From ae9e42479253a9cd30423476405377f2d7952137 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 17 Aug 2017 13:00:37 -0700 Subject: [PATCH 001/187] [SQL][MINOR][TEST] Set spark.unsafe.exceptionOnMemoryLeak to true ## What changes were proposed in this pull request? When running IntelliJ, we are unable to capture the exception of memory leak detection. > org.apache.spark.executor.Executor: Managed memory leak detected Explicitly setting `spark.unsafe.exceptionOnMemoryLeak` in SparkConf when building the SparkSession, instead of reading it from system properties. ## How was this patch tested? N/A Author: gatorsmile Closes #18967 from gatorsmile/setExceptionOnMemoryLeak. --- .../scala/org/apache/spark/sql/test/SharedSQLContext.scala | 4 +++- .../main/scala/org/apache/spark/sql/hive/test/TestHive.scala | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 5ec76a4f0ec90..1f073d5f64c6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -31,7 +31,9 @@ import org.apache.spark.sql.{SparkSession, SQLContext} trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { protected def sparkConf = { - new SparkConf().set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + new SparkConf() + .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + .set("spark.unsafe.exceptionOnMemoryLeak", "true") } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 9e15baa4b2b74..10c9a2de6540a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -55,7 +55,8 @@ object TestHive "org.apache.spark.sql.hive.execution.PairSerDe") .set("spark.sql.warehouse.dir", TestHiveContext.makeWarehouseDir().toURI.getPath) // SPARK-8910 - .set("spark.ui.enabled", "false"))) + .set("spark.ui.enabled", "false") + .set("spark.unsafe.exceptionOnMemoryLeak", "true"))) case class TestHiveVersion(hiveClient: HiveClient) From 6aad02d03632df964363a144c96371e86f7b207e Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 17 Aug 2017 22:47:14 +0200 Subject: [PATCH 002/187] [SPARK-18394][SQL] Make an AttributeSet.toSeq output order consistent ## What changes were proposed in this pull request? This pr sorted output attributes on their name and exprId in `AttributeSet.toSeq` to make the order consistent. If the order is different, spark possibly generates different code and then misses cache in `CodeGenerator`, e.g., `GenerateColumnAccessor` generates code depending on an input attribute order. ## How was this patch tested? Added tests in `AttributeSetSuite` and manually checked if the cache worked well in the given query of the JIRA. Author: Takeshi Yamamuro Closes #18959 from maropu/SPARK-18394. --- .../catalyst/expressions/AttributeSet.scala | 7 +++- .../expressions/AttributeSetSuite.scala | 40 +++++++++++++++++++ .../sql/hive/execution/PruningSuite.scala | 7 +++- 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index b77f93373e78d..7420b6b57d8e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -121,7 +121,12 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) // We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all // sorts of things in its closure. - override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq + override def toSeq: Seq[Attribute] = { + // We need to keep a deterministic output order for `baseSet` because this affects a variable + // order in generated code (e.g., `GenerateColumnAccessor`). + // See SPARK-18394 for details. + baseSet.map(_.a).toSeq.sortBy { a => (a.name, a.exprId.id) } + } override def toString: String = "{" + baseSet.map(_.a).mkString(", ") + "}" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala index 273f95f91ee50..b6e8b667a2400 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala @@ -78,4 +78,44 @@ class AttributeSetSuite extends SparkFunSuite { assert(aSet == aSet) assert(aSet == AttributeSet(aUpper :: Nil)) } + + test("SPARK-18394 keep a deterministic output order along with attribute names and exprIds") { + // Checks a simple case + val attrSeqA = { + val attr1 = AttributeReference("c1", IntegerType)(exprId = ExprId(1098)) + val attr2 = AttributeReference("c2", IntegerType)(exprId = ExprId(107)) + val attr3 = AttributeReference("c3", IntegerType)(exprId = ExprId(838)) + val attrSetA = AttributeSet(attr1 :: attr2 :: attr3 :: Nil) + + val attr4 = AttributeReference("c4", IntegerType)(exprId = ExprId(389)) + val attr5 = AttributeReference("c5", IntegerType)(exprId = ExprId(89329)) + + val attrSetB = AttributeSet(attr4 :: attr5 :: Nil) + (attrSetA ++ attrSetB).toSeq.map(_.name) + } + + val attrSeqB = { + val attr1 = AttributeReference("c1", IntegerType)(exprId = ExprId(392)) + val attr2 = AttributeReference("c2", IntegerType)(exprId = ExprId(92)) + val attr3 = AttributeReference("c3", IntegerType)(exprId = ExprId(87)) + val attrSetA = AttributeSet(attr1 :: attr2 :: attr3 :: Nil) + + val attr4 = AttributeReference("c4", IntegerType)(exprId = ExprId(9023920)) + val attr5 = AttributeReference("c5", IntegerType)(exprId = ExprId(522)) + val attrSetB = AttributeSet(attr4 :: attr5 :: Nil) + + (attrSetA ++ attrSetB).toSeq.map(_.name) + } + + assert(attrSeqA === attrSeqB) + + // Checks the same column names having different exprIds + val attr1 = AttributeReference("c", IntegerType)(exprId = ExprId(1098)) + val attr2 = AttributeReference("c", IntegerType)(exprId = ExprId(107)) + val attrSetA = AttributeSet(attr1 :: attr2 :: Nil) + val attr3 = AttributeReference("c", IntegerType)(exprId = ExprId(389)) + val attrSetB = AttributeSet(attr3 :: Nil) + + assert((attrSetA ++ attrSetB).toSeq === attr2 :: attr3 :: attr1 :: Nil) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index d535bef4cc787..cc592cf6ca629 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -162,7 +162,12 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { }.head assert(actualOutputColumns === expectedOutputColumns, "Output columns mismatch") - assert(actualScannedColumns === expectedScannedColumns, "Scanned columns mismatch") + + // Scanned columns in `HiveTableScanExec` are generated by the `pruneFilterProject` method + // in `SparkPlanner`. This method internally uses `AttributeSet.toSeq`, in which + // the returned output columns are sorted by the names and expression ids. + assert(actualScannedColumns.sorted === expectedScannedColumns.sorted, + "Scanned columns mismatch") val actualPartitions = actualPartValues.map(_.asScala.mkString(",")).sorted val expectedPartitions = expectedPartValues.map(_.mkString(",")).sorted From bfdc361ededb2ed4e323f075fdc40ed004b7f41d Mon Sep 17 00:00:00 2001 From: ArtRand Date: Thu, 17 Aug 2017 15:47:07 -0700 Subject: [PATCH 003/187] [SPARK-16742] Mesos Kerberos Support ## What changes were proposed in this pull request? Add Kerberos Support to Mesos. This includes kinit and --keytab support, but does not include delegation token renewal. ## How was this patch tested? Manually against a Secure DC/OS Apache HDFS cluster. Author: ArtRand Author: Michael Gummelt Closes #18519 from mgummelt/SPARK-16742-kerberos. --- .../apache/spark/deploy/SparkHadoopUtil.scala | 29 +++++++++++--- .../org/apache/spark/deploy/SparkSubmit.scala | 38 ++++++++++++++----- .../HadoopDelegationTokenManager.scala | 8 ++++ .../CoarseGrainedExecutorBackend.scala | 7 ++++ .../cluster/CoarseGrainedClusterMessage.scala | 3 +- .../CoarseGrainedSchedulerBackend.scala | 33 ++++++++++++++-- resource-managers/mesos/pom.xml | 11 ++++++ .../MesosCoarseGrainedSchedulerBackend.scala | 12 +++--- .../deploy/yarn/YarnSparkHadoopUtil.scala | 8 ---- 9 files changed, 117 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 2a92ef99b9f37..6d507d85331bd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -import java.io.{File, IOException} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, File, IOException} import java.security.PrivilegedExceptionAction import java.text.DateFormat import java.util.{Arrays, Comparator, Date, Locale} @@ -147,14 +147,18 @@ class SparkHadoopUtil extends Logging { def isYarnMode(): Boolean = { false } - def getCurrentUserCredentials(): Credentials = { null } - - def addCurrentUserCredentials(creds: Credentials) {} - def addSecretKeyToUserCredentials(key: String, secret: String) {} def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { null } + def getCurrentUserCredentials(): Credentials = { + UserGroupInformation.getCurrentUser().getCredentials() + } + + def addCurrentUserCredentials(creds: Credentials): Unit = { + UserGroupInformation.getCurrentUser.addCredentials(creds) + } + def loginUserFromKeytab(principalName: String, keytabFilename: String): Unit = { if (!new File(keytabFilename).exists()) { throw new SparkException(s"Keytab file: ${keytabFilename} does not exist") @@ -425,6 +429,21 @@ class SparkHadoopUtil extends Logging { s"${if (status.isDirectory) "d" else "-"}$perm") false } + + def serialize(creds: Credentials): Array[Byte] = { + val byteStream = new ByteArrayOutputStream + val dataStream = new DataOutputStream(byteStream) + creds.writeTokenStorageToStream(dataStream) + byteStream.toByteArray + } + + def deserialize(tokenBytes: Array[Byte]): Credentials = { + val tokensBuf = new ByteArrayInputStream(tokenBytes) + + val creds = new Credentials() + creds.readTokenStorageStream(new DataInputStream(tokensBuf)) + creds + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 6d744a084a0fa..e7e8fbc25d0ec 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -34,6 +34,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.hadoop.conf.{Configuration => HadoopConfiguration} import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.ivy.Ivy import org.apache.ivy.core.LogOptions import org.apache.ivy.core.module.descriptor._ @@ -49,6 +50,7 @@ import org.apache.ivy.plugins.resolver.{ChainResolver, FileSystemResolver, IBibl import org.apache.spark._ import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.rest._ +import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkLauncher import org.apache.spark.util._ @@ -556,19 +558,25 @@ object SparkSubmit extends CommandLineUtils { } // assure a keytab is available from any place in a JVM - if (clusterManager == YARN || clusterManager == LOCAL) { + if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) { if (args.principal != null) { - require(args.keytab != null, "Keytab must be specified when principal is specified") - SparkHadoopUtil.get.loginUserFromKeytab(args.principal, args.keytab) - // Add keytab and principal configurations in sysProps to make them available - // for later use; e.g. in spark sql, the isolated class loader used to talk - // to HiveMetastore will use these settings. They will be set as Java system - // properties and then loaded by SparkConf - sysProps.put("spark.yarn.keytab", args.keytab) - sysProps.put("spark.yarn.principal", args.principal) + if (args.keytab != null) { + require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") + // Add keytab and principal configurations in sysProps to make them available + // for later use; e.g. in spark sql, the isolated class loader used to talk + // to HiveMetastore will use these settings. They will be set as Java system + // properties and then loaded by SparkConf + sysProps.put("spark.yarn.keytab", args.keytab) + sysProps.put("spark.yarn.principal", args.principal) + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + } } } + if (clusterManager == MESOS && UserGroupInformation.isSecurityEnabled) { + setRMPrincipal(sysProps) + } + // In yarn-cluster mode, use yarn.Client as a wrapper around the user class if (isYarnCluster) { childMainClass = "org.apache.spark.deploy.yarn.Client" @@ -653,6 +661,18 @@ object SparkSubmit extends CommandLineUtils { (childArgs, childClasspath, sysProps, childMainClass) } + // [SPARK-20328]. HadoopRDD calls into a Hadoop library that fetches delegation tokens with + // renewer set to the YARN ResourceManager. Since YARN isn't configured in Mesos mode, we + // must trick it into thinking we're YARN. + private def setRMPrincipal(sysProps: HashMap[String, String]): Unit = { + val shortUserName = UserGroupInformation.getCurrentUser.getShortUserName + val key = s"spark.hadoop.${YarnConfiguration.RM_PRINCIPAL}" + // scalastyle:off println + printStream.println(s"Setting ${key} to ${shortUserName}") + // scalastyle:off println + sysProps.put(key, shortUserName) + } + /** * Run the main method of the child class using the provided launch environment. * diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 01cbfe1ee6ae1..c317c4fe3d821 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -55,6 +55,14 @@ private[spark] class HadoopDelegationTokenManager( logDebug(s"Using the following delegation token providers: " + s"${delegationTokenProviders.keys.mkString(", ")}.") + /** Construct a [[HadoopDelegationTokenManager]] for the default Hadoop filesystem */ + def this(sparkConf: SparkConf, hadoopConf: Configuration) = { + this( + sparkConf, + hadoopConf, + hadoopConf => Set(FileSystem.get(hadoopConf).getHomeDirectory.getFileSystem(hadoopConf))) + } + private def getDelegationTokenProviders: Map[String, HadoopDelegationTokenProvider] = { val providers = List(new HadoopFSDelegationTokenProvider(fileSystems), new HiveDelegationTokenProvider, diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index a2f1aa22b0063..a5d60e90210f1 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -26,6 +26,8 @@ import scala.collection.mutable import scala.util.{Failure, Success} import scala.util.control.NonFatal +import org.apache.hadoop.security.{Credentials, UserGroupInformation} + import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil @@ -219,6 +221,11 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { SparkHadoopUtil.get.startCredentialUpdater(driverConf) } + cfg.hadoopDelegationCreds.foreach { hadoopCreds => + val creds = SparkHadoopUtil.get.deserialize(hadoopCreds) + SparkHadoopUtil.get.addCurrentUserCredentials(creds) + } + val env = SparkEnv.createExecutorEnv( driverConf, executorId, hostname, cores, cfg.ioEncryptionKey, isLocal = false) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 89a9ad6811e18..5d65731dfc30e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -32,7 +32,8 @@ private[spark] object CoarseGrainedClusterMessages { case class SparkAppConfig( sparkProperties: Seq[(String, String)], - ioEncryptionKey: Option[Array[Byte]]) + ioEncryptionKey: Option[Array[Byte]], + hadoopDelegationCreds: Option[Array[Byte]]) extends CoarseGrainedClusterMessage case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index a46824a0c6fad..a0ef209779309 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -24,7 +24,11 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.concurrent.Future +import org.apache.hadoop.security.UserGroupInformation + import org.apache.spark.{ExecutorAllocationClient, SparkEnv, SparkException, TaskState} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.rpc._ import org.apache.spark.scheduler._ @@ -42,8 +46,8 @@ import org.apache.spark.util.{RpcUtils, SerializableBuffer, ThreadUtils, Utils} */ private[spark] class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv) - extends ExecutorAllocationClient with SchedulerBackend with Logging -{ + extends ExecutorAllocationClient with SchedulerBackend with Logging { + // Use an atomic variable to track total number of cores in the cluster for simplicity and speed protected val totalCoreCount = new AtomicInteger(0) // Total number of executors that are currently registered @@ -95,6 +99,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // The num of current max ExecutorId used to re-register appMaster @volatile protected var currentExecutorIdCounter = 0 + // hadoop token manager used by some sub-classes (e.g. Mesos) + def hadoopDelegationTokenManager: Option[HadoopDelegationTokenManager] = None + + // Hadoop delegation tokens to be sent to the executors. + val hadoopDelegationCreds: Option[Array[Byte]] = getHadoopDelegationCreds() + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { @@ -223,8 +233,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp context.reply(true) case RetrieveSparkAppConfig => - val reply = SparkAppConfig(sparkProperties, - SparkEnv.get.securityManager.getIOEncryptionKey()) + val reply = SparkAppConfig( + sparkProperties, + SparkEnv.get.securityManager.getIOEncryptionKey(), + hadoopDelegationCreds) context.reply(reply) } @@ -675,6 +687,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp driverEndpoint.send(KillExecutorsOnHost(host)) true } + + protected def getHadoopDelegationCreds(): Option[Array[Byte]] = { + if (UserGroupInformation.isSecurityEnabled && hadoopDelegationTokenManager.isDefined) { + hadoopDelegationTokenManager.map { manager => + val creds = UserGroupInformation.getCurrentUser.getCredentials + val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + manager.obtainDelegationTokens(hadoopConf, creds) + SparkHadoopUtil.get.serialize(creds) + } + } else { + None + } + } } private[spark] object CoarseGrainedSchedulerBackend { diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 20b53f2d8f987..2aa3228af79d6 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -74,6 +74,17 @@ test + + ${hive.group} + hive-exec + provided + + + ${hive.group} + hive-metastore + provided + + com.google.guava diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index e6b09572121d6..5ecd466194d8b 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -22,15 +22,15 @@ import java.util.{Collections, List => JList} import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import java.util.concurrent.locks.ReentrantLock +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.SchedulerDriver import scala.collection.JavaConverters._ import scala.collection.mutable import scala.concurrent.Future -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} -import org.apache.mesos.SchedulerDriver - import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState} import org.apache.spark.deploy.mesos.config._ +import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.config import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient @@ -55,8 +55,10 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( master: String, securityManager: SecurityManager) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) - with org.apache.mesos.Scheduler - with MesosSchedulerUtils { + with org.apache.mesos.Scheduler with MesosSchedulerUtils { + + override def hadoopDelegationTokenManager: Option[HadoopDelegationTokenManager] = + Some(new HadoopDelegationTokenManager(sc.conf, sc.hadoopConfiguration)) // Blacklist a slave after this many failures private val MAX_SLAVE_FAILURES = 2 diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 4fef4394bb3f0..3d9f99f57bed7 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -74,14 +74,6 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials()) } - override def getCurrentUserCredentials(): Credentials = { - UserGroupInformation.getCurrentUser().getCredentials() - } - - override def addCurrentUserCredentials(creds: Credentials) { - UserGroupInformation.getCurrentUser().addCredentials(creds) - } - override def addSecretKeyToUserCredentials(key: String, secret: String) { val creds = new Credentials() creds.addSecretKey(new Text(key), secret.getBytes(UTF_8)) From 7ab951885fd34aa8184b70a3a39b865a239e5052 Mon Sep 17 00:00:00 2001 From: Jen-Ming Chung Date: Thu, 17 Aug 2017 15:59:45 -0700 Subject: [PATCH 004/187] [SPARK-21677][SQL] json_tuple throws NullPointException when column is null as string type ## What changes were proposed in this pull request? ``` scala scala> Seq(("""{"Hyukjin": 224, "John": 1225}""")).toDS.selectExpr("json_tuple(value, trim(null))").show() ... java.lang.NullPointerException at ... ``` Currently the `null` field name will throw NullPointException. As a given field name null can't be matched with any field names in json, we just output null as its column value. This PR achieves it by returning a very unlikely column name `__NullFieldName` in evaluation of the field names. ## How was this patch tested? Added unit test. Author: Jen-Ming Chung Closes #18930 from jmchung/SPARK-21677. --- .../expressions/jsonExpressions.scala | 8 ++--- .../expressions/JsonExpressionsSuite.scala | 10 ++++++ .../sql-tests/inputs/json-functions.sql | 6 ++++ .../sql-tests/results/json-functions.sql.out | 34 ++++++++++++++++++- 4 files changed, 53 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 17b605438d587..c3757373a3cf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -362,9 +362,9 @@ case class JsonTuple(children: Seq[Expression]) @transient private lazy val fieldExpressions: Seq[Expression] = children.tail // eagerly evaluate any foldable the field names - @transient private lazy val foldableFieldNames: IndexedSeq[String] = { + @transient private lazy val foldableFieldNames: IndexedSeq[Option[String]] = { fieldExpressions.map { - case expr if expr.foldable => expr.eval().asInstanceOf[UTF8String].toString + case expr if expr.foldable => Option(expr.eval()).map(_.asInstanceOf[UTF8String].toString) case _ => null }.toIndexedSeq } @@ -417,7 +417,7 @@ case class JsonTuple(children: Seq[Expression]) val fieldNames = if (constantFields == fieldExpressions.length) { // typically the user will provide the field names as foldable expressions // so we can use the cached copy - foldableFieldNames + foldableFieldNames.map(_.orNull) } else if (constantFields == 0) { // none are foldable so all field names need to be evaluated from the input row fieldExpressions.map(_.eval(input).asInstanceOf[UTF8String].toString) @@ -426,7 +426,7 @@ case class JsonTuple(children: Seq[Expression]) // prefer the cached copy when available foldableFieldNames.zip(fieldExpressions).map { case (null, expr) => expr.eval(input).asInstanceOf[UTF8String].toString - case (fieldName, _) => fieldName + case (fieldName, _) => fieldName.orNull } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index f892e80204603..1cd2b4fc18a5c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -363,6 +363,16 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow(UTF8String.fromString("b\nc"))) } + test("SPARK-21677: json_tuple throws NullPointException when column is null as string type") { + checkJsonTuple( + JsonTuple(Literal("""{"f1": 1, "f2": 2}""") :: + NonFoldableLiteral("f1") :: + NonFoldableLiteral("cast(NULL AS STRING)") :: + NonFoldableLiteral("f2") :: + Nil), + InternalRow(UTF8String.fromString("1"), null, UTF8String.fromString("2"))) + } + val gmtId = Option(DateTimeUtils.TimeZoneGMT.getID) test("from_json") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index b3cc2cea51d43..5a46fb4321f90 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -20,3 +20,9 @@ select from_json('{"a":1}', 'a InvalidType'); select from_json('{"a":1}', 'a INT', named_struct('mode', 'PERMISSIVE')); select from_json('{"a":1}', 'a INT', map('mode', 1)); select from_json(); +-- json_tuple +SELECT json_tuple('{"a" : 1, "b" : 2}', CAST(NULL AS STRING), 'b', CAST(NULL AS STRING), 'a'); +CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1, "b": 2}', 'a'); +SELECT json_tuple(jsonField, 'b', CAST(NULL AS STRING), a) FROM jsonTable; +-- Clean up +DROP VIEW IF EXISTS jsonTable; diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 22da20d9a9f4e..ae21d00116e9b 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 17 +-- Number of queries: 21 -- !query 0 @@ -178,3 +178,35 @@ struct<> -- !query 16 output org.apache.spark.sql.AnalysisException Invalid number of arguments for function from_json; line 1 pos 7 + + +-- !query 17 +SELECT json_tuple('{"a" : 1, "b" : 2}', CAST(NULL AS STRING), 'b', CAST(NULL AS STRING), 'a') +-- !query 17 schema +struct +-- !query 17 output +NULL 2 NULL 1 + + +-- !query 18 +CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1, "b": 2}', 'a') +-- !query 18 schema +struct<> +-- !query 18 output + + + +-- !query 19 +SELECT json_tuple(jsonField, 'b', CAST(NULL AS STRING), a) FROM jsonTable +-- !query 19 schema +struct +-- !query 19 output +2 NULL 1 + + +-- !query 20 +DROP VIEW IF EXISTS jsonTable +-- !query 20 schema +struct<> +-- !query 20 output + From 2caaed970e3e26ae59be5999516a737aff3e5c78 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 17 Aug 2017 16:33:39 -0700 Subject: [PATCH 005/187] [SPARK-21767][TEST][SQL] Add Decimal Test For Avro in VersionSuite ## What changes were proposed in this pull request? Decimal is a logical type of AVRO. We need to ensure the support of Hive's AVRO serde works well in Spark ## How was this patch tested? N/A Author: gatorsmile Closes #18977 from gatorsmile/addAvroTest. --- .../spark/sql/hive/client/VersionsSuite.scala | 68 ++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 82fbdd645ebe0..072e538b9ed54 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -21,7 +21,6 @@ import java.io.{ByteArrayOutputStream, File, PrintStream} import java.net.URI import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred.TextInputFormat @@ -697,6 +696,73 @@ class VersionsSuite extends SparkFunSuite with Logging { assert(versionSpark.table("t1").collect() === Array(Row(2))) } } + + test(s"$version: Decimal support of Avro Hive serde") { + val tableName = "tab1" + // TODO: add the other logical types. For details, see the link: + // https://avro.apache.org/docs/1.8.1/spec.html#Logical+Types + val avroSchema = + """{ + | "name": "test_record", + | "type": "record", + | "fields": [ { + | "name": "f0", + | "type": [ + | "null", + | { + | "precision": 38, + | "scale": 2, + | "type": "bytes", + | "logicalType": "decimal" + | } + | ] + | } ] + |} + """.stripMargin + + Seq(true, false).foreach { isPartitioned => + withTable(tableName) { + val partitionClause = if (isPartitioned) "PARTITIONED BY (ds STRING)" else "" + // Creates the (non-)partitioned Avro table + versionSpark.sql( + s""" + |CREATE TABLE $tableName + |$partitionClause + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' + |TBLPROPERTIES ('avro.schema.literal' = '$avroSchema') + """.stripMargin + ) + + val errorMsg = "data type mismatch: cannot cast DecimalType(2,1) to BinaryType" + + if (isPartitioned) { + val insertStmt = s"INSERT OVERWRITE TABLE $tableName partition (ds='a') SELECT 1.3" + if (version == "0.12" || version == "0.13") { + val e = intercept[AnalysisException](versionSpark.sql(insertStmt)).getMessage + assert(e.contains(errorMsg)) + } else { + versionSpark.sql(insertStmt) + assert(versionSpark.table(tableName).collect() === + versionSpark.sql("SELECT 1.30, 'a'").collect()) + } + } else { + val insertStmt = s"INSERT OVERWRITE TABLE $tableName SELECT 1.3" + if (version == "0.12" || version == "0.13") { + val e = intercept[AnalysisException](versionSpark.sql(insertStmt)).getMessage + assert(e.contains(errorMsg)) + } else { + versionSpark.sql(insertStmt) + assert(versionSpark.table(tableName).collect() === + versionSpark.sql("SELECT 1.30").collect()) + } + } + } + } + } + // TODO: add more tests. } } From 310454be3b0ce5ff6b6ef0070c5daadf6fb16927 Mon Sep 17 00:00:00 2001 From: donnyzone Date: Thu, 17 Aug 2017 22:37:32 -0700 Subject: [PATCH 006/187] [SPARK-21739][SQL] Cast expression should initialize timezoneId when it is called statically to convert something into TimestampType ## What changes were proposed in this pull request? https://issues.apache.org/jira/projects/SPARK/issues/SPARK-21739 This issue is caused by introducing TimeZoneAwareExpression. When the **Cast** expression converts something into TimestampType, it should be resolved with setting `timezoneId`. In general, it is resolved in LogicalPlan phase. However, there are still some places that use Cast expression statically to convert datatypes without setting `timezoneId`. In such cases, `NoSuchElementException: None.get` will be thrown for TimestampType. This PR is proposed to fix the issue. We have checked the whole project and found two such usages(i.e., in`TableReader` and `HiveTableScanExec`). ## How was this patch tested? unit test Author: donnyzone Closes #18960 from DonnyZone/spark-21739. --- .../org/apache/spark/sql/hive/TableReader.scala | 8 ++++++-- .../sql/hive/execution/HiveTableScanExec.scala | 8 ++++++-- .../spark/sql/hive/QueryPartitionSuite.scala | 17 +++++++++++++++++ 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index f238b9a4f7f6f..cc8907a0bbc93 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -39,8 +39,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -65,7 +67,7 @@ class HadoopTableReader( @transient private val tableDesc: TableDesc, @transient private val sparkSession: SparkSession, hadoopConf: Configuration) - extends TableReader with Logging { + extends TableReader with CastSupport with Logging { // Hadoop honors "mapreduce.job.maps" as hint, // but will ignore when mapreduce.jobtracker.address is "local". @@ -86,6 +88,8 @@ class HadoopTableReader( private val _broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + override def conf: SQLConf = sparkSession.sessionState.conf + override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = makeRDDForTable( hiveTable, @@ -227,7 +231,7 @@ class HadoopTableReader( def fillPartitionKeys(rawPartValues: Array[String], row: InternalRow): Unit = { partitionKeyAttrs.foreach { case (attr, ordinal) => val partOrdinal = partitionKeys.indexOf(attr) - row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) + row(ordinal) = cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 896f24f2e223d..48d0b4a63e54a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan @@ -37,6 +38,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClientImpl +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, DataType} import org.apache.spark.util.Utils @@ -53,11 +55,13 @@ case class HiveTableScanExec( relation: HiveTableRelation, partitionPruningPred: Seq[Expression])( @transient private val sparkSession: SparkSession) - extends LeafExecNode { + extends LeafExecNode with CastSupport { require(partitionPruningPred.isEmpty || relation.isPartitioned, "Partition pruning predicates only supported for partitioned tables.") + override def conf: SQLConf = sparkSession.sessionState.conf + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -104,7 +108,7 @@ case class HiveTableScanExec( hadoopConf) private def castFromString(value: String, dataType: DataType) = { - Cast(Literal(value), dataType).eval(null) + cast(Literal(value), dataType).eval(null) } private def addColumnMetadataToConf(hiveConf: Configuration): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 43b6bf5feeb60..b2dc401ce1efc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.io.File +import java.sql.Timestamp import com.google.common.io.Files import org.apache.hadoop.fs.FileSystem @@ -68,4 +69,20 @@ class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingl sql("DROP TABLE IF EXISTS createAndInsertTest") } } + + test("SPARK-21739: Cast expression should initialize timezoneId") { + withTable("table_with_timestamp_partition") { + sql("CREATE TABLE table_with_timestamp_partition(value int) PARTITIONED BY (ts TIMESTAMP)") + sql("INSERT OVERWRITE TABLE table_with_timestamp_partition " + + "PARTITION (ts = '2010-01-01 00:00:00.000') VALUES (1)") + + // test for Cast expression in TableReader + checkAnswer(sql("SELECT * FROM table_with_timestamp_partition"), + Seq(Row(1, Timestamp.valueOf("2010-01-01 00:00:00.000")))) + + // test for Cast expression in HiveTableScanExec + checkAnswer(sql("SELECT value FROM table_with_timestamp_partition " + + "WHERE ts = '2010-01-01 00:00:00.000'"), Row(1)) + } + } } From 07a2b8738ed8e6c136545d03f91a865de05e41a0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 18 Aug 2017 23:58:20 +0900 Subject: [PATCH 007/187] [SPARK-21778][SQL] Simpler Dataset.sample API in Scala / Java ## What changes were proposed in this pull request? Dataset.sample requires a boolean flag withReplacement as the first argument. However, most of the time users simply want to sample some records without replacement. This ticket introduces a new sample function that simply takes in the fraction and seed. ## How was this patch tested? Tested manually. Not sure yet if we should add a test case for just this wrapper ... Author: Reynold Xin Closes #18988 from rxin/SPARK-21778. --- .../scala/org/apache/spark/sql/Dataset.scala | 36 +++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a9887eb95279f..615686ccbe2b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1848,11 +1848,43 @@ class Dataset[T] private[sql]( Except(logicalPlan, other.logicalPlan) } + /** + * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), + * using a user-supplied seed. + * + * @param fraction Fraction of rows to generate, range [0.0, 1.0]. + * @param seed Seed for sampling. + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[Dataset]]. + * + * @group typedrel + * @since 2.3.0 + */ + def sample(fraction: Double, seed: Long): Dataset[T] = { + sample(withReplacement = false, fraction = fraction, seed = seed) + } + + /** + * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement). + * + * @param fraction Fraction of rows to generate, range [0.0, 1.0]. + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[Dataset]]. + * + * @group typedrel + * @since 2.3.0 + */ + def sample(fraction: Double): Dataset[T] = { + sample(withReplacement = false, fraction = fraction) + } + /** * Returns a new [[Dataset]] by sampling a fraction of rows, using a user-supplied seed. * * @param withReplacement Sample with replacement or not. - * @param fraction Fraction of rows to generate. + * @param fraction Fraction of rows to generate, range [0.0, 1.0]. * @param seed Seed for sampling. * * @note This is NOT guaranteed to provide exactly the fraction of the count @@ -1871,7 +1903,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] by sampling a fraction of rows, using a random seed. * * @param withReplacement Sample with replacement or not. - * @param fraction Fraction of rows to generate. + * @param fraction Fraction of rows to generate, range [0.0, 1.0]. * * @note This is NOT guaranteed to provide exactly the fraction of the total count * of the given [[Dataset]]. From 23ea8980809497d0372084adf5936602655e1685 Mon Sep 17 00:00:00 2001 From: Masha Basmanova Date: Fri, 18 Aug 2017 09:54:39 -0700 Subject: [PATCH 008/187] [SPARK-21213][SQL] Support collecting partition-level statistics: rowCount and sizeInBytes ## What changes were proposed in this pull request? Added support for ANALYZE TABLE [db_name].tablename PARTITION (partcol1[=val1], partcol2[=val2], ...) COMPUTE STATISTICS [NOSCAN] SQL command to calculate total number of rows and size in bytes for a subset of partitions. Calculated statistics are stored in Hive Metastore as user-defined properties attached to partition objects. Property names are the same as the ones used to store table-level statistics: spark.sql.statistics.totalSize and spark.sql.statistics.numRows. When partition specification contains all partition columns with values, the command collects statistics for a single partition that matches the specification. When some partition columns are missing or listed without their values, the command collects statistics for all partitions which match a subset of partition column values specified. For example, table t has 4 partitions with the following specs: * Partition1: (ds='2008-04-08', hr=11) * Partition2: (ds='2008-04-08', hr=12) * Partition3: (ds='2008-04-09', hr=11) * Partition4: (ds='2008-04-09', hr=12) 'ANALYZE TABLE t PARTITION (ds='2008-04-09', hr=11)' command will collect statistics only for partition 3. 'ANALYZE TABLE t PARTITION (ds='2008-04-09')' command will collect statistics for partitions 3 and 4. 'ANALYZE TABLE t PARTITION (ds, hr)' command will collect statistics for all four partitions. When the optional parameter NOSCAN is specified, the command doesn't count number of rows and only gathers size in bytes. The statistics gathered by ANALYZE TABLE command can be fetched using DESC EXTENDED [db_name.]tablename PARTITION command. ## How was this patch tested? Added tests. Author: Masha Basmanova Closes #18421 from mbasmanova/mbasmanova-analyze-partition. --- .../sql/catalyst/catalog/interface.scala | 7 +- .../spark/sql/execution/SparkSqlParser.scala | 36 ++- .../command/AnalyzePartitionCommand.scala | 149 ++++++++++ .../command/AnalyzeTableCommand.scala | 28 +- .../sql/execution/command/CommandUtils.scala | 27 +- .../inputs/describe-part-after-analyze.sql | 34 +++ .../describe-part-after-analyze.sql.out | 244 +++++++++++++++++ .../sql/execution/SparkSqlParserSuite.scala | 33 ++- .../spark/sql/hive/HiveExternalCatalog.scala | 169 ++++++++---- .../sql/hive/client/HiveClientImpl.scala | 2 + .../spark/sql/hive/StatisticsSuite.scala | 254 ++++++++++++++++++ 11 files changed, 888 insertions(+), 95 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala create mode 100644 sql/core/src/test/resources/sql-tests/inputs/describe-part-after-analyze.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 5a8c4e7610fff..1965144e81197 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -91,12 +91,14 @@ object CatalogStorageFormat { * * @param spec partition spec values indexed by column name * @param storage storage format of the partition - * @param parameters some parameters for the partition, for example, stats. + * @param parameters some parameters for the partition + * @param stats optional statistics (number of rows, total size, etc.) */ case class CatalogTablePartition( spec: CatalogTypes.TablePartitionSpec, storage: CatalogStorageFormat, - parameters: Map[String, String] = Map.empty) { + parameters: Map[String, String] = Map.empty, + stats: Option[CatalogStatistics] = None) { def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { val map = new mutable.LinkedHashMap[String, String]() @@ -106,6 +108,7 @@ case class CatalogTablePartition( if (parameters.nonEmpty) { map.put("Partition Parameters", s"{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") } + stats.foreach(s => map.put("Partition Statistics", s.simpleString)) map } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index d4414b6f78ca2..8379e740a0717 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -90,30 +90,40 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } /** - * Create an [[AnalyzeTableCommand]] command or an [[AnalyzeColumnCommand]] command. - * Example SQL for analyzing table : + * Create an [[AnalyzeTableCommand]] command, or an [[AnalyzePartitionCommand]] + * or an [[AnalyzeColumnCommand]] command. + * Example SQL for analyzing a table or a set of partitions : * {{{ - * ANALYZE TABLE table COMPUTE STATISTICS [NOSCAN]; + * ANALYZE TABLE [db_name.]tablename [PARTITION (partcol1[=val1], partcol2[=val2], ...)] + * COMPUTE STATISTICS [NOSCAN]; * }}} + * * Example SQL for analyzing columns : * {{{ - * ANALYZE TABLE table COMPUTE STATISTICS FOR COLUMNS column1, column2; + * ANALYZE TABLE [db_name.]tablename COMPUTE STATISTICS FOR COLUMNS column1, column2; * }}} */ override def visitAnalyze(ctx: AnalyzeContext): LogicalPlan = withOrigin(ctx) { - if (ctx.partitionSpec != null) { - logWarning(s"Partition specification is ignored: ${ctx.partitionSpec.getText}") + if (ctx.identifier != null && + ctx.identifier.getText.toLowerCase(Locale.ROOT) != "noscan") { + throw new ParseException(s"Expected `NOSCAN` instead of `${ctx.identifier.getText}`", ctx) } - if (ctx.identifier != null) { - if (ctx.identifier.getText.toLowerCase(Locale.ROOT) != "noscan") { - throw new ParseException(s"Expected `NOSCAN` instead of `${ctx.identifier.getText}`", ctx) + + val table = visitTableIdentifier(ctx.tableIdentifier) + if (ctx.identifierSeq() == null) { + if (ctx.partitionSpec != null) { + AnalyzePartitionCommand(table, visitPartitionSpec(ctx.partitionSpec), + noscan = ctx.identifier != null) + } else { + AnalyzeTableCommand(table, noscan = ctx.identifier != null) } - AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier)) - } else if (ctx.identifierSeq() == null) { - AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier), noscan = false) } else { + if (ctx.partitionSpec != null) { + logWarning("Partition specification is ignored when collecting column statistics: " + + ctx.partitionSpec.getText) + } AnalyzeColumnCommand( - visitTableIdentifier(ctx.tableIdentifier), + table, visitIdentifierSeq(ctx.identifierSeq())) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala new file mode 100644 index 0000000000000..5b54b2270b5ec --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.{AnalysisException, Column, Row, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal} +import org.apache.spark.sql.execution.datasources.PartitioningUtils + +/** + * Analyzes a given set of partitions to generate per-partition statistics, which will be used in + * query optimizations. + * + * When `partitionSpec` is empty, statistics for all partitions are collected and stored in + * Metastore. + * + * When `partitionSpec` mentions only some of the partition columns, all partitions with + * matching values for specified columns are processed. + * + * If `partitionSpec` mentions unknown partition column, an `AnalysisException` is raised. + * + * By default, total number of rows and total size in bytes are calculated. When `noscan` + * is `true`, only total size in bytes is computed. + */ +case class AnalyzePartitionCommand( + tableIdent: TableIdentifier, + partitionSpec: Map[String, Option[String]], + noscan: Boolean = true) extends RunnableCommand { + + private def getPartitionSpec(table: CatalogTable): Option[TablePartitionSpec] = { + val normalizedPartitionSpec = + PartitioningUtils.normalizePartitionSpec(partitionSpec, table.partitionColumnNames, + table.identifier.quotedString, conf.resolver) + + // Report an error if partition columns in partition specification do not form + // a prefix of the list of partition columns defined in the table schema + val isNotSpecified = + table.partitionColumnNames.map(normalizedPartitionSpec.getOrElse(_, None).isEmpty) + if (isNotSpecified.init.zip(isNotSpecified.tail).contains((true, false))) { + val tableId = table.identifier + val schemaColumns = table.partitionColumnNames.mkString(",") + val specColumns = normalizedPartitionSpec.keys.mkString(",") + throw new AnalysisException("The list of partition columns with values " + + s"in partition specification for table '${tableId.table}' " + + s"in database '${tableId.database.get}' is not a prefix of the list of " + + "partition columns defined in the table schema. " + + s"Expected a prefix of [${schemaColumns}], but got [${specColumns}].") + } + + val filteredSpec = normalizedPartitionSpec.filter(_._2.isDefined).mapValues(_.get) + if (filteredSpec.isEmpty) { + None + } else { + Some(filteredSpec) + } + } + + override def run(sparkSession: SparkSession): Seq[Row] = { + val sessionState = sparkSession.sessionState + val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) + val tableMeta = sessionState.catalog.getTableMetadata(tableIdentWithDB) + if (tableMeta.tableType == CatalogTableType.VIEW) { + throw new AnalysisException("ANALYZE TABLE is not supported on views.") + } + + val partitionValueSpec = getPartitionSpec(tableMeta) + + val partitions = sessionState.catalog.listPartitions(tableMeta.identifier, partitionValueSpec) + + if (partitions.isEmpty) { + if (partitionValueSpec.isDefined) { + throw new NoSuchPartitionException(db, tableIdent.table, partitionValueSpec.get) + } else { + // the user requested to analyze all partitions for a table which has no partitions + // return normally, since there is nothing to do + return Seq.empty[Row] + } + } + + // Compute statistics for individual partitions + val rowCounts: Map[TablePartitionSpec, BigInt] = + if (noscan) { + Map.empty + } else { + calculateRowCountsPerPartition(sparkSession, tableMeta, partitionValueSpec) + } + + // Update the metastore if newly computed statistics are different from those + // recorded in the metastore. + val newPartitions = partitions.flatMap { p => + val newTotalSize = CommandUtils.calculateLocationSize( + sessionState, tableMeta.identifier, p.storage.locationUri) + val newRowCount = rowCounts.get(p.spec) + val newStats = CommandUtils.compareAndGetNewStats(tableMeta.stats, newTotalSize, newRowCount) + newStats.map(_ => p.copy(stats = newStats)) + } + + if (newPartitions.nonEmpty) { + sessionState.catalog.alterPartitions(tableMeta.identifier, newPartitions) + } + + Seq.empty[Row] + } + + private def calculateRowCountsPerPartition( + sparkSession: SparkSession, + tableMeta: CatalogTable, + partitionValueSpec: Option[TablePartitionSpec]): Map[TablePartitionSpec, BigInt] = { + val filter = if (partitionValueSpec.isDefined) { + val filters = partitionValueSpec.get.map { + case (columnName, value) => EqualTo(UnresolvedAttribute(columnName), Literal(value)) + } + filters.reduce(And) + } else { + Literal.TrueLiteral + } + + val tableDf = sparkSession.table(tableMeta.identifier) + val partitionColumns = tableMeta.partitionColumnNames.map(Column(_)) + + val df = tableDf.filter(Column(filter)).groupBy(partitionColumns: _*).count() + + df.collect().map { r => + val partitionColumnValues = partitionColumns.indices.map(r.get(_).toString) + val spec = tableMeta.partitionColumnNames.zip(partitionColumnValues).toMap + val count = BigInt(r.getLong(partitionColumns.size)) + (spec, count) + }.toMap + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index cba147c35dd99..04715bd314d4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.CatalogTableType /** @@ -37,31 +37,15 @@ case class AnalyzeTableCommand( if (tableMeta.tableType == CatalogTableType.VIEW) { throw new AnalysisException("ANALYZE TABLE is not supported on views.") } + + // Compute stats for the whole table val newTotalSize = CommandUtils.calculateTotalSize(sessionState, tableMeta) + val newRowCount = + if (noscan) None else Some(BigInt(sparkSession.table(tableIdentWithDB).count())) - val oldTotalSize = tableMeta.stats.map(_.sizeInBytes.toLong).getOrElse(-1L) - val oldRowCount = tableMeta.stats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L) - var newStats: Option[CatalogStatistics] = None - if (newTotalSize >= 0 && newTotalSize != oldTotalSize) { - newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize)) - } - // We only set rowCount when noscan is false, because otherwise: - // 1. when total size is not changed, we don't need to alter the table; - // 2. when total size is changed, `oldRowCount` becomes invalid. - // This is to make sure that we only record the right statistics. - if (!noscan) { - val newRowCount = sparkSession.table(tableIdentWithDB).count() - if (newRowCount >= 0 && newRowCount != oldRowCount) { - newStats = if (newStats.isDefined) { - newStats.map(_.copy(rowCount = Some(BigInt(newRowCount)))) - } else { - Some(CatalogStatistics( - sizeInBytes = oldTotalSize, rowCount = Some(BigInt(newRowCount)))) - } - } - } // Update the metastore if the above statistics of the table are different from those // recorded in the metastore. + val newStats = CommandUtils.compareAndGetNewStats(tableMeta.stats, newTotalSize, newRowCount) if (newStats.isDefined) { sessionState.catalog.alterTableStats(tableIdentWithDB, newStats) // Refresh the cached data source table in the catalog. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index de45be85220e9..b22958d59336c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition} import org.apache.spark.sql.internal.SessionState @@ -112,4 +112,29 @@ object CommandUtils extends Logging { size } + def compareAndGetNewStats( + oldStats: Option[CatalogStatistics], + newTotalSize: BigInt, + newRowCount: Option[BigInt]): Option[CatalogStatistics] = { + val oldTotalSize = oldStats.map(_.sizeInBytes.toLong).getOrElse(-1L) + val oldRowCount = oldStats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L) + var newStats: Option[CatalogStatistics] = None + if (newTotalSize >= 0 && newTotalSize != oldTotalSize) { + newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize)) + } + // We only set rowCount when noscan is false, because otherwise: + // 1. when total size is not changed, we don't need to alter the table; + // 2. when total size is changed, `oldRowCount` becomes invalid. + // This is to make sure that we only record the right statistics. + if (newRowCount.isDefined) { + if (newRowCount.get >= 0 && newRowCount.get != oldRowCount) { + newStats = if (newStats.isDefined) { + newStats.map(_.copy(rowCount = newRowCount)) + } else { + Some(CatalogStatistics(sizeInBytes = oldTotalSize, rowCount = newRowCount)) + } + } + } + newStats + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe-part-after-analyze.sql b/sql/core/src/test/resources/sql-tests/inputs/describe-part-after-analyze.sql new file mode 100644 index 0000000000000..f4239da906276 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/describe-part-after-analyze.sql @@ -0,0 +1,34 @@ +CREATE TABLE t (key STRING, value STRING, ds STRING, hr INT) USING parquet + PARTITIONED BY (ds, hr); + +INSERT INTO TABLE t PARTITION (ds='2017-08-01', hr=10) +VALUES ('k1', 100), ('k2', 200), ('k3', 300); + +INSERT INTO TABLE t PARTITION (ds='2017-08-01', hr=11) +VALUES ('k1', 101), ('k2', 201), ('k3', 301), ('k4', 401); + +INSERT INTO TABLE t PARTITION (ds='2017-09-01', hr=5) +VALUES ('k1', 102), ('k2', 202); + +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10); + +-- Collect stats for a single partition +ANALYZE TABLE t PARTITION (ds='2017-08-01', hr=10) COMPUTE STATISTICS; + +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10); + +-- Collect stats for 2 partitions +ANALYZE TABLE t PARTITION (ds='2017-08-01') COMPUTE STATISTICS; + +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10); +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=11); + +-- Collect stats for all partitions +ANALYZE TABLE t PARTITION (ds, hr) COMPUTE STATISTICS; + +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10); +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=11); +DESC EXTENDED t PARTITION (ds='2017-09-01', hr=5); + +-- DROP TEST TABLES/VIEWS +DROP TABLE t; diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out new file mode 100644 index 0000000000000..51dac111029e8 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out @@ -0,0 +1,244 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 15 + + +-- !query 0 +CREATE TABLE t (key STRING, value STRING, ds STRING, hr INT) USING parquet + PARTITIONED BY (ds, hr) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +INSERT INTO TABLE t PARTITION (ds='2017-08-01', hr=10) +VALUES ('k1', 100), ('k2', 200), ('k3', 300) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +INSERT INTO TABLE t PARTITION (ds='2017-08-01', hr=11) +VALUES ('k1', 101), ('k2', 201), ('k3', 301), ('k4', 401) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +INSERT INTO TABLE t PARTITION (ds='2017-09-01', hr=5) +VALUES ('k1', 102), ('k2', 202) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10) +-- !query 4 schema +struct +-- !query 4 output +key string +value string +ds string +hr int +# Partition Information +# col_name data_type comment +ds string +hr int + +# Detailed Partition Information +Database default +Table t +Partition Values [ds=2017-08-01, hr=10] +Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 + +# Storage Information +Location [not included in comparison]sql/core/spark-warehouse/t + + +-- !query 5 +ANALYZE TABLE t PARTITION (ds='2017-08-01', hr=10) COMPUTE STATISTICS +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10) +-- !query 6 schema +struct +-- !query 6 output +key string +value string +ds string +hr int +# Partition Information +# col_name data_type comment +ds string +hr int + +# Detailed Partition Information +Database default +Table t +Partition Values [ds=2017-08-01, hr=10] +Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Partition Statistics 1067 bytes, 3 rows + +# Storage Information +Location [not included in comparison]sql/core/spark-warehouse/t + + +-- !query 7 +ANALYZE TABLE t PARTITION (ds='2017-08-01') COMPUTE STATISTICS +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10) +-- !query 8 schema +struct +-- !query 8 output +key string +value string +ds string +hr int +# Partition Information +# col_name data_type comment +ds string +hr int + +# Detailed Partition Information +Database default +Table t +Partition Values [ds=2017-08-01, hr=10] +Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Partition Statistics 1067 bytes, 3 rows + +# Storage Information +Location [not included in comparison]sql/core/spark-warehouse/t + + +-- !query 9 +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=11) +-- !query 9 schema +struct +-- !query 9 output +key string +value string +ds string +hr int +# Partition Information +# col_name data_type comment +ds string +hr int + +# Detailed Partition Information +Database default +Table t +Partition Values [ds=2017-08-01, hr=11] +Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 +Partition Statistics 1080 bytes, 4 rows + +# Storage Information +Location [not included in comparison]sql/core/spark-warehouse/t + + +-- !query 10 +ANALYZE TABLE t PARTITION (ds, hr) COMPUTE STATISTICS +-- !query 10 schema +struct<> +-- !query 10 output + + + +-- !query 11 +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10) +-- !query 11 schema +struct +-- !query 11 output +key string +value string +ds string +hr int +# Partition Information +# col_name data_type comment +ds string +hr int + +# Detailed Partition Information +Database default +Table t +Partition Values [ds=2017-08-01, hr=10] +Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Partition Statistics 1067 bytes, 3 rows + +# Storage Information +Location [not included in comparison]sql/core/spark-warehouse/t + + +-- !query 12 +DESC EXTENDED t PARTITION (ds='2017-08-01', hr=11) +-- !query 12 schema +struct +-- !query 12 output +key string +value string +ds string +hr int +# Partition Information +# col_name data_type comment +ds string +hr int + +# Detailed Partition Information +Database default +Table t +Partition Values [ds=2017-08-01, hr=11] +Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 +Partition Statistics 1080 bytes, 4 rows + +# Storage Information +Location [not included in comparison]sql/core/spark-warehouse/t + + +-- !query 13 +DESC EXTENDED t PARTITION (ds='2017-09-01', hr=5) +-- !query 13 schema +struct +-- !query 13 output +key string +value string +ds string +hr int +# Partition Information +# col_name data_type comment +ds string +hr int + +# Detailed Partition Information +Database default +Table t +Partition Values [ds=2017-09-01, hr=5] +Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5 +Partition Statistics 1054 bytes, 2 rows + +# Storage Information +Location [not included in comparison]sql/core/spark-warehouse/t + + +-- !query 14 +DROP TABLE t +-- !query 14 schema +struct<> +-- !query 14 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index d238c76fbeeff..fa7a866f4d551 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -259,17 +259,33 @@ class SparkSqlParserSuite extends AnalysisTest { assertEqual("analyze table t compute statistics noscan", AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) assertEqual("analyze table t partition (a) compute statistics nOscAn", - AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) + AnalyzePartitionCommand(TableIdentifier("t"), Map("a" -> None), noscan = true)) - // Partitions specified - we currently parse them but don't do anything with it + // Partitions specified assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS", - AnalyzeTableCommand(TableIdentifier("t"), noscan = false)) + AnalyzePartitionCommand(TableIdentifier("t"), noscan = false, + partitionSpec = Map("ds" -> Some("2008-04-09"), "hr" -> Some("11")))) assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan", - AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) + AnalyzePartitionCommand(TableIdentifier("t"), noscan = true, + partitionSpec = Map("ds" -> Some("2008-04-09"), "hr" -> Some("11")))) + assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09') COMPUTE STATISTICS noscan", + AnalyzePartitionCommand(TableIdentifier("t"), noscan = true, + partitionSpec = Map("ds" -> Some("2008-04-09")))) + assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr) COMPUTE STATISTICS", + AnalyzePartitionCommand(TableIdentifier("t"), noscan = false, + partitionSpec = Map("ds" -> Some("2008-04-09"), "hr" -> None))) + assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr) COMPUTE STATISTICS noscan", + AnalyzePartitionCommand(TableIdentifier("t"), noscan = true, + partitionSpec = Map("ds" -> Some("2008-04-09"), "hr" -> None))) + assertEqual("ANALYZE TABLE t PARTITION(ds, hr=11) COMPUTE STATISTICS noscan", + AnalyzePartitionCommand(TableIdentifier("t"), noscan = true, + partitionSpec = Map("ds" -> None, "hr" -> Some("11")))) assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS", - AnalyzeTableCommand(TableIdentifier("t"), noscan = false)) + AnalyzePartitionCommand(TableIdentifier("t"), noscan = false, + partitionSpec = Map("ds" -> None, "hr" -> None))) assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS noscan", - AnalyzeTableCommand(TableIdentifier("t"), noscan = true)) + AnalyzePartitionCommand(TableIdentifier("t"), noscan = true, + partitionSpec = Map("ds" -> None, "hr" -> None))) intercept("analyze table t compute statistics xxxx", "Expected `NOSCAN` instead of `xxxx`") @@ -282,6 +298,11 @@ class SparkSqlParserSuite extends AnalysisTest { assertEqual("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS key, value", AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value"))) + + // Partition specified - should be ignored + assertEqual("ANALYZE TABLE t PARTITION(ds='2017-06-10') " + + "COMPUTE STATISTICS FOR COLUMNS key, value", + AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value"))) } test("query organization") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index e9d48f95aa905..547447b31f0a1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -639,26 +639,17 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat requireTableExists(db, table) val rawTable = getRawTable(db, table) - // convert table statistics to properties so that we can persist them through hive client - val statsProperties = new mutable.HashMap[String, String]() - if (stats.isDefined) { - statsProperties += STATISTICS_TOTAL_SIZE -> stats.get.sizeInBytes.toString() - if (stats.get.rowCount.isDefined) { - statsProperties += STATISTICS_NUM_ROWS -> stats.get.rowCount.get.toString() - } - - // For datasource tables and hive serde tables created by spark 2.1 or higher, - // the data schema is stored in the table properties. - val schema = restoreTableMetadata(rawTable).schema + // For datasource tables and hive serde tables created by spark 2.1 or higher, + // the data schema is stored in the table properties. + val schema = restoreTableMetadata(rawTable).schema - val colNameTypeMap: Map[String, DataType] = - schema.fields.map(f => (f.name, f.dataType)).toMap - stats.get.colStats.foreach { case (colName, colStat) => - colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => - statsProperties += (columnStatKeyPropName(colName, k) -> v) - } + // convert table statistics to properties so that we can persist them through hive client + var statsProperties = + if (stats.isDefined) { + statsToProperties(stats.get, schema) + } else { + new mutable.HashMap[String, String]() } - } val oldTableNonStatsProps = rawTable.properties.filterNot(_._1.startsWith(STATISTICS_PREFIX)) val updatedTable = rawTable.copy(properties = oldTableNonStatsProps ++ statsProperties) @@ -704,36 +695,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val version: String = table.properties.getOrElse(CREATED_SPARK_VERSION, "2.2 or prior") // Restore Spark's statistics from information in Metastore. - val statsProps = table.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) - - // Currently we have two sources of statistics: one from Hive and the other from Spark. - // In our design, if Spark's statistics is available, we respect it over Hive's statistics. - if (statsProps.nonEmpty) { - val colStats = new mutable.HashMap[String, ColumnStat] - - // For each column, recover its column stats. Note that this is currently a O(n^2) operation, - // but given the number of columns it usually not enormous, this is probably OK as a start. - // If we want to map this a linear operation, we'd need a stronger contract between the - // naming convention used for serialization. - table.schema.foreach { field => - if (statsProps.contains(columnStatKeyPropName(field.name, ColumnStat.KEY_VERSION))) { - // If "version" field is defined, then the column stat is defined. - val keyPrefix = columnStatKeyPropName(field.name, "") - val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) => - (k.drop(keyPrefix.length), v) - } - - ColumnStat.fromMap(table.identifier.table, field, colStatMap).foreach { - colStat => colStats += field.name -> colStat - } - } - } - - table = table.copy( - stats = Some(CatalogStatistics( - sizeInBytes = BigInt(table.properties(STATISTICS_TOTAL_SIZE)), - rowCount = table.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)), - colStats = colStats.toMap))) + val restoredStats = + statsFromProperties(table.properties, table.identifier.table, table.schema) + if (restoredStats.isDefined) { + table = table.copy(stats = restoredStats) } // Get the original table properties as defined by the user. @@ -1037,17 +1002,92 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat currentFullPath } + private def statsToProperties( + stats: CatalogStatistics, + schema: StructType): Map[String, String] = { + + var statsProperties: Map[String, String] = + Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString()) + if (stats.rowCount.isDefined) { + statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() + } + + val colNameTypeMap: Map[String, DataType] = + schema.fields.map(f => (f.name, f.dataType)).toMap + stats.colStats.foreach { case (colName, colStat) => + colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => + statsProperties += (columnStatKeyPropName(colName, k) -> v) + } + } + + statsProperties + } + + private def statsFromProperties( + properties: Map[String, String], + table: String, + schema: StructType): Option[CatalogStatistics] = { + + val statsProps = properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) + if (statsProps.isEmpty) { + None + } else { + + val colStats = new mutable.HashMap[String, ColumnStat] + + // For each column, recover its column stats. Note that this is currently a O(n^2) operation, + // but given the number of columns it usually not enormous, this is probably OK as a start. + // If we want to map this a linear operation, we'd need a stronger contract between the + // naming convention used for serialization. + schema.foreach { field => + if (statsProps.contains(columnStatKeyPropName(field.name, ColumnStat.KEY_VERSION))) { + // If "version" field is defined, then the column stat is defined. + val keyPrefix = columnStatKeyPropName(field.name, "") + val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) => + (k.drop(keyPrefix.length), v) + } + + ColumnStat.fromMap(table, field, colStatMap).foreach { + colStat => colStats += field.name -> colStat + } + } + } + + Some(CatalogStatistics( + sizeInBytes = BigInt(statsProps(STATISTICS_TOTAL_SIZE)), + rowCount = statsProps.get(STATISTICS_NUM_ROWS).map(BigInt(_)), + colStats = colStats.toMap)) + } + } + override def alterPartitions( db: String, table: String, newParts: Seq[CatalogTablePartition]): Unit = withClient { val lowerCasedParts = newParts.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) + + val rawTable = getRawTable(db, table) + + // For datasource tables and hive serde tables created by spark 2.1 or higher, + // the data schema is stored in the table properties. + val schema = restoreTableMetadata(rawTable).schema + + // convert partition statistics to properties so that we can persist them through hive api + val withStatsProps = lowerCasedParts.map(p => { + if (p.stats.isDefined) { + val statsProperties = statsToProperties(p.stats.get, schema) + p.copy(parameters = p.parameters ++ statsProperties) + } else { + p + } + }) + // Note: Before altering table partitions in Hive, you *must* set the current database // to the one that contains the table of interest. Otherwise you will end up with the // most helpful error message ever: "Unable to alter partition. alter is not possible." // See HIVE-2742 for more detail. client.setCurrentDatabase(db) - client.alterPartitions(db, table, lowerCasedParts) + client.alterPartitions(db, table, withStatsProps) } override def getPartition( @@ -1055,7 +1095,34 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table: String, spec: TablePartitionSpec): CatalogTablePartition = withClient { val part = client.getPartition(db, table, lowerCasePartitionSpec(spec)) - part.copy(spec = restorePartitionSpec(part.spec, getTable(db, table).partitionColumnNames)) + restorePartitionMetadata(part, getTable(db, table)) + } + + /** + * Restores partition metadata from the partition properties. + * + * Reads partition-level statistics from partition properties, puts these + * into [[CatalogTablePartition#stats]] and removes these special entries + * from the partition properties. + */ + private def restorePartitionMetadata( + partition: CatalogTablePartition, + table: CatalogTable): CatalogTablePartition = { + val restoredSpec = restorePartitionSpec(partition.spec, table.partitionColumnNames) + + // Restore Spark's statistics from information in Metastore. + // Note: partition-level statistics were introduced in 2.3. + val restoredStats = + statsFromProperties(partition.parameters, table.identifier.table, table.schema) + if (restoredStats.isDefined) { + partition.copy( + spec = restoredSpec, + stats = restoredStats, + parameters = partition.parameters.filterNot { + case (key, _) => key.startsWith(SPARK_SQL_PREFIX) }) + } else { + partition.copy(spec = restoredSpec) + } } /** @@ -1066,7 +1133,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table: String, spec: TablePartitionSpec): Option[CatalogTablePartition] = withClient { client.getPartitionOption(db, table, lowerCasePartitionSpec(spec)).map { part => - part.copy(spec = restorePartitionSpec(part.spec, getTable(db, table).partitionColumnNames)) + restorePartitionMetadata(part, getTable(db, table)) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 5e5c0a2a5078c..995280e0e9416 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -21,6 +21,7 @@ import java.io.{File, PrintStream} import java.util.Locale import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.conf.Configuration @@ -960,6 +961,7 @@ private[hive] object HiveClientImpl { tpart.setTableName(ht.getTableName) tpart.setValues(partValues.asJava) tpart.setSd(storageDesc) + tpart.setParameters(mutable.Map(p.parameters.toSeq: _*).asJava) new HivePartition(ht, tpart) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 71cf79c473b46..dc6140756d519 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, HiveTableRelation} import org.apache.spark.sql.catalyst.plans.logical.ColumnStat import org.apache.spark.sql.catalyst.util.StringUtils @@ -256,6 +257,259 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + test("analyze single partition") { + val tableName = "analyzeTable_part" + + def queryStats(ds: String): CatalogStatistics = { + val partition = + spark.sessionState.catalog.getPartition(TableIdentifier(tableName), Map("ds" -> ds)) + partition.stats.get + } + + def createPartition(ds: String, query: String): Unit = { + sql(s"INSERT INTO TABLE $tableName PARTITION (ds='$ds') $query") + } + + withTable(tableName) { + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)") + + createPartition("2010-01-01", "SELECT '1', 'A' from src") + createPartition("2010-01-02", "SELECT '1', 'A' from src UNION ALL SELECT '1', 'A' from src") + createPartition("2010-01-03", "SELECT '1', 'A' from src") + + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-01') COMPUTE STATISTICS NOSCAN") + + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-02') COMPUTE STATISTICS NOSCAN") + + assert(queryStats("2010-01-01").rowCount === None) + assert(queryStats("2010-01-01").sizeInBytes === 2000) + + assert(queryStats("2010-01-02").rowCount === None) + assert(queryStats("2010-01-02").sizeInBytes === 2*2000) + + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-01') COMPUTE STATISTICS") + + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-02') COMPUTE STATISTICS") + + assert(queryStats("2010-01-01").rowCount.get === 500) + assert(queryStats("2010-01-01").sizeInBytes === 2000) + + assert(queryStats("2010-01-02").rowCount.get === 2*500) + assert(queryStats("2010-01-02").sizeInBytes === 2*2000) + } + } + + test("analyze a set of partitions") { + val tableName = "analyzeTable_part" + + def queryStats(ds: String, hr: String): Option[CatalogStatistics] = { + val tableId = TableIdentifier(tableName) + val partition = + spark.sessionState.catalog.getPartition(tableId, Map("ds" -> ds, "hr" -> hr)) + partition.stats + } + + def assertPartitionStats( + ds: String, + hr: String, + rowCount: Option[BigInt], + sizeInBytes: BigInt): Unit = { + val stats = queryStats(ds, hr).get + assert(stats.rowCount === rowCount) + assert(stats.sizeInBytes === sizeInBytes) + } + + def createPartition(ds: String, hr: Int, query: String): Unit = { + sql(s"INSERT INTO TABLE $tableName PARTITION (ds='$ds', hr=$hr) $query") + } + + withTable(tableName) { + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING, hr INT)") + + createPartition("2010-01-01", 10, "SELECT '1', 'A' from src") + createPartition("2010-01-01", 11, "SELECT '1', 'A' from src") + createPartition("2010-01-02", 10, "SELECT '1', 'A' from src") + createPartition("2010-01-02", 11, + "SELECT '1', 'A' from src UNION ALL SELECT '1', 'A' from src") + + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-01') COMPUTE STATISTICS NOSCAN") + + assertPartitionStats("2010-01-01", "10", rowCount = None, sizeInBytes = 2000) + assertPartitionStats("2010-01-01", "11", rowCount = None, sizeInBytes = 2000) + assert(queryStats("2010-01-02", "10") === None) + assert(queryStats("2010-01-02", "11") === None) + + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-02') COMPUTE STATISTICS NOSCAN") + + assertPartitionStats("2010-01-01", "10", rowCount = None, sizeInBytes = 2000) + assertPartitionStats("2010-01-01", "11", rowCount = None, sizeInBytes = 2000) + assertPartitionStats("2010-01-02", "10", rowCount = None, sizeInBytes = 2000) + assertPartitionStats("2010-01-02", "11", rowCount = None, sizeInBytes = 2*2000) + + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-01') COMPUTE STATISTICS") + + assertPartitionStats("2010-01-01", "10", rowCount = Some(500), sizeInBytes = 2000) + assertPartitionStats("2010-01-01", "11", rowCount = Some(500), sizeInBytes = 2000) + assertPartitionStats("2010-01-02", "10", rowCount = None, sizeInBytes = 2000) + assertPartitionStats("2010-01-02", "11", rowCount = None, sizeInBytes = 2*2000) + + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-02') COMPUTE STATISTICS") + + assertPartitionStats("2010-01-01", "10", rowCount = Some(500), sizeInBytes = 2000) + assertPartitionStats("2010-01-01", "11", rowCount = Some(500), sizeInBytes = 2000) + assertPartitionStats("2010-01-02", "10", rowCount = Some(500), sizeInBytes = 2000) + assertPartitionStats("2010-01-02", "11", rowCount = Some(2*500), sizeInBytes = 2*2000) + } + } + + test("analyze all partitions") { + val tableName = "analyzeTable_part" + + def assertPartitionStats( + ds: String, + hr: String, + rowCount: Option[BigInt], + sizeInBytes: BigInt): Unit = { + val stats = spark.sessionState.catalog.getPartition(TableIdentifier(tableName), + Map("ds" -> ds, "hr" -> hr)).stats.get + assert(stats.rowCount === rowCount) + assert(stats.sizeInBytes === sizeInBytes) + } + + def createPartition(ds: String, hr: Int, query: String): Unit = { + sql(s"INSERT INTO TABLE $tableName PARTITION (ds='$ds', hr=$hr) $query") + } + + withTable(tableName) { + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING, hr INT)") + + createPartition("2010-01-01", 10, "SELECT '1', 'A' from src") + createPartition("2010-01-01", 11, "SELECT '1', 'A' from src") + createPartition("2010-01-02", 10, "SELECT '1', 'A' from src") + createPartition("2010-01-02", 11, + "SELECT '1', 'A' from src UNION ALL SELECT '1', 'A' from src") + + sql(s"ANALYZE TABLE $tableName PARTITION (ds, hr) COMPUTE STATISTICS NOSCAN") + + assertPartitionStats("2010-01-01", "10", rowCount = None, sizeInBytes = 2000) + assertPartitionStats("2010-01-01", "11", rowCount = None, sizeInBytes = 2000) + assertPartitionStats("2010-01-01", "11", rowCount = None, sizeInBytes = 2000) + assertPartitionStats("2010-01-02", "11", rowCount = None, sizeInBytes = 2*2000) + + sql(s"ANALYZE TABLE $tableName PARTITION (ds, hr) COMPUTE STATISTICS") + + assertPartitionStats("2010-01-01", "10", rowCount = Some(500), sizeInBytes = 2000) + assertPartitionStats("2010-01-01", "11", rowCount = Some(500), sizeInBytes = 2000) + assertPartitionStats("2010-01-01", "11", rowCount = Some(500), sizeInBytes = 2000) + assertPartitionStats("2010-01-02", "11", rowCount = Some(2*500), sizeInBytes = 2*2000) + } + } + + test("analyze partitions for an empty table") { + val tableName = "analyzeTable_part" + + withTable(tableName) { + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)") + + // make sure there is no exception + sql(s"ANALYZE TABLE $tableName PARTITION (ds) COMPUTE STATISTICS NOSCAN") + + // make sure there is no exception + sql(s"ANALYZE TABLE $tableName PARTITION (ds) COMPUTE STATISTICS") + } + } + + test("analyze partitions case sensitivity") { + val tableName = "analyzeTable_part" + withTable(tableName) { + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)") + + sql(s"INSERT INTO TABLE $tableName PARTITION (ds='2010-01-01') SELECT * FROM src") + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + sql(s"ANALYZE TABLE $tableName PARTITION (DS='2010-01-01') COMPUTE STATISTICS") + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val message = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $tableName PARTITION (DS='2010-01-01') COMPUTE STATISTICS") + }.getMessage + assert(message.contains( + s"DS is not a valid partition column in table `default`.`${tableName.toLowerCase}`")) + } + } + } + + test("analyze partial partition specifications") { + + val tableName = "analyzeTable_part" + + def assertAnalysisException(partitionSpec: String): Unit = { + val message = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $tableName $partitionSpec COMPUTE STATISTICS") + }.getMessage + assert(message.contains("The list of partition columns with values " + + s"in partition specification for table '${tableName.toLowerCase}' in database 'default' " + + "is not a prefix of the list of partition columns defined in the table schema")) + } + + withTable(tableName) { + sql( + s""" + |CREATE TABLE $tableName (key STRING, value STRING) + |PARTITIONED BY (a STRING, b INT, c STRING) + """.stripMargin) + + sql(s"INSERT INTO TABLE $tableName PARTITION (a='a1', b=10, c='c1') SELECT * FROM src") + + sql(s"ANALYZE TABLE $tableName PARTITION (a='a1') COMPUTE STATISTICS") + sql(s"ANALYZE TABLE $tableName PARTITION (a='a1', b=10) COMPUTE STATISTICS") + sql(s"ANALYZE TABLE $tableName PARTITION (A='a1', b=10) COMPUTE STATISTICS") + sql(s"ANALYZE TABLE $tableName PARTITION (b=10, a='a1') COMPUTE STATISTICS") + sql(s"ANALYZE TABLE $tableName PARTITION (b=10, A='a1') COMPUTE STATISTICS") + + assertAnalysisException("PARTITION (b=10)") + assertAnalysisException("PARTITION (a, b=10)") + assertAnalysisException("PARTITION (b=10, c='c1')") + assertAnalysisException("PARTITION (a, b=10, c='c1')") + assertAnalysisException("PARTITION (c='c1')") + assertAnalysisException("PARTITION (a, b, c='c1')") + assertAnalysisException("PARTITION (a='a1', c='c1')") + assertAnalysisException("PARTITION (a='a1', b, c='c1')") + } + } + + test("analyze non-existent partition") { + + def assertAnalysisException(analyzeCommand: String, errorMessage: String): Unit = { + val message = intercept[AnalysisException] { + sql(analyzeCommand) + }.getMessage + assert(message.contains(errorMessage)) + } + + val tableName = "analyzeTable_part" + withTable(tableName) { + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)") + + sql(s"INSERT INTO TABLE $tableName PARTITION (ds='2010-01-01') SELECT * FROM src") + + assertAnalysisException( + s"ANALYZE TABLE $tableName PARTITION (hour=20) COMPUTE STATISTICS", + s"hour is not a valid partition column in table `default`.`${tableName.toLowerCase}`" + ) + + assertAnalysisException( + s"ANALYZE TABLE $tableName PARTITION (hour) COMPUTE STATISTICS", + s"hour is not a valid partition column in table `default`.`${tableName.toLowerCase}`" + ) + + intercept[NoSuchPartitionException] { + sql(s"ANALYZE TABLE $tableName PARTITION (ds='2011-02-30') COMPUTE STATISTICS") + } + } + } + test("test table-level statistics for hive tables created in HiveExternalCatalog") { val textTable = "textTable" withTable(textTable) { From 7880909c45916ab76dccac308a9b2c5225a00e89 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 18 Aug 2017 11:19:22 -0700 Subject: [PATCH 009/187] [SPARK-21743][SQL][FOLLOW-UP] top-most limit should not cause memory leak ## What changes were proposed in this pull request? This is a follow-up of https://github.com/apache/spark/pull/18955 , to fix a bug that we break whole stage codegen for `Limit`. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #18993 from cloud-fan/bug. --- .../sql/catalyst/optimizer/Optimizer.scala | 5 ++- .../spark/sql/execution/SparkStrategies.scala | 37 ++++++++----------- .../apache/spark/sql/execution/limit.scala | 8 ---- 3 files changed, 20 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a51b385399d88..e2d7164d93ac1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1171,7 +1171,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { * Converts local operations (i.e. ones that don't require data exchange) on LocalRelation to * another LocalRelation. * - * This is relatively simple as it currently handles only a single case: Project. + * This is relatively simple as it currently handles only 2 single case: Project and Limit. */ object ConvertToLocalRelation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -1180,6 +1180,9 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { val projection = new InterpretedProjection(projectList, output) projection.initialize(0) LocalRelation(projectList.map(_.toAttribute), data.map(projection)) + + case Limit(IntegerLiteral(limit), LocalRelation(output, data)) => + LocalRelation(output, data.take(limit)) } private def hasUnevaluableExpr(expr: Expression): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2e8ce4541865d..c115cb6e80e91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -63,29 +63,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object SpecialLimits extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.ReturnAnswer(rootPlan) => rootPlan match { - case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => - execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case logical.Limit( - IntegerLiteral(limit), - logical.Project(projectList, logical.Sort(order, true, child))) => - execution.TakeOrderedAndProjectExec( - limit, order, projectList, planLater(child)) :: Nil - case logical.Limit(IntegerLiteral(limit), child) => - // Normally wrapping child with `LocalLimitExec` here is a no-op, because - // `CollectLimitExec.executeCollect` will call `LocalLimitExec.executeTake`, which - // calls `child.executeTake`. If child supports whole stage codegen, adding this - // `LocalLimitExec` can stop the processing of whole stage codegen and trigger the - // resource releasing work, after we consume `limit` rows. - execution.CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + case ReturnAnswer(rootPlan) => rootPlan match { + case Limit(IntegerLiteral(limit), Sort(order, true, child)) => + TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) => + TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil + case Limit(IntegerLiteral(limit), child) => + // With whole stage codegen, Spark releases resources only when all the output data of the + // query plan are consumed. It's possible that `CollectLimitExec` only consumes a little + // data from child plan and finishes the query without releasing resources. Here we wrap + // the child plan with `LocalLimitExec`, to stop the processing of whole stage codegen and + // trigger the resource releasing work, after we consume `limit` rows. + CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil case other => planLater(other) :: Nil } - case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => - execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case logical.Limit( - IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) => - execution.TakeOrderedAndProjectExec( - limit, order, projectList, planLater(child)) :: Nil + case Limit(IntegerLiteral(limit), Sort(order, true, child)) => + TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) => + TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 7cef5569717a3..73a0f8735ed45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -54,14 +54,6 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { val limit: Int override def output: Seq[Attribute] = child.output - // Do not enable whole stage codegen for a single limit. - override def supportCodegen: Boolean = child match { - case plan: CodegenSupport => plan.supportCodegen - case _ => false - } - - override def executeTake(n: Int): Array[InternalRow] = child.executeTake(math.min(n, limit)) - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } From a2db5c5761b0c72babe48b79859d3b208ee8e9f6 Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Fri, 18 Aug 2017 13:43:42 -0700 Subject: [PATCH 010/187] [MINOR][TYPO] Fix typos: runnning and Excecutors ## What changes were proposed in this pull request? Fix typos ## How was this patch tested? Existing tests Author: Andrew Ash Closes #18996 from ash211/patch-2. --- .../scala/org/apache/spark/deploy/yarn/YarnAllocator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index f73e7dc0bb567..7052fb347106b 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -551,8 +551,8 @@ private[yarn] class YarnAllocator( updateInternalState() } } else { - logInfo(("Skip launching executorRunnable as runnning Excecutors count: %d " + - "reached target Executors count: %d.").format( + logInfo(("Skip launching executorRunnable as running executors count: %d " + + "reached target executors count: %d.").format( numExecutorsRunning.get, targetNumExecutors)) } } From 10be01848ef28004a287940a4e8d8a044e14b257 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Fri, 18 Aug 2017 18:10:54 -0700 Subject: [PATCH 011/187] [SPARK-21566][SQL][PYTHON] Python method for summary ## What changes were proposed in this pull request? Adds the recently added `summary` method to the python dataframe interface. ## How was this patch tested? Additional inline doctests. Author: Andrew Ray Closes #18762 from aray/summary-py. --- python/pyspark/sql/dataframe.py | 61 ++++++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5cd208bb525a3..d1b2a9c9947e1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -927,7 +927,7 @@ def _sort_cols(self, cols, kwargs): @since("1.3.1") def describe(self, *cols): - """Computes statistics for numeric and string columns. + """Computes basic statistics for numeric and string columns. This include count, mean, stddev, min, and max. If no columns are given, this function computes statistics for all numerical or string columns. @@ -955,12 +955,71 @@ def describe(self, *cols): | min| 2|Alice| | max| 5| Bob| +-------+------------------+-----+ + + Use summary for expanded statistics and control over which statistics to compute. """ if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] jdf = self._jdf.describe(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx) + @since("2.3.0") + def summary(self, *statistics): + """Computes specified statistics for numeric and string columns. Available statistics are: + - count + - mean + - stddev + - min + - max + - arbitrary approximate percentiles specified as a percentage (eg, 75%) + + If no statistics are given, this function computes count, mean, stddev, min, + approximate quartiles (percentiles at 25%, 50%, and 75%), and max. + + .. note:: This function is meant for exploratory data analysis, as we make no + guarantee about the backward compatibility of the schema of the resulting DataFrame. + + >>> df.summary().show() + +-------+------------------+-----+ + |summary| age| name| + +-------+------------------+-----+ + | count| 2| 2| + | mean| 3.5| null| + | stddev|2.1213203435596424| null| + | min| 2|Alice| + | 25%| 5.0| null| + | 50%| 5.0| null| + | 75%| 5.0| null| + | max| 5| Bob| + +-------+------------------+-----+ + + >>> df.summary("count", "min", "25%", "75%", "max").show() + +-------+---+-----+ + |summary|age| name| + +-------+---+-----+ + | count| 2| 2| + | min| 2|Alice| + | 25%|5.0| null| + | 75%|5.0| null| + | max| 5| Bob| + +-------+---+-----+ + + To do a summary for specific columns first select them: + + >>> df.select("age", "name").summary("count").show() + +-------+---+----+ + |summary|age|name| + +-------+---+----+ + | count| 2| 2| + +-------+---+----+ + + See also describe for basic statistics. + """ + if len(statistics) == 1 and isinstance(statistics[0], list): + statistics = statistics[0] + jdf = self._jdf.summary(self._jseq(statistics)) + return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix @since(1.3) def head(self, n=None): From 72b738d8dcdb7893003c81bf1c73bbe262852d1a Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 19 Aug 2017 11:41:32 -0700 Subject: [PATCH 012/187] [SPARK-21790][TESTS] Fix Docker-based Integration Test errors. ## What changes were proposed in this pull request? [SPARK-17701](https://github.com/apache/spark/pull/18600/files#diff-b9f96d092fb3fea76bcf75e016799678L77) removed `metadata` function, this PR removed the Docker-based Integration module that has been relevant to `SparkPlan.metadata`. ## How was this patch tested? manual tests Author: Yuming Wang Closes #19000 from wangyum/SPARK-21709. --- external/docker-integration-tests/pom.xml | 7 +++++++ .../apache/spark/sql/jdbc/OracleIntegrationSuite.scala | 9 --------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 0fa87a697454b..485b562dce990 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -80,6 +80,13 @@ test-jar test + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-sql_${scala.binary.version} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index e14810a32edc6..80a129a9e0329 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -255,15 +255,6 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo val df = dfRead.filter(dfRead.col("date_type").lt(dt)) .filter(dfRead.col("timestamp_type").lt(ts)) - val metadata = df.queryExecution.sparkPlan.metadata - // The "PushedFilters" part should be exist in Datafrome's - // physical plan and the existence of right literals in - // "PushedFilters" is used to prove that the predicates - // pushing down have been effective. - assert(metadata.get("PushedFilters").ne(None)) - assert(metadata("PushedFilters").contains(dt.toString)) - assert(metadata("PushedFilters").contains(ts.toString)) - val row = df.collect()(0) assert(row.getDate(0).equals(dateVal)) assert(row.getTimestamp(1).equals(timestampVal)) From 73e04ecc4f29a0fe51687ed1337c61840c976f89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Pelvet?= Date: Sun, 20 Aug 2017 11:05:54 +0100 Subject: [PATCH 013/187] [MINOR] Correct validateAndTransformSchema in GaussianMixture and AFTSurvivalRegression MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The line SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) did not modify the variable schema, hence only the last line had any effect. A temporary variable is used to correctly append the two columns predictionCol and probabilityCol. ## How was this patch tested? Manually. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Cédric Pelvet Closes #18980 from sharp-pixel/master. --- .../org/apache/spark/ml/clustering/GaussianMixture.scala | 4 ++-- .../spark/ml/regression/AFTSurvivalRegression.scala | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 5259ee419445f..f19ad7a5a6938 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -64,8 +64,8 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w */ protected def validateAndTransformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) - SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) - SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT) + val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 0891994530f88..16821f317760e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -109,10 +109,12 @@ private[regression] trait AFTSurvivalRegressionParams extends Params SchemaUtils.checkNumericType(schema, $(censorCol)) SchemaUtils.checkNumericType(schema, $(labelCol)) } - if (hasQuantilesCol) { + + val schemaWithQuantilesCol = if (hasQuantilesCol) { SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT) - } - SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + } else schema + + SchemaUtils.appendColumn(schemaWithQuantilesCol, $(predictionCol), DoubleType) } } From 41e0eb71a63140c9a44a7d2f32821f02abd62367 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 20 Aug 2017 19:48:04 +0900 Subject: [PATCH 014/187] [SPARK-21773][BUILD][DOCS] Installs mkdocs if missing in the path in SQL documentation build MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR proposes to install `mkdocs` by `pip install` if missing in the path. Mainly to fix Jenkins's documentation build failure in `spark-master-docs`. See https://amplab.cs.berkeley.edu/jenkins/job/spark-master-docs/3580/console. It also adds `mkdocs` as requirements in `docs/README.md`. ## How was this patch tested? I manually ran `jekyll build` under `docs` directory after manually removing `mkdocs` via `pip uninstall mkdocs`. Also, tested this in the same way but on CentOS Linux release 7.3.1611 (Core) where I built Spark few times but never built documentation before and `mkdocs` is not installed. ``` ... Moving back into docs dir. Moving to SQL directory and building docs. Missing mkdocs in your path, trying to install mkdocs for SQL documentation generation. Collecting mkdocs Downloading mkdocs-0.16.3-py2.py3-none-any.whl (1.2MB) 100% |████████████████████████████████| 1.2MB 574kB/s Requirement already satisfied: PyYAML>=3.10 in /usr/lib64/python2.7/site-packages (from mkdocs) Collecting livereload>=2.5.1 (from mkdocs) Downloading livereload-2.5.1-py2-none-any.whl Collecting tornado>=4.1 (from mkdocs) Downloading tornado-4.5.1.tar.gz (483kB) 100% |████████████████████████████████| 491kB 1.4MB/s Collecting Markdown>=2.3.1 (from mkdocs) Downloading Markdown-2.6.9.tar.gz (271kB) 100% |████████████████████████████████| 276kB 2.4MB/s Collecting click>=3.3 (from mkdocs) Downloading click-6.7-py2.py3-none-any.whl (71kB) 100% |████████████████████████████████| 71kB 2.8MB/s Requirement already satisfied: Jinja2>=2.7.1 in /usr/lib/python2.7/site-packages (from mkdocs) Requirement already satisfied: six in /usr/lib/python2.7/site-packages (from livereload>=2.5.1->mkdocs) Requirement already satisfied: backports.ssl_match_hostname in /usr/lib/python2.7/site-packages (from tornado>=4.1->mkdocs) Collecting singledispatch (from tornado>=4.1->mkdocs) Downloading singledispatch-3.4.0.3-py2.py3-none-any.whl Collecting certifi (from tornado>=4.1->mkdocs) Downloading certifi-2017.7.27.1-py2.py3-none-any.whl (349kB) 100% |████████████████████████████████| 358kB 2.1MB/s Collecting backports_abc>=0.4 (from tornado>=4.1->mkdocs) Downloading backports_abc-0.5-py2.py3-none-any.whl Requirement already satisfied: MarkupSafe>=0.23 in /usr/lib/python2.7/site-packages (from Jinja2>=2.7.1->mkdocs) Building wheels for collected packages: tornado, Markdown Running setup.py bdist_wheel for tornado ... done Stored in directory: /root/.cache/pip/wheels/84/83/cd/6a04602633457269d161344755e6766d24307189b7a67ff4b7 Running setup.py bdist_wheel for Markdown ... done Stored in directory: /root/.cache/pip/wheels/bf/46/10/c93e17ae86ae3b3a919c7b39dad3b5ccf09aeb066419e5c1e5 Successfully built tornado Markdown Installing collected packages: singledispatch, certifi, backports-abc, tornado, livereload, Markdown, click, mkdocs Successfully installed Markdown-2.6.9 backports-abc-0.5 certifi-2017.7.27.1 click-6.7 livereload-2.5.1 mkdocs-0.16.3 singledispatch-3.4.0.3 tornado-4.5.1 Generating markdown files for SQL documentation. Generating HTML files for SQL documentation. INFO - Cleaning site directory INFO - Building documentation to directory: .../spark/sql/site Moving back into docs dir. Making directory api/sql cp -r ../sql/site/. api/sql Source: .../spark/docs Destination: .../spark/docs/_site Generating... done. Auto-regeneration: disabled. Use --watch to enable. ``` Author: hyukjinkwon Closes #18984 from HyukjinKwon/sql-doc-mkdocs. --- docs/README.md | 2 +- sql/create-docs.sh | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/README.md b/docs/README.md index 0090dd071e15f..866364f1566a9 100644 --- a/docs/README.md +++ b/docs/README.md @@ -19,7 +19,7 @@ installed. Also install the following libraries: $ sudo gem install jekyll jekyll-redirect-from pygments.rb $ sudo pip install Pygments # Following is needed only for generating API docs - $ sudo pip install sphinx pypandoc + $ sudo pip install sphinx pypandoc mkdocs $ sudo Rscript -e 'install.packages(c("knitr", "devtools", "roxygen2", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' ``` (Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0) diff --git a/sql/create-docs.sh b/sql/create-docs.sh index 275e4c391a388..1d2d602c979be 100755 --- a/sql/create-docs.sh +++ b/sql/create-docs.sh @@ -33,8 +33,8 @@ if ! hash python 2>/dev/null; then fi if ! hash mkdocs 2>/dev/null; then - echo "Missing mkdocs in your path, skipping SQL documentation generation." - exit 0 + echo "Missing mkdocs in your path, trying to install mkdocs for SQL documentation generation." + pip install mkdocs fi # Now create the markdown file From 28a6cca7df900d13613b318c07acb97a5722d2b8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 Aug 2017 00:45:23 +0800 Subject: [PATCH 015/187] [SPARK-21721][SQL][FOLLOWUP] Clear FileSystem deleteOnExit cache when paths are successfully removed ## What changes were proposed in this pull request? Fix a typo in test. ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh Closes #19005 from viirya/SPARK-21721-followup. --- .../apache/spark/sql/hive/execution/SQLQuerySuite.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index ef3d9b27aad79..d0e0d20df30af 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2023,7 +2023,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("SPARK-21721: Clear FileSystem deleterOnExit cache if path is successfully removed") { - withTable("test21721") { + val table = "test21721" + withTable(table) { val deleteOnExitField = classOf[FileSystem].getDeclaredField("deleteOnExit") deleteOnExitField.setAccessible(true) @@ -2031,10 +2032,10 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val setOfPath = deleteOnExitField.get(fs).asInstanceOf[Set[Path]] val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() - sql("CREATE TABLE test21721 (key INT, value STRING)") + sql(s"CREATE TABLE $table (key INT, value STRING)") val pathSizeToDeleteOnExit = setOfPath.size() - (0 to 10).foreach(_ => testData.write.mode(SaveMode.Append).insertInto("test1")) + (0 to 10).foreach(_ => testData.write.mode(SaveMode.Append).insertInto(table)) assert(setOfPath.size() == pathSizeToDeleteOnExit) } From 77d046ec47a9bfa6323aa014869844c28e18e049 Mon Sep 17 00:00:00 2001 From: Sergey Serebryakov Date: Mon, 21 Aug 2017 08:21:25 +0100 Subject: [PATCH 016/187] [SPARK-21782][CORE] Repartition creates skews when numPartitions is a power of 2 ## Problem When an RDD (particularly with a low item-per-partition ratio) is repartitioned to numPartitions = power of 2, the resulting partitions are very uneven-sized, due to using fixed seed to initialize PRNG, and using the PRNG only once. See details in https://issues.apache.org/jira/browse/SPARK-21782 ## What changes were proposed in this pull request? Instead of directly using `0, 1, 2,...` seeds to initialize `Random`, hash them with `scala.util.hashing.byteswap32()`. ## How was this patch tested? `build/mvn -Dtest=none -DwildcardSuites=org.apache.spark.rdd.RDDSuite test` Author: Sergey Serebryakov Closes #18990 from megaserg/repartition-skew. --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 3 ++- core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 5435f59ea0d28..8798dfc925362 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import scala.io.Codec import scala.language.implicitConversions import scala.reflect.{classTag, ClassTag} +import scala.util.hashing import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus import org.apache.hadoop.io.{BytesWritable, NullWritable, Text} @@ -448,7 +449,7 @@ abstract class RDD[T: ClassTag]( if (shuffle) { /** Distributes elements evenly across output partitions, starting from a random partition. */ val distributePartition = (index: Int, items: Iterator[T]) => { - var position = (new Random(index)).nextInt(numPartitions) + var position = (new Random(hashing.byteswap32(index))).nextInt(numPartitions) items.map { t => // Note that the hash code of the key will just be the key itself. The HashPartitioner // will mod it with the number of total partitions. diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 386c0060f9c41..e994d724c462f 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -347,16 +347,18 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { val partitions = repartitioned.glom().collect() // assert all elements are present assert(repartitioned.collect().sortWith(_ > _).toSeq === input.toSeq.sortWith(_ > _).toSeq) - // assert no bucket is overloaded + // assert no bucket is overloaded or empty for (partition <- partitions) { val avg = input.size / finalPartitions val maxPossible = avg + initialPartitions - assert(partition.length <= maxPossible) + assert(partition.length <= maxPossible) + assert(!partition.isEmpty) } } testSplitPartitions(Array.fill(100)(1), 10, 20) testSplitPartitions(Array.fill(10000)(1) ++ Array.fill(10000)(2), 20, 100) + testSplitPartitions(Array.fill(1000)(1), 250, 128) } test("coalesced RDDs") { From b3a07526fe774fd64fe3a2b9a2381eff9a3c49a3 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 21 Aug 2017 14:20:40 +0200 Subject: [PATCH 017/187] [SPARK-21718][SQL] Heavy log of type: "Skipping partition based on stats ..." ## What changes were proposed in this pull request? Reduce 'Skipping partitions' message to debug ## How was this patch tested? Existing tests Author: Sean Owen Closes #19010 from srowen/SPARK-21718. --- .../sql/execution/columnar/InMemoryTableScanExec.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 1d601374de135..c7ddec55682e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -166,12 +166,13 @@ case class InMemoryTableScanExec( if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => if (!partitionFilter.eval(cachedBatch.stats)) { - def statsString: String = schemaIndex.map { - case (a, i) => + logDebug { + val statsString = schemaIndex.map { case (a, i) => val value = cachedBatch.stats.get(i, a.dataType) s"${a.name}: $value" - }.mkString(", ") - logInfo(s"Skipping partition based on stats $statsString") + }.mkString(", ") + s"Skipping partition based on stats $statsString" + } false } else { true From 988b84d7ed43bea2616527ff050dffcf20548ab2 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Mon, 21 Aug 2017 14:35:38 +0200 Subject: [PATCH 018/187] [SPARK-21468][PYSPARK][ML] Python API for FeatureHasher Add Python API for `FeatureHasher` transformer. ## How was this patch tested? New doc test. Author: Nick Pentreath Closes #18970 from MLnick/SPARK-21468-pyspark-hasher. --- .../spark/ml/feature/FeatureHasher.scala | 16 ++-- python/pyspark/ml/feature.py | 77 +++++++++++++++++++ 2 files changed, 85 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index d22bf164c313c..4b91fa933ed9f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -64,17 +64,17 @@ import org.apache.spark.util.collection.OpenHashMap * ).toDF("real", "bool", "stringNum", "string") * * val hasher = new FeatureHasher() - * .setInputCols("real", "bool", "stringNum", "num") + * .setInputCols("real", "bool", "stringNum", "string") * .setOutputCol("features") * - * hasher.transform(df).show() + * hasher.transform(df).show(false) * - * +----+-----+---------+------+--------------------+ - * |real| bool|stringNum|string| features| - * +----+-----+---------+------+--------------------+ - * | 2.0| true| 1| foo|(262144,[51871,63...| - * | 3.0|false| 2| bar|(262144,[6031,806...| - * +----+-----+---------+------+--------------------+ + * +----+-----+---------+------+------------------------------------------------------+ + * |real|bool |stringNum|string|features | + * +----+-----+---------+------+------------------------------------------------------+ + * |2.0 |true |1 |foo |(262144,[51871,63643,174475,253195],[1.0,1.0,2.0,1.0])| + * |3.0 |false|2 |bar |(262144,[6031,80619,140467,174475],[1.0,1.0,1.0,3.0]) | + * +----+-----+---------+------+------------------------------------------------------+ * }}} */ @Experimental diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 54b4026f78bec..050537b811f61 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -34,6 +34,7 @@ 'CountVectorizer', 'CountVectorizerModel', 'DCT', 'ElementwiseProduct', + 'FeatureHasher', 'HashingTF', 'IDF', 'IDFModel', 'Imputer', 'ImputerModel', @@ -696,6 +697,82 @@ def getScalingVec(self): return self.getOrDefault(self.scalingVec) +@inherit_doc +class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, JavaMLReadable, + JavaMLWritable): + """ + .. note:: Experimental + + Feature hashing projects a set of categorical or numerical features into a feature vector of + specified dimension (typically substantially smaller than that of the original feature + space). This is done using the hashing trick (https://en.wikipedia.org/wiki/Feature_hashing) + to map features to indices in the feature vector. + + The FeatureHasher transformer operates on multiple columns. Each column may contain either + numeric or categorical features. Behavior and handling of column data types is as follows: + + * Numeric columns: + For numeric features, the hash value of the column name is used to map the + feature value to its index in the feature vector. Numeric features are never + treated as categorical, even when they are integers. You must explicitly + convert numeric columns containing categorical features to strings first. + + * String columns: + For categorical features, the hash value of the string "column_name=value" + is used to map to the vector index, with an indicator value of `1.0`. + Thus, categorical features are "one-hot" encoded + (similarly to using :py:class:`OneHotEncoder` with `dropLast=false`). + + * Boolean columns: + Boolean values are treated in the same way as string columns. That is, + boolean features are represented as "column_name=true" or "column_name=false", + with an indicator value of `1.0`. + + Null (missing) values are ignored (implicitly zero in the resulting feature vector). + + Since a simple modulo is used to transform the hash function to a vector index, + it is advisable to use a power of two as the `numFeatures` parameter; + otherwise the features will not be mapped evenly to the vector indices. + + >>> data = [(2.0, True, "1", "foo"), (3.0, False, "2", "bar")] + >>> cols = ["real", "bool", "stringNum", "string"] + >>> df = spark.createDataFrame(data, cols) + >>> hasher = FeatureHasher(inputCols=cols, outputCol="features") + >>> hasher.transform(df).head().features + SparseVector(262144, {51871: 1.0, 63643: 1.0, 174475: 2.0, 253195: 1.0}) + >>> hasherPath = temp_path + "/hasher" + >>> hasher.save(hasherPath) + >>> loadedHasher = FeatureHasher.load(hasherPath) + >>> loadedHasher.getNumFeatures() == hasher.getNumFeatures() + True + >>> loadedHasher.transform(df).head().features == hasher.transform(df).head().features + True + + .. versionadded:: 2.3.0 + """ + + @keyword_only + def __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None): + """ + __init__(self, numFeatures=1 << 18, inputCols=None, outputCol=None) + """ + super(FeatureHasher, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.FeatureHasher", self.uid) + self._setDefault(numFeatures=1 << 18) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.3.0") + def setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None): + """ + setParams(self, numFeatures=1 << 18, inputCols=None, outputCol=None) + Sets params for this FeatureHasher. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @inherit_doc class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable, JavaMLWritable): From ba843292e37368e1f5e4ae5c99ba1f5f90ca6025 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 21 Aug 2017 10:16:56 -0700 Subject: [PATCH 019/187] [SPARK-21790][TESTS][FOLLOW-UP] Add filter pushdown verification back. ## What changes were proposed in this pull request? The previous PR(https://github.com/apache/spark/pull/19000) removed filter pushdown verification, This PR add them back. ## How was this patch tested? manual tests Author: Yuming Wang Closes #19002 from wangyum/SPARK-21790-follow-up. --- .../spark/sql/jdbc/OracleIntegrationSuite.scala | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 80a129a9e0329..1b2c1b9e800ac 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -22,6 +22,7 @@ import java.util.Properties import java.math.BigDecimal import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.{WholeStageCodegenExec, RowDataSourceScanExec} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -255,6 +256,18 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo val df = dfRead.filter(dfRead.col("date_type").lt(dt)) .filter(dfRead.col("timestamp_type").lt(ts)) + val parentPlan = df.queryExecution.executedPlan + assert(parentPlan.isInstanceOf[WholeStageCodegenExec]) + val node = parentPlan.asInstanceOf[WholeStageCodegenExec] + val metadata = node.child.asInstanceOf[RowDataSourceScanExec].metadata + // The "PushedFilters" part should exist in Dataframe's + // physical plan and the existence of right literals in + // "PushedFilters" is used to prove that the predicates + // pushing down have been effective. + assert(metadata.get("PushedFilters").isDefined) + assert(metadata("PushedFilters").contains(dt.toString)) + assert(metadata("PushedFilters").contains(ts.toString)) + val row = df.collect()(0) assert(row.getDate(0).equals(dateVal)) assert(row.getTimestamp(1).equals(timestampVal)) From 84b5b16ea6c9816c70f7471a50eb5e4acb7fb1a1 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 21 Aug 2017 15:09:02 -0700 Subject: [PATCH 020/187] [SPARK-21617][SQL] Store correct table metadata when altering schema in Hive metastore. For Hive tables, the current "replace the schema" code is the correct path, except that an exception in that path should result in an error, and not in retrying in a different way. For data source tables, Spark may generate a non-compatible Hive table; but for that to work with Hive 2.1, the detection of data source tables needs to be fixed in the Hive client, to also consider the raw tables used by code such as `alterTableSchema`. Tested with existing and added unit tests (plus internal tests with a 2.1 metastore). Author: Marcelo Vanzin Closes #18849 from vanzin/SPARK-21617. --- .../sql/execution/command/DDLSuite.scala | 15 +-- .../spark/sql/hive/HiveExternalCatalog.scala | 55 +++++--- .../sql/hive/client/HiveClientImpl.scala | 3 +- .../hive/execution/Hive_2_1_DDLSuite.scala | 126 ++++++++++++++++++ 4 files changed, 171 insertions(+), 28 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 9332f773430e7..ad6fc20df1f02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2357,18 +2357,9 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { }.getMessage assert(e.contains("Found duplicate column(s)")) } else { - if (isUsingHiveMetastore) { - // hive catalog will still complains that c1 is duplicate column name because hive - // identifiers are case insensitive. - val e = intercept[AnalysisException] { - sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") - }.getMessage - assert(e.contains("HiveException")) - } else { - sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") - assert(spark.table("t1").schema - .equals(new StructType().add("c1", IntegerType).add("C1", StringType))) - } + sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") + assert(spark.table("t1").schema == + new StructType().add("c1", IntegerType).add("C1", StringType)) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 547447b31f0a1..bdbb8bccbc5cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -114,7 +114,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * should interpret these special data source properties and restore the original table metadata * before returning it. */ - private def getRawTable(db: String, table: String): CatalogTable = withClient { + private[hive] def getRawTable(db: String, table: String): CatalogTable = withClient { client.getTable(db, table) } @@ -386,6 +386,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * can be used as table properties later. */ private def tableMetaToTableProps(table: CatalogTable): mutable.Map[String, String] = { + tableMetaToTableProps(table, table.schema) + } + + private def tableMetaToTableProps( + table: CatalogTable, + schema: StructType): mutable.Map[String, String] = { val partitionColumns = table.partitionColumnNames val bucketSpec = table.bucketSpec @@ -397,7 +403,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // property. In this case, we split the JSON string and store each part as a separate table // property. val threshold = conf.get(SCHEMA_STRING_LENGTH_THRESHOLD) - val schemaJsonString = table.schema.json + val schemaJsonString = schema.json // Split the JSON string. val parts = schemaJsonString.grouped(threshold).toSeq properties.put(DATASOURCE_SCHEMA_NUMPARTS, parts.size.toString) @@ -615,20 +621,29 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat override def alterTableSchema(db: String, table: String, schema: StructType): Unit = withClient { requireTableExists(db, table) val rawTable = getRawTable(db, table) - val withNewSchema = rawTable.copy(schema = schema) - verifyColumnNames(withNewSchema) // Add table metadata such as table schema, partition columns, etc. to table properties. - val updatedTable = withNewSchema.copy( - properties = withNewSchema.properties ++ tableMetaToTableProps(withNewSchema)) - try { - client.alterTable(updatedTable) - } catch { - case NonFatal(e) => - val warningMessage = - s"Could not alter schema of table ${rawTable.identifier.quotedString} in a Hive " + - "compatible way. Updating Hive metastore in Spark SQL specific format." - logWarning(warningMessage, e) - client.alterTable(updatedTable.copy(schema = updatedTable.partitionSchema)) + val updatedProperties = rawTable.properties ++ tableMetaToTableProps(rawTable, schema) + val withNewSchema = rawTable.copy(properties = updatedProperties, schema = schema) + verifyColumnNames(withNewSchema) + + if (isDatasourceTable(rawTable)) { + // For data source tables, first try to write it with the schema set; if that does not work, + // try again with updated properties and the partition schema. This is a simplified version of + // what createDataSourceTable() does, and may leave the table in a state unreadable by Hive + // (for example, the schema does not match the data source schema, or does not match the + // storage descriptor). + try { + client.alterTable(withNewSchema) + } catch { + case NonFatal(e) => + val warningMessage = + s"Could not alter schema of table ${rawTable.identifier.quotedString} in a Hive " + + "compatible way. Updating Hive metastore in Spark SQL specific format." + logWarning(warningMessage, e) + client.alterTable(withNewSchema.copy(schema = rawTable.partitionSchema)) + } + } else { + client.alterTable(withNewSchema) } } @@ -1351,4 +1366,14 @@ object HiveExternalCatalog { getColumnNamesByType(metadata.properties, "sort", "sorting columns")) } } + + /** + * Detects a data source table. This checks both the table provider and the table properties, + * unlike DDLUtils which just checks the former. + */ + private[spark] def isDatasourceTable(table: CatalogTable): Boolean = { + val provider = table.provider.orElse(table.properties.get(DATASOURCE_PROVIDER)) + provider.isDefined && provider != Some(DDLUtils.HIVE_PROVIDER) + } + } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 995280e0e9416..7c0b9bf19bf30 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -50,6 +50,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.client.HiveClientImpl._ import org.apache.spark.sql.types._ import org.apache.spark.util.{CircularBuffer, Utils} @@ -883,7 +884,7 @@ private[hive] object HiveClientImpl { } // after SPARK-19279, it is not allowed to create a hive table with an empty schema, // so here we should not add a default col schema - if (schema.isEmpty && DDLUtils.isDatasourceTable(table)) { + if (schema.isEmpty && HiveExternalCatalog.isDatasourceTable(table)) { // This is a hack to preserve existing behavior. Before Spark 2.0, we do not // set a default serde here (this was done in Hive), and so if the user provides // an empty schema Hive would automatically populate the schema with a single diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala new file mode 100644 index 0000000000000..5c248b9acd04f --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.language.existentials + +import org.apache.hadoop.conf.Configuration +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.launcher.SparkLauncher +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.StaticSQLConf._ +import org.apache.spark.sql.types._ +import org.apache.spark.tags.ExtendedHiveTest +import org.apache.spark.util.Utils + +/** + * A separate set of DDL tests that uses Hive 2.1 libraries, which behave a little differently + * from the built-in ones. + */ +@ExtendedHiveTest +class Hive_2_1_DDLSuite extends SparkFunSuite with TestHiveSingleton with BeforeAndAfterEach + with BeforeAndAfterAll { + + // Create a custom HiveExternalCatalog instance with the desired configuration. We cannot + // use SparkSession here since there's already an active on managed by the TestHive object. + private var catalog = { + val warehouse = Utils.createTempDir() + val metastore = Utils.createTempDir() + metastore.delete() + val sparkConf = new SparkConf() + .set(SparkLauncher.SPARK_MASTER, "local") + .set(WAREHOUSE_PATH.key, warehouse.toURI().toString()) + .set(CATALOG_IMPLEMENTATION.key, "hive") + .set(HiveUtils.HIVE_METASTORE_VERSION.key, "2.1") + .set(HiveUtils.HIVE_METASTORE_JARS.key, "maven") + + val hadoopConf = new Configuration() + hadoopConf.set("hive.metastore.warehouse.dir", warehouse.toURI().toString()) + hadoopConf.set("javax.jdo.option.ConnectionURL", + s"jdbc:derby:;databaseName=${metastore.getAbsolutePath()};create=true") + // These options are needed since the defaults in Hive 2.1 cause exceptions with an + // empty metastore db. + hadoopConf.set("datanucleus.schema.autoCreateAll", "true") + hadoopConf.set("hive.metastore.schema.verification", "false") + + new HiveExternalCatalog(sparkConf, hadoopConf) + } + + override def afterEach: Unit = { + catalog.listTables("default").foreach { t => + catalog.dropTable("default", t, true, false) + } + spark.sessionState.catalog.reset() + } + + override def afterAll(): Unit = { + catalog = null + } + + test("SPARK-21617: ALTER TABLE for non-compatible DataSource tables") { + testAlterTable( + "t1", + "CREATE TABLE t1 (c1 int) USING json", + StructType(Array(StructField("c1", IntegerType), StructField("c2", IntegerType))), + hiveCompatible = false) + } + + test("SPARK-21617: ALTER TABLE for Hive-compatible DataSource tables") { + testAlterTable( + "t1", + "CREATE TABLE t1 (c1 int) USING parquet", + StructType(Array(StructField("c1", IntegerType), StructField("c2", IntegerType)))) + } + + test("SPARK-21617: ALTER TABLE for Hive tables") { + testAlterTable( + "t1", + "CREATE TABLE t1 (c1 int) STORED AS parquet", + StructType(Array(StructField("c1", IntegerType), StructField("c2", IntegerType)))) + } + + test("SPARK-21617: ALTER TABLE with incompatible schema on Hive-compatible table") { + val exception = intercept[AnalysisException] { + testAlterTable( + "t1", + "CREATE TABLE t1 (c1 string) USING parquet", + StructType(Array(StructField("c2", IntegerType)))) + } + assert(exception.getMessage().contains("types incompatible with the existing columns")) + } + + private def testAlterTable( + tableName: String, + createTableStmt: String, + updatedSchema: StructType, + hiveCompatible: Boolean = true): Unit = { + spark.sql(createTableStmt) + val oldTable = spark.sessionState.catalog.externalCatalog.getTable("default", tableName) + catalog.createTable(oldTable, true) + catalog.alterTableSchema("default", tableName, updatedSchema) + + val updatedTable = catalog.getTable("default", tableName) + assert(updatedTable.schema.fieldNames === updatedSchema.fieldNames) + } + +} From c108a5d30e821fef23709681fca7da22bc507129 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 22 Aug 2017 08:43:18 +0800 Subject: [PATCH 021/187] [SPARK-19762][ML][FOLLOWUP] Add necessary comments to L2Regularization. ## What changes were proposed in this pull request? MLlib ```LinearRegression/LogisticRegression/LinearSVC``` always standardize the data during training to improve the rate of convergence regardless of _standardization_ is true or false. If _standardization_ is false, we perform reverse standardization by penalizing each component differently to get effectively the same objective function when the training dataset is not standardized. We should keep these comments in the code to let developers understand how we handle it correctly. ## How was this patch tested? Existing tests, only adding some comments in code. Author: Yanbo Liang Closes #18992 from yanboliang/SPARK-19762. --- .../optim/loss/DifferentiableRegularization.scala | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala index 7ac7c225e5acb..929374eda13a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala @@ -39,9 +39,13 @@ private[ml] trait DifferentiableRegularization[T] extends DiffFunction[T] { * * @param regParam The magnitude of the regularization. * @param shouldApply A function (Int => Boolean) indicating whether a given index should have - * regularization applied to it. + * regularization applied to it. Usually we don't apply regularization to + * the intercept. * @param applyFeaturesStd Option for a function which maps coefficient index (column major) to the - * feature standard deviation. If `None`, no standardization is applied. + * feature standard deviation. Since we always standardize the data during + * training, if `standardization` is false, we have to reverse + * standardization by penalizing each component differently by this param. + * If `standardization` is true, this should be `None`. */ private[ml] class L2Regularization( override val regParam: Double, @@ -57,6 +61,11 @@ private[ml] class L2Regularization( val coef = coefficients(j) applyFeaturesStd match { case Some(getStd) => + // If `standardization` is false, we still standardize the data + // to improve the rate of convergence; as a result, we have to + // perform this reverse standardization by penalizing each component + // differently to get effectively the same objective function when + // the training dataset is not standardized. val std = getStd(j) if (std != 0.0) { val temp = coef / (std * std) @@ -66,6 +75,7 @@ private[ml] class L2Regularization( 0.0 } case None => + // If `standardization` is true, compute L2 regularization normally. sum += coef * coef gradient(j) = coef * regParam } From 751f513367ae776c6d6815e1ce138078924872eb Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Tue, 22 Aug 2017 11:17:53 +0900 Subject: [PATCH 022/187] [SPARK-21070][PYSPARK] Attempt to update cloudpickle again ## What changes were proposed in this pull request? Based on https://github.com/apache/spark/pull/18282 by rgbkrk this PR attempts to update to the current released cloudpickle and minimize the difference between Spark cloudpickle and "stock" cloud pickle with the goal of eventually using the stock cloud pickle. Some notable changes: * Import submodules accessed by pickled functions (cloudpipe/cloudpickle#80) * Support recursive functions inside closures (cloudpipe/cloudpickle#89, cloudpipe/cloudpickle#90) * Fix ResourceWarnings and DeprecationWarnings (cloudpipe/cloudpickle#88) * Assume modules with __file__ attribute are not dynamic (cloudpipe/cloudpickle#85) * Make cloudpickle Python 3.6 compatible (cloudpipe/cloudpickle#72) * Allow pickling of builtin methods (cloudpipe/cloudpickle#57) * Add ability to pickle dynamically created modules (cloudpipe/cloudpickle#52) * Support method descriptor (cloudpipe/cloudpickle#46) * No more pickling of closed files, was broken on Python 3 (cloudpipe/cloudpickle#32) * ** Remove non-standard __transient__check (cloudpipe/cloudpickle#110)** -- while we don't use this internally, and have no tests or documentation for its use, downstream code may use __transient__, although it has never been part of the API, if we merge this we should include a note about this in the release notes. * Support for pickling loggers (yay!) (cloudpipe/cloudpickle#96) * BUG: Fix crash when pickling dynamic class cycles. (cloudpipe/cloudpickle#102) ## How was this patch tested? Existing PySpark unit tests + the unit tests from the cloudpickle project on their own. Author: Holden Karau Author: Kyle Kelley Closes #18734 from holdenk/holden-rgbkrk-cloudpickle-upgrades. --- python/pyspark/cloudpickle.py | 599 +++++++++++++++++++++++++--------- 1 file changed, 446 insertions(+), 153 deletions(-) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 389bee7eee6e9..40e91a2d0655d 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -9,10 +9,10 @@ It does not include an unpickler, as standard python unpickling suffices. This module was extracted from the `cloud` package, developed by `PiCloud, Inc. -`_. +`_. Copyright (c) 2012, Regents of the University of California. -Copyright (c) 2009 `PiCloud, Inc. `_. +Copyright (c) 2009 `PiCloud, Inc. `_. All rights reserved. Redistribution and use in source and binary forms, with or without @@ -42,18 +42,19 @@ """ from __future__ import print_function -import operator -import opcode -import os +import dis +from functools import partial +import imp import io +import itertools +import logging +import opcode +import operator import pickle import struct import sys -import types -from functools import partial -import itertools -import dis import traceback +import types import weakref from pyspark.util import _exception_message @@ -71,6 +72,92 @@ from io import BytesIO as StringIO PY3 = True + +def _make_cell_set_template_code(): + """Get the Python compiler to emit LOAD_FAST(arg); STORE_DEREF + + Notes + ----- + In Python 3, we could use an easier function: + + .. code-block:: python + + def f(): + cell = None + + def _stub(value): + nonlocal cell + cell = value + + return _stub + + _cell_set_template_code = f() + + This function is _only_ a LOAD_FAST(arg); STORE_DEREF, but that is + invalid syntax on Python 2. If we use this function we also don't need + to do the weird freevars/cellvars swap below + """ + def inner(value): + lambda: cell # make ``cell`` a closure so that we get a STORE_DEREF + cell = value + + co = inner.__code__ + + # NOTE: we are marking the cell variable as a free variable intentionally + # so that we simulate an inner function instead of the outer function. This + # is what gives us the ``nonlocal`` behavior in a Python 2 compatible way. + if not PY3: + return types.CodeType( + co.co_argcount, + co.co_nlocals, + co.co_stacksize, + co.co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_cellvars, # this is the trickery + (), + ) + else: + return types.CodeType( + co.co_argcount, + co.co_kwonlyargcount, + co.co_nlocals, + co.co_stacksize, + co.co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_cellvars, # this is the trickery + (), + ) + + +_cell_set_template_code = _make_cell_set_template_code() + + +def cell_set(cell, value): + """Set the value of a closure cell. + """ + return types.FunctionType( + _cell_set_template_code, + {}, + '_cell_set_inner', + (), + (cell,), + )(value) + + #relevant opcodes STORE_GLOBAL = opcode.opmap['STORE_GLOBAL'] DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL'] @@ -161,6 +248,7 @@ def dump(self, obj): print_exec(sys.stderr) raise pickle.PicklingError(msg) + def save_memoryview(self, obj): """Fallback to save_string""" Pickler.save_string(self, str(obj)) @@ -186,8 +274,22 @@ def save_module(self, obj): """ Save a module as an import """ + mod_name = obj.__name__ + # If module is successfully found then it is not a dynamically created module + if hasattr(obj, '__file__'): + is_dynamic = False + else: + try: + _find_module(mod_name) + is_dynamic = False + except ImportError: + is_dynamic = True + self.modules.add(obj) - self.save_reduce(subimport, (obj.__name__,), obj=obj) + if is_dynamic: + self.save_reduce(dynamic_subimport, (obj.__name__, vars(obj)), obj=obj) + else: + self.save_reduce(subimport, (obj.__name__,), obj=obj) dispatch[types.ModuleType] = save_module def save_codeobject(self, obj): @@ -241,11 +343,32 @@ def save_function(self, obj, name=None): if getattr(themodule, name, None) is obj: return self.save_global(obj, name) + # a builtin_function_or_method which comes in as an attribute of some + # object (e.g., object.__new__, itertools.chain.from_iterable) will end + # up with modname "__main__" and so end up here. But these functions + # have no __code__ attribute in CPython, so the handling for + # user-defined functions below will fail. + # So we pickle them here using save_reduce; have to do it differently + # for different python versions. + if not hasattr(obj, '__code__'): + if PY3: + if sys.version_info < (3, 4): + raise pickle.PicklingError("Can't pickle %r" % obj) + else: + rv = obj.__reduce_ex__(self.proto) + else: + if hasattr(obj, '__self__'): + rv = (getattr, (obj.__self__, name)) + else: + raise pickle.PicklingError("Can't pickle %r" % obj) + return Pickler.save_reduce(self, obj=obj, *rv) + # if func is lambda, def'ed at prompt, is in main, or is nested, then # we'll pickle the actual function object rather than simply saving a # reference (as is done in default pickler), via save_function_tuple. - if islambda(obj) or obj.__code__.co_filename == '' or themodule is None: - #print("save global", islambda(obj), obj.__code__.co_filename, modname, themodule) + if (islambda(obj) + or getattr(obj.__code__, 'co_filename', None) == '' + or themodule is None): self.save_function_tuple(obj) return else: @@ -267,6 +390,97 @@ def save_function(self, obj, name=None): self.memoize(obj) dispatch[types.FunctionType] = save_function + def _save_subimports(self, code, top_level_dependencies): + """ + Ensure de-pickler imports any package child-modules that + are needed by the function + """ + # check if any known dependency is an imported package + for x in top_level_dependencies: + if isinstance(x, types.ModuleType) and hasattr(x, '__package__') and x.__package__: + # check if the package has any currently loaded sub-imports + prefix = x.__name__ + '.' + for name, module in sys.modules.items(): + # Older versions of pytest will add a "None" module to sys.modules. + if name is not None and name.startswith(prefix): + # check whether the function can address the sub-module + tokens = set(name[len(prefix):].split('.')) + if not tokens - set(code.co_names): + # ensure unpickler executes this import + self.save(module) + # then discards the reference to it + self.write(pickle.POP) + + def save_dynamic_class(self, obj): + """ + Save a class that can't be stored as module global. + + This method is used to serialize classes that are defined inside + functions, or that otherwise can't be serialized as attribute lookups + from global modules. + """ + clsdict = dict(obj.__dict__) # copy dict proxy to a dict + if not isinstance(clsdict.get('__dict__', None), property): + # don't extract dict that are properties + clsdict.pop('__dict__', None) + clsdict.pop('__weakref__', None) + + # hack as __new__ is stored differently in the __dict__ + new_override = clsdict.get('__new__', None) + if new_override: + clsdict['__new__'] = obj.__new__ + + # namedtuple is a special case for Spark where we use the _load_namedtuple function + if getattr(obj, '_is_namedtuple_', False): + self.save_reduce(_load_namedtuple, (obj.__name__, obj._fields)) + return + + save = self.save + write = self.write + + # We write pickle instructions explicitly here to handle the + # possibility that the type object participates in a cycle with its own + # __dict__. We first write an empty "skeleton" version of the class and + # memoize it before writing the class' __dict__ itself. We then write + # instructions to "rehydrate" the skeleton class by restoring the + # attributes from the __dict__. + # + # A type can appear in a cycle with its __dict__ if an instance of the + # type appears in the type's __dict__ (which happens for the stdlib + # Enum class), or if the type defines methods that close over the name + # of the type, (which is common for Python 2-style super() calls). + + # Push the rehydration function. + save(_rehydrate_skeleton_class) + + # Mark the start of the args for the rehydration function. + write(pickle.MARK) + + # On PyPy, __doc__ is a readonly attribute, so we need to include it in + # the initial skeleton class. This is safe because we know that the + # doc can't participate in a cycle with the original class. + doc_dict = {'__doc__': clsdict.pop('__doc__', None)} + + # Create and memoize an empty class with obj's name and bases. + save(type(obj)) + save(( + obj.__name__, + obj.__bases__, + doc_dict, + )) + write(pickle.REDUCE) + self.memoize(obj) + + # Now save the rest of obj's __dict__. Any references to obj + # encountered while saving will point to the skeleton class. + save(clsdict) + + # Write a tuple of (skeleton_class, clsdict). + write(pickle.TUPLE) + + # Call _rehydrate_skeleton_class(skeleton_class, clsdict) + write(pickle.REDUCE) + def save_function_tuple(self, func): """ Pickles an actual func object. @@ -279,17 +493,31 @@ def save_function_tuple(self, func): safe, since this won't contain a ref to the func), and memoize it as soon as it's created. The other stuff can then be filled in later. """ + if is_tornado_coroutine(func): + self.save_reduce(_rebuild_tornado_coroutine, (func.__wrapped__,), + obj=func) + return + save = self.save write = self.write - code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func) + code, f_globals, defaults, closure_values, dct, base_globals = self.extract_func_data(func) save(_fill_function) # skeleton function updater write(pickle.MARK) # beginning of tuple that _fill_function expects + self._save_subimports( + code, + itertools.chain(f_globals.values(), closure_values or ()), + ) + # create a skeleton function object and memoize it save(_make_skel_func) - save((code, closure, base_globals)) + save(( + code, + len(closure_values) if closure_values is not None else -1, + base_globals, + )) write(pickle.REDUCE) self.memoize(func) @@ -298,6 +526,7 @@ def save_function_tuple(self, func): save(defaults) save(dct) save(func.__module__) + save(closure_values) write(pickle.TUPLE) write(pickle.REDUCE) # applies _fill_function on the tuple @@ -335,7 +564,7 @@ def extract_code_globals(cls, co): def extract_func_data(self, func): """ Turn the function into a tuple of data necessary to recreate it: - code, globals, defaults, closure, dict + code, globals, defaults, closure_values, dict """ code = func.__code__ @@ -352,7 +581,11 @@ def extract_func_data(self, func): defaults = func.__defaults__ # process closure - closure = [c.cell_contents for c in func.__closure__] if func.__closure__ else [] + closure = ( + list(map(_get_cell_contents, func.__closure__)) + if func.__closure__ is not None + else None + ) # save the dict dct = func.__dict__ @@ -363,12 +596,18 @@ def extract_func_data(self, func): return (code, f_globals, defaults, closure, dct, base_globals) def save_builtin_function(self, obj): - if obj.__module__ is "__builtin__": + if obj.__module__ == "__builtin__": return self.save_global(obj) return self.save_function(obj) dispatch[types.BuiltinFunctionType] = save_builtin_function def save_global(self, obj, name=None, pack=struct.pack): + """ + Save a "global". + + The name of this method is somewhat misleading: all types get + dispatched here. + """ if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": if obj in _BUILTIN_TYPE_NAMES: return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) @@ -397,42 +636,7 @@ def save_global(self, obj, name=None, pack=struct.pack): typ = type(obj) if typ is not obj and isinstance(obj, (type, types.ClassType)): - d = dict(obj.__dict__) # copy dict proxy to a dict - if not isinstance(d.get('__dict__', None), property): - # don't extract dict that are properties - d.pop('__dict__', None) - d.pop('__weakref__', None) - - # hack as __new__ is stored differently in the __dict__ - new_override = d.get('__new__', None) - if new_override: - d['__new__'] = obj.__new__ - - # workaround for namedtuple (hijacked by PySpark) - if getattr(obj, '_is_namedtuple_', False): - self.save_reduce(_load_namedtuple, (obj.__name__, obj._fields)) - return - - self.save(_load_class) - self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj) - d.pop('__doc__', None) - # handle property and staticmethod - dd = {} - for k, v in d.items(): - if isinstance(v, property): - k = ('property', k) - v = (v.fget, v.fset, v.fdel, v.__doc__) - elif isinstance(v, staticmethod) and hasattr(v, '__func__'): - k = ('staticmethod', k) - v = v.__func__ - elif isinstance(v, classmethod) and hasattr(v, '__func__'): - k = ('classmethod', k) - v = v.__func__ - dd[k] = v - self.save(dd) - self.write(pickle.TUPLE2) - self.write(pickle.REDUCE) - + self.save_dynamic_class(obj) else: raise pickle.PicklingError("Can't pickle %r" % obj) @@ -441,18 +645,26 @@ def save_global(self, obj, name=None, pack=struct.pack): def save_instancemethod(self, obj): # Memoization rarely is ever useful due to python bounding - if PY3: - self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj) + if obj.__self__ is None: + self.save_reduce(getattr, (obj.im_class, obj.__name__)) else: - self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__), - obj=obj) + if PY3: + self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj) + else: + self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__), + obj=obj) dispatch[types.MethodType] = save_instancemethod def save_inst(self, obj): - """Inner logic to save instance. Based off pickle.save_inst - Supports __transient__""" + """Inner logic to save instance. Based off pickle.save_inst""" cls = obj.__class__ + # Try the dispatch table (pickle module doesn't do it) + f = self.dispatch.get(cls) + if f: + f(self, obj) # Call unbound method with explicit self + return + memo = self.memo write = self.write save = self.save @@ -482,13 +694,6 @@ def save_inst(self, obj): getstate = obj.__getstate__ except AttributeError: stuff = obj.__dict__ - #remove items if transient - if hasattr(obj, '__transient__'): - transient = obj.__transient__ - stuff = stuff.copy() - for k in list(stuff.keys()): - if k in transient: - del stuff[k] else: stuff = getstate() pickle._keep_alive(stuff, memo) @@ -503,6 +708,17 @@ def save_property(self, obj): self.save_reduce(property, (obj.fget, obj.fset, obj.fdel, obj.__doc__), obj=obj) dispatch[property] = save_property + def save_classmethod(self, obj): + try: + orig_func = obj.__func__ + except AttributeError: # Python 2.6 + orig_func = obj.__get__(None, object) + if isinstance(obj, classmethod): + orig_func = orig_func.__func__ # Unbind + self.save_reduce(type(obj), (orig_func,), obj=obj) + dispatch[classmethod] = save_classmethod + dispatch[staticmethod] = save_classmethod + def save_itemgetter(self, obj): """itemgetter serializer (needed for namedtuple support)""" class Dummy: @@ -540,8 +756,6 @@ def __getattribute__(self, item): def save_reduce(self, func, args, state=None, listitems=None, dictitems=None, obj=None): - """Modified to support __transient__ on new objects - Change only affects protocol level 2 (which is always used by PiCloud""" # Assert that args is a tuple or None if not isinstance(args, tuple): raise pickle.PicklingError("args from reduce() should be a tuple") @@ -555,7 +769,6 @@ def save_reduce(self, func, args, state=None, # Protocol 2 special case: if func's name is __newobj__, use NEWOBJ if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__": - #Added fix to allow transient cls = args[0] if not hasattr(cls, "__new__"): raise pickle.PicklingError( @@ -566,15 +779,6 @@ def save_reduce(self, func, args, state=None, args = args[1:] save(cls) - #Don't pickle transient entries - if hasattr(obj, '__transient__'): - transient = obj.__transient__ - state = state.copy() - - for k in list(state.keys()): - if k in transient: - del state[k] - save(args) write(pickle.NEWOBJ) else: @@ -623,72 +827,82 @@ def save_file(self, obj): return self.save_reduce(getattr, (sys,'stderr'), obj=obj) if obj is sys.stdin: raise pickle.PicklingError("Cannot pickle standard input") - if hasattr(obj, 'isatty') and obj.isatty(): + if obj.closed: + raise pickle.PicklingError("Cannot pickle closed files") + if hasattr(obj, 'isatty') and obj.isatty(): raise pickle.PicklingError("Cannot pickle files that map to tty objects") - if 'r' not in obj.mode: - raise pickle.PicklingError("Cannot pickle files that are not opened for reading") + if 'r' not in obj.mode and '+' not in obj.mode: + raise pickle.PicklingError("Cannot pickle files that are not opened for reading: %s" % obj.mode) + name = obj.name - try: - fsize = os.stat(name).st_size - except OSError: - raise pickle.PicklingError("Cannot pickle file %s as it cannot be stat" % name) - if obj.closed: - #create an empty closed string io - retval = pystringIO.StringIO("") - retval.close() - elif not fsize: #empty file - retval = pystringIO.StringIO("") - try: - tmpfile = file(name) - tst = tmpfile.read(1) - except IOError: - raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) - tmpfile.close() - if tst != '': - raise pickle.PicklingError("Cannot pickle file %s as it does not appear to map to a physical, real file" % name) - else: - try: - tmpfile = file(name) - contents = tmpfile.read() - tmpfile.close() - except IOError: - raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) - retval = pystringIO.StringIO(contents) + retval = pystringIO.StringIO() + + try: + # Read the whole file curloc = obj.tell() - retval.seek(curloc) + obj.seek(0) + contents = obj.read() + obj.seek(curloc) + except IOError: + raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) + retval.write(contents) + retval.seek(curloc) retval.name = name self.save(retval) self.memoize(obj) + def save_ellipsis(self, obj): + self.save_reduce(_gen_ellipsis, ()) + + def save_not_implemented(self, obj): + self.save_reduce(_gen_not_implemented, ()) + if PY3: dispatch[io.TextIOWrapper] = save_file else: dispatch[file] = save_file - """Special functions for Add-on libraries""" + dispatch[type(Ellipsis)] = save_ellipsis + dispatch[type(NotImplemented)] = save_not_implemented - def inject_numpy(self): - numpy = sys.modules.get('numpy') - if not numpy or not hasattr(numpy, 'ufunc'): - return - self.dispatch[numpy.ufunc] = self.__class__.save_ufunc - - def save_ufunc(self, obj): - """Hack function for saving numpy ufunc objects""" - name = obj.__name__ - numpy_tst_mods = ['numpy', 'scipy.special'] - for tst_mod_name in numpy_tst_mods: - tst_mod = sys.modules.get(tst_mod_name, None) - if tst_mod and name in tst_mod.__dict__: - return self.save_reduce(_getobject, (tst_mod_name, name)) - raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in' - % str(obj)) + # WeakSet was added in 2.7. + if hasattr(weakref, 'WeakSet'): + def save_weakset(self, obj): + self.save_reduce(weakref.WeakSet, (list(obj),)) + dispatch[weakref.WeakSet] = save_weakset + + """Special functions for Add-on libraries""" def inject_addons(self): """Plug in system. Register additional pickling functions if modules already loaded""" - self.inject_numpy() + pass + + def save_logger(self, obj): + self.save_reduce(logging.getLogger, (obj.name,), obj=obj) + + dispatch[logging.Logger] = save_logger + + +# Tornado support + +def is_tornado_coroutine(func): + """ + Return whether *func* is a Tornado coroutine function. + Running coroutines are not supported. + """ + if 'tornado.gen' not in sys.modules: + return False + gen = sys.modules['tornado.gen'] + if not hasattr(gen, "is_coroutine_function"): + # Tornado version is too old + return False + return gen.is_coroutine_function(func) + +def _rebuild_tornado_coroutine(func): + from tornado import gen + return gen.coroutine(func) # Shorthands for legacy support @@ -705,6 +919,10 @@ def dumps(obj, protocol=2): return file.getvalue() +# including pickles unloading functions in this namespace +load = pickle.load +loads = pickle.loads + #hack for __import__ not working as desired def subimport(name): @@ -712,6 +930,12 @@ def subimport(name): return sys.modules[name] +def dynamic_subimport(name, vars): + mod = imp.new_module(name) + mod.__dict__.update(vars) + sys.modules[name] = mod + return mod + # restores function attributes def _restore_attr(obj, attr): for key, val in attr.items(): @@ -755,59 +979,114 @@ def _genpartial(func, args, kwds): kwds = {} return partial(func, *args, **kwds) +def _gen_ellipsis(): + return Ellipsis -def _fill_function(func, globals, defaults, dict, module): +def _gen_not_implemented(): + return NotImplemented + + +def _get_cell_contents(cell): + try: + return cell.cell_contents + except ValueError: + # sentinel used by ``_fill_function`` which will leave the cell empty + return _empty_cell_value + + +def instance(cls): + """Create a new instance of a class. + + Parameters + ---------- + cls : type + The class to create an instance of. + + Returns + ------- + instance : cls + A new instance of ``cls``. + """ + return cls() + + +@instance +class _empty_cell_value(object): + """sentinel for empty closures + """ + @classmethod + def __reduce__(cls): + return cls.__name__ + + +def _fill_function(func, globals, defaults, dict, module, closure_values): """ Fills in the rest of function data into the skeleton function object that were created via _make_skel_func(). - """ + """ func.__globals__.update(globals) func.__defaults__ = defaults func.__dict__ = dict func.__module__ = module - return func + cells = func.__closure__ + if cells is not None: + for cell, value in zip(cells, closure_values): + if value is not _empty_cell_value: + cell_set(cell, value) + return func -def _make_cell(value): - return (lambda: value).__closure__[0] +def _make_empty_cell(): + if False: + # trick the compiler into creating an empty cell in our lambda + cell = None + raise AssertionError('this route should not be executed') -def _reconstruct_closure(values): - return tuple([_make_cell(v) for v in values]) + return (lambda: cell).__closure__[0] -def _make_skel_func(code, closures, base_globals = None): +def _make_skel_func(code, cell_count, base_globals=None): """ Creates a skeleton function object that contains just the provided code and the correct number of cells in func_closure. All other func attributes (e.g. func_globals) are empty. """ - closure = _reconstruct_closure(closures) if closures else None - if base_globals is None: base_globals = {} base_globals['__builtins__'] = __builtins__ - return types.FunctionType(code, base_globals, - None, None, closure) + closure = ( + tuple(_make_empty_cell() for _ in range(cell_count)) + if cell_count >= 0 else + None + ) + return types.FunctionType(code, base_globals, None, None, closure) -def _load_class(cls, d): - """ - Loads additional properties into class `cls`. +def _rehydrate_skeleton_class(skeleton_class, class_dict): + """Put attributes from `class_dict` back on `skeleton_class`. + + See CloudPickler.save_dynamic_class for more info. """ - for k, v in d.items(): - if isinstance(k, tuple): - typ, k = k - if typ == 'property': - v = property(*v) - elif typ == 'staticmethod': - v = staticmethod(v) - elif typ == 'classmethod': - v = classmethod(v) - setattr(cls, k, v) - return cls + for attrname, attr in class_dict.items(): + setattr(skeleton_class, attrname, attr) + return skeleton_class +def _find_module(mod_name): + """ + Iterate over each part instead of calling imp.find_module directly. + This function is able to find submodules (e.g. sickit.tree) + """ + path = None + for part in mod_name.split('.'): + if path is not None: + path = [path] + file, path, description = imp.find_module(part, path) + if file is not None: + file.close() + return path, description + def _load_namedtuple(name, fields): """ Loads a class generated by namedtuple @@ -815,10 +1094,24 @@ def _load_namedtuple(name, fields): from collections import namedtuple return namedtuple(name, fields) - """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" def _getobject(modname, attribute): mod = __import__(modname, fromlist=[attribute]) return mod.__dict__[attribute] + + +""" Use copy_reg to extend global pickle definitions """ + +if sys.version_info < (3, 4): + method_descriptor = type(str.upper) + + def _reduce_method_descriptor(obj): + return (getattr, (obj.__objclass__, obj.__name__)) + + try: + import copy_reg as copyreg + except ImportError: + import copyreg + copyreg.pickle(method_descriptor, _reduce_method_descriptor) From 5c9b3017279e4f20c364ae92a1fd059d4cfe9f4f Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Mon, 21 Aug 2017 23:08:27 -0700 Subject: [PATCH 023/187] [SPARK-21584][SQL][SPARKR] Update R method for summary to call new implementation ## What changes were proposed in this pull request? SPARK-21100 introduced a new `summary` method to the Scala/Java Dataset API that included expanded statistics (vs `describe`) and control over which statistics to compute. Currently in the R API `summary` acts as an alias for `describe`. This patch updates the R API to call the new `summary` method in the JVM that includes additional statistics and ability to select which to compute. This does not break the current interface as the present `summary` method does not take additional arguments like `describe` and the output was never meant to be used programmatically. ## How was this patch tested? Modified and additional unit tests. Author: Andrew Ray Closes #18786 from aray/summary-r. --- R/pkg/R/DataFrame.R | 44 ++++++++++++++++++++++++--- R/pkg/R/generics.R | 2 +- R/pkg/tests/fulltests/test_sparkSQL.R | 19 ++++++++---- 3 files changed, 54 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 5d6f9c042248b..80526cdd4fd45 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2930,7 +2930,7 @@ setMethod("saveAsTable", invisible(callJMethod(write, "saveAsTable", tableName)) }) -#' summary +#' describe #' #' Computes statistics for numeric and string columns. #' If no columns are given, this function computes statistics for all numerical or string columns. @@ -2941,7 +2941,7 @@ setMethod("saveAsTable", #' @return A SparkDataFrame. #' @family SparkDataFrame functions #' @aliases describe,SparkDataFrame,character-method describe,SparkDataFrame,ANY-method -#' @rdname summary +#' @rdname describe #' @name describe #' @export #' @examples @@ -2953,6 +2953,7 @@ setMethod("saveAsTable", #' describe(df, "col1") #' describe(df, "col1", "col2") #' } +#' @seealso See \link{summary} for expanded statistics and control over which statistics to compute. #' @note describe(SparkDataFrame, character) since 1.4.0 setMethod("describe", signature(x = "SparkDataFrame", col = "character"), @@ -2962,7 +2963,7 @@ setMethod("describe", dataFrame(sdf) }) -#' @rdname summary +#' @rdname describe #' @name describe #' @aliases describe,SparkDataFrame-method #' @note describe(SparkDataFrame) since 1.4.0 @@ -2973,15 +2974,50 @@ setMethod("describe", dataFrame(sdf) }) +#' summary +#' +#' Computes specified statistics for numeric and string columns. Available statistics are: +#' \itemize{ +#' \item count +#' \item mean +#' \item stddev +#' \item min +#' \item max +#' \item arbitrary approximate percentiles specified as a percentage (eg, "75%") +#' } +#' If no statistics are given, this function computes count, mean, stddev, min, +#' approximate quartiles (percentiles at 25%, 50%, and 75%), and max. +#' This function is meant for exploratory data analysis, as we make no guarantee about the +#' backward compatibility of the schema of the resulting Dataset. If you want to +#' programmatically compute summary statistics, use the \code{agg} function instead. +#' +#' #' @param object a SparkDataFrame to be summarized. +#' @param ... (optional) statistics to be computed for all columns. +#' @return A SparkDataFrame. +#' @family SparkDataFrame functions #' @rdname summary #' @name summary #' @aliases summary,SparkDataFrame-method +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' summary(df) +#' summary(df, "min", "25%", "75%", "max") +#' summary(select(df, "age", "height")) +#' } #' @note summary(SparkDataFrame) since 1.5.0 +#' @note The statistics provided by \code{summary} were change in 2.3.0 use \link{describe} for previous defaults. +#' @seealso \link{describe} setMethod("summary", signature(object = "SparkDataFrame"), function(object, ...) { - describe(object) + statisticsList <- list(...) + sdf <- callJMethod(object@sdf, "summary", statisticsList) + dataFrame(sdf) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index df91c35f7d851..f0cc2dc3f195a 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -521,7 +521,7 @@ setGeneric("gapplyCollect", function(x, ...) { standardGeneric("gapplyCollect") # @export setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") }) -#' @rdname summary +#' @rdname describe #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index deb0e163a8d58..d477fc6a4256c 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2497,7 +2497,7 @@ test_that("read/write text files - compression option", { unlink(textPath) }) -test_that("describe() and summarize() on a DataFrame", { +test_that("describe() and summary() on a DataFrame", { df <- read.json(jsonPath) stats <- describe(df, "age") expect_equal(collect(stats)[1, "summary"], "count") @@ -2508,8 +2508,15 @@ test_that("describe() and summarize() on a DataFrame", { expect_equal(collect(stats)[5, "age"], "30") stats2 <- summary(df) - expect_equal(collect(stats2)[4, "summary"], "min") - expect_equal(collect(stats2)[5, "age"], "30") + expect_equal(collect(stats2)[5, "summary"], "25%") + expect_equal(collect(stats2)[5, "age"], "30.0") + + stats3 <- summary(df, "min", "max", "55.1%") + + expect_equal(collect(stats3)[1, "summary"], "min") + expect_equal(collect(stats3)[2, "summary"], "max") + expect_equal(collect(stats3)[3, "summary"], "55.1%") + expect_equal(collect(stats3)[3, "age"], "30.0") # SPARK-16425: SparkR summary() fails on column of type logical df <- withColumn(df, "boolean", df$age == 30) @@ -2742,15 +2749,15 @@ test_that("attach() on a DataFrame", { expected_age <- data.frame(age = c(NA, 30, 19)) expect_equal(head(age), expected_age) stat <- summary(age) - expect_equal(collect(stat)[5, "age"], "30") + expect_equal(collect(stat)[8, "age"], "30") age <- age$age + 1 expect_is(age, "Column") rm(age) stat2 <- summary(age) - expect_equal(collect(stat2)[5, "age"], "30") + expect_equal(collect(stat2)[8, "age"], "30") detach("df") stat3 <- summary(df[, "age", drop = F]) - expect_equal(collect(stat3)[5, "age"], "30") + expect_equal(collect(stat3)[8, "age"], "30") expect_error(age) }) From be72b157ea13ea116c5178a9e41e37ae24090f72 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 22 Aug 2017 17:54:39 +0800 Subject: [PATCH 024/187] [SPARK-21803][TEST] Remove the HiveDDLCommandSuite ## What changes were proposed in this pull request? We do not have any Hive-specific parser. It does not make sense to keep a parser-specific test suite `HiveDDLCommandSuite.scala` in the Hive package. This PR is to remove it. ## How was this patch tested? N/A Author: gatorsmile Closes #19015 from gatorsmile/combineDDL. --- ...ommandSuite.scala => DDLParserSuite.scala} | 572 +++++++++++++- .../spark/sql/hive/HiveDDLCommandSuite.scala | 739 ------------------ .../apache/spark/sql/hive/TestHiveSuite.scala | 4 + .../sql/hive/execution/HiveDDLSuite.scala | 11 + .../sql/hive/execution/HiveSerDeSuite.scala | 133 +++- 5 files changed, 716 insertions(+), 743 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/execution/command/{DDLCommandSuite.scala => DDLParserSuite.scala} (62%) delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala similarity index 62% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 5643c58d9f847..70df7607a713f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -22,19 +22,25 @@ import java.util.Locale import scala.reflect.{classTag, ClassTag} +import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans +import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan +import org.apache.spark.sql.catalyst.expressions.JsonTuple import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, Project, ScriptTransformation} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -// TODO: merge this with DDLSuite (SPARK-14441) -class DDLCommandSuite extends PlanTest { +class DDLParserSuite extends PlanTest with SharedSQLContext { private lazy val parser = new SparkSqlParser(new SQLConf) private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = { @@ -56,6 +62,17 @@ class DDLCommandSuite extends PlanTest { } } + private def compareTransformQuery(sql: String, expected: LogicalPlan): Unit = { + val plan = parser.parsePlan(sql).asInstanceOf[ScriptTransformation].copy(ioschema = null) + comparePlans(plan, expected, checkAnalysis = false) + } + + private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { + parser.parsePlan(sql).collect { + case CreateTable(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore) + }.head + } + test("create database") { val sql = """ @@ -1046,4 +1063,553 @@ class DDLCommandSuite extends PlanTest { s"got ${other.getClass.getName}: $sql") } } + + test("Test CTAS #1") { + val s1 = + """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |COMMENT 'This is the staging page view table' + |STORED AS RCFILE + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src""".stripMargin + + val (desc, exists) = extractTableDesc(s1) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + // TODO will be SQLText + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == + Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } + + test("Test CTAS #2") { + val s2 = + """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |COMMENT 'This is the staging page view table' + |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' + | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src""".stripMargin + + val (desc, exists) = extractTableDesc(s2) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + // TODO will be SQLText + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.properties == Map()) + assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) + assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) + assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } + + test("Test CTAS #3") { + val s3 = """CREATE TABLE page_view AS SELECT * FROM src""" + val (desc, exists) = extractTableDesc(s3) + assert(exists == false) + assert(desc.identifier.database == None) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.MANAGED) + assert(desc.storage.locationUri == None) + assert(desc.schema.isEmpty) + assert(desc.viewText == None) // TODO will be SQLText + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.storage.properties == Map()) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + assert(desc.properties == Map()) + } + + test("Test CTAS #4") { + val s4 = + """CREATE TABLE page_view + |STORED BY 'storage.handler.class.name' AS SELECT * FROM src""".stripMargin + intercept[AnalysisException] { + extractTableDesc(s4) + } + } + + test("Test CTAS #5") { + val s5 = """CREATE TABLE ctas2 + | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" + | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") + | STORED AS RCFile + | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") + | AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin + val (desc, exists) = extractTableDesc(s5) + assert(exists == false) + assert(desc.identifier.database == None) + assert(desc.identifier.table == "ctas2") + assert(desc.tableType == CatalogTableType.MANAGED) + assert(desc.storage.locationUri == None) + assert(desc.schema.isEmpty) + assert(desc.viewText == None) // TODO will be SQLText + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.storage.properties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2"))) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) + assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) + } + + test("CTAS statement with a PARTITIONED BY clause is not allowed") { + assertUnsupported(s"CREATE TABLE ctas1 PARTITIONED BY (k int)" + + " AS SELECT key, value FROM (SELECT 1 as key, 2 as value) tmp") + } + + test("CTAS statement with schema") { + assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT * FROM src") + assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT 1, 'hello'") + } + + test("unsupported operations") { + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE TEMPORARY TABLE ctas2 + |ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" + |WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") + |STORED AS RCFile + |TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") + |AS SELECT key, value FROM src ORDER BY key, value + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING) + |CLUSTERED BY(user_id) INTO 256 BUCKETS + |AS SELECT key, value FROM src ORDER BY key, value + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING) + |SKEWED BY (key) ON (1,5,6) + |AS SELECT key, value FROM src ORDER BY key, value + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """ + |SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.contrib.serde2.TypedBytesSerDe' + |RECORDREADER 'org.apache.hadoop.hive.contrib.util.typedbytes.TypedBytesRecordReader' + |FROM testData + """.stripMargin) + } + } + + test("Invalid interval term should throw AnalysisException") { + def assertError(sql: String, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + parser.parsePlan(sql) + } + assert(e.getMessage.contains(errorMessage)) + } + assertError("select interval '42-32' year to month", + "month 32 outside range [0, 11]") + assertError("select interval '5 49:12:15' day to second", + "hour 49 outside range [0, 23]") + assertError("select interval '.1111111111' second", + "nanosecond 1111111111 outside range") + } + + test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") { + val analyzer = spark.sessionState.analyzer + val plan = analyzer.execute(parser.parsePlan( + """ + |SELECT * + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b + """.stripMargin)) + + assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple]) + } + + test("transform query spec") { + val p = ScriptTransformation( + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), + "func", Seq.empty, plans.table("e"), null) + + compareTransformQuery("select transform(a, b) using 'func' from e where f < 10", + p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) + compareTransformQuery("map a, b using 'func' as c, d from e", + p.copy(output = Seq('c.string, 'd.string))) + compareTransformQuery("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e", + p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + } + + test("use backticks in output of Script Transform") { + parser.parsePlan( + """SELECT `t`.`thing1` + |FROM (SELECT TRANSFORM (`parquet_t1`.`key`, `parquet_t1`.`value`) + |USING 'cat' AS (`thing1` int, `thing2` string) FROM `default`.`parquet_t1`) AS t + """.stripMargin) + } + + test("use backticks in output of Generator") { + parser.parsePlan( + """ + |SELECT `gentab2`.`gencol2` + |FROM `default`.`src` + |LATERAL VIEW explode(array(array(1, 2, 3))) `gentab1` AS `gencol1` + |LATERAL VIEW explode(`gentab1`.`gencol1`) `gentab2` AS `gencol2` + """.stripMargin) + } + + test("use escaped backticks in output of Generator") { + parser.parsePlan( + """ + |SELECT `gen``tab2`.`gen``col2` + |FROM `default`.`src` + |LATERAL VIEW explode(array(array(1, 2, 3))) `gen``tab1` AS `gen``col1` + |LATERAL VIEW explode(`gen``tab1`.`gen``col1`) `gen``tab2` AS `gen``col2` + """.stripMargin) + } + + test("create table - basic") { + val query = "CREATE TABLE my_table (id int, name string)" + val (desc, allowExisting) = extractTableDesc(query) + assert(!allowExisting) + assert(desc.identifier.database.isEmpty) + assert(desc.identifier.table == "my_table") + assert(desc.tableType == CatalogTableType.MANAGED) + assert(desc.schema == new StructType().add("id", "int").add("name", "string")) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.bucketSpec.isEmpty) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.storage.locationUri.isEmpty) + assert(desc.storage.inputFormat == + Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + assert(desc.storage.properties.isEmpty) + assert(desc.properties.isEmpty) + assert(desc.comment.isEmpty) + } + + test("create table - with database name") { + val query = "CREATE TABLE dbx.my_table (id int, name string)" + val (desc, _) = extractTableDesc(query) + assert(desc.identifier.database == Some("dbx")) + assert(desc.identifier.table == "my_table") + } + + test("create table - temporary") { + val query = "CREATE TEMPORARY TABLE tab1 (id int, name string)" + val e = intercept[ParseException] { parser.parsePlan(query) } + assert(e.message.contains("CREATE TEMPORARY TABLE is not supported yet")) + } + + test("create table - external") { + val query = "CREATE EXTERNAL TABLE tab1 (id int, name string) LOCATION '/path/to/nowhere'" + val (desc, _) = extractTableDesc(query) + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/path/to/nowhere"))) + } + + test("create table - if not exists") { + val query = "CREATE TABLE IF NOT EXISTS tab1 (id int, name string)" + val (_, allowExisting) = extractTableDesc(query) + assert(allowExisting) + } + + test("create table - comment") { + val query = "CREATE TABLE my_table (id int, name string) COMMENT 'its hot as hell below'" + val (desc, _) = extractTableDesc(query) + assert(desc.comment == Some("its hot as hell below")) + } + + test("create table - partitioned columns") { + val query = "CREATE TABLE my_table (id int, name string) PARTITIONED BY (month int)" + val (desc, _) = extractTableDesc(query) + assert(desc.schema == new StructType() + .add("id", "int") + .add("name", "string") + .add("month", "int")) + assert(desc.partitionColumnNames == Seq("month")) + } + + test("create table - clustered by") { + val numBuckets = 10 + val bucketedColumn = "id" + val sortColumn = "id" + val baseQuery = + s""" + CREATE TABLE my_table ( + $bucketedColumn int, + name string) + CLUSTERED BY($bucketedColumn) + """ + + val query1 = s"$baseQuery INTO $numBuckets BUCKETS" + val (desc1, _) = extractTableDesc(query1) + assert(desc1.bucketSpec.isDefined) + val bucketSpec1 = desc1.bucketSpec.get + assert(bucketSpec1.numBuckets == numBuckets) + assert(bucketSpec1.bucketColumnNames.head.equals(bucketedColumn)) + assert(bucketSpec1.sortColumnNames.isEmpty) + + val query2 = s"$baseQuery SORTED BY($sortColumn) INTO $numBuckets BUCKETS" + val (desc2, _) = extractTableDesc(query2) + assert(desc2.bucketSpec.isDefined) + val bucketSpec2 = desc2.bucketSpec.get + assert(bucketSpec2.numBuckets == numBuckets) + assert(bucketSpec2.bucketColumnNames.head.equals(bucketedColumn)) + assert(bucketSpec2.sortColumnNames.head.equals(sortColumn)) + } + + test("create table - skewed by") { + val baseQuery = "CREATE TABLE my_table (id int, name string) SKEWED BY" + val query1 = s"$baseQuery(id) ON (1, 10, 100)" + val query2 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z'))" + val query3 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z')) STORED AS DIRECTORIES" + val e1 = intercept[ParseException] { parser.parsePlan(query1) } + val e2 = intercept[ParseException] { parser.parsePlan(query2) } + val e3 = intercept[ParseException] { parser.parsePlan(query3) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + assert(e3.getMessage.contains("Operation not allowed")) + } + + test("create table - row format") { + val baseQuery = "CREATE TABLE my_table (id int, name string) ROW FORMAT" + val query1 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff'" + val query2 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1')" + val query3 = + s""" + |$baseQuery DELIMITED FIELDS TERMINATED BY 'x' ESCAPED BY 'y' + |COLLECTION ITEMS TERMINATED BY 'a' + |MAP KEYS TERMINATED BY 'b' + |LINES TERMINATED BY '\n' + |NULL DEFINED AS 'c' + """.stripMargin + val (desc1, _) = extractTableDesc(query1) + val (desc2, _) = extractTableDesc(query2) + val (desc3, _) = extractTableDesc(query3) + assert(desc1.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc1.storage.properties.isEmpty) + assert(desc2.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc2.storage.properties == Map("k1" -> "v1")) + assert(desc3.storage.properties == Map( + "field.delim" -> "x", + "escape.delim" -> "y", + "serialization.format" -> "x", + "line.delim" -> "\n", + "colelction.delim" -> "a", // yes, it's a typo from Hive :) + "mapkey.delim" -> "b")) + } + + test("create table - file format") { + val baseQuery = "CREATE TABLE my_table (id int, name string) STORED AS" + val query1 = s"$baseQuery INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput'" + val query2 = s"$baseQuery ORC" + val (desc1, _) = extractTableDesc(query1) + val (desc2, _) = extractTableDesc(query2) + assert(desc1.storage.inputFormat == Some("winput")) + assert(desc1.storage.outputFormat == Some("wowput")) + assert(desc1.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + assert(desc2.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) + assert(desc2.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + assert(desc2.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } + + test("create table - storage handler") { + val baseQuery = "CREATE TABLE my_table (id int, name string) STORED BY" + val query1 = s"$baseQuery 'org.papachi.StorageHandler'" + val query2 = s"$baseQuery 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')" + val e1 = intercept[ParseException] { parser.parsePlan(query1) } + val e2 = intercept[ParseException] { parser.parsePlan(query2) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + } + + test("create table - properties") { + val query = "CREATE TABLE my_table (id int, name string) TBLPROPERTIES ('k1'='v1', 'k2'='v2')" + val (desc, _) = extractTableDesc(query) + assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) + } + + test("create table - everything!") { + val query = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS dbx.my_table (id int, name string) + |COMMENT 'no comment' + |PARTITIONED BY (month int) + |ROW FORMAT SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1') + |STORED AS INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput' + |LOCATION '/path/to/mercury' + |TBLPROPERTIES ('k1'='v1', 'k2'='v2') + """.stripMargin + val (desc, allowExisting) = extractTableDesc(query) + assert(allowExisting) + assert(desc.identifier.database == Some("dbx")) + assert(desc.identifier.table == "my_table") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.schema == new StructType() + .add("id", "int") + .add("name", "string") + .add("month", "int")) + assert(desc.partitionColumnNames == Seq("month")) + assert(desc.bucketSpec.isEmpty) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.storage.locationUri == Some(new URI("/path/to/mercury"))) + assert(desc.storage.inputFormat == Some("winput")) + assert(desc.storage.outputFormat == Some("wowput")) + assert(desc.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc.storage.properties == Map("k1" -> "v1")) + assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) + assert(desc.comment == Some("no comment")) + } + + test("create view -- basic") { + val v1 = "CREATE VIEW view1 AS SELECT * FROM tab1" + val command = parser.parsePlan(v1).asInstanceOf[CreateViewCommand] + assert(!command.allowExisting) + assert(command.name.database.isEmpty) + assert(command.name.table == "view1") + assert(command.originalText == Some("SELECT * FROM tab1")) + assert(command.userSpecifiedColumns.isEmpty) + } + + test("create view - full") { + val v1 = + """ + |CREATE OR REPLACE VIEW view1 + |(col1, col3 COMMENT 'hello') + |COMMENT 'BLABLA' + |TBLPROPERTIES('prop1Key'="prop1Val") + |AS SELECT * FROM tab1 + """.stripMargin + val command = parser.parsePlan(v1).asInstanceOf[CreateViewCommand] + assert(command.name.database.isEmpty) + assert(command.name.table == "view1") + assert(command.userSpecifiedColumns == Seq("col1" -> None, "col3" -> Some("hello"))) + assert(command.originalText == Some("SELECT * FROM tab1")) + assert(command.properties == Map("prop1Key" -> "prop1Val")) + assert(command.comment == Some("BLABLA")) + } + + test("create view -- partitioned view") { + val v1 = "CREATE VIEW view1 partitioned on (ds, hr) as select * from srcpart" + intercept[ParseException] { + parser.parsePlan(v1) + } + } + + test("MSCK REPAIR table") { + val sql = "MSCK REPAIR TABLE tab1" + val parsed = parser.parsePlan(sql) + val expected = AlterTableRecoverPartitionsCommand( + TableIdentifier("tab1", None), + "MSCK REPAIR TABLE") + comparePlans(parsed, expected) + } + + test("create table like") { + val v1 = "CREATE TABLE table1 LIKE table2" + val (target, source, location, exists) = parser.parsePlan(v1).collect { + case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting) + }.head + assert(exists == false) + assert(target.database.isEmpty) + assert(target.table == "table1") + assert(source.database.isEmpty) + assert(source.table == "table2") + assert(location.isEmpty) + + val v2 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2" + val (target2, source2, location2, exists2) = parser.parsePlan(v2).collect { + case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting) + }.head + assert(exists2) + assert(target2.database.isEmpty) + assert(target2.table == "table1") + assert(source2.database.isEmpty) + assert(source2.table == "table2") + assert(location2.isEmpty) + + val v3 = "CREATE TABLE table1 LIKE table2 LOCATION '/spark/warehouse'" + val (target3, source3, location3, exists3) = parser.parsePlan(v3).collect { + case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting) + }.head + assert(!exists3) + assert(target3.database.isEmpty) + assert(target3.table == "table1") + assert(source3.database.isEmpty) + assert(source3.table == "table2") + assert(location3 == Some("/spark/warehouse")) + + val v4 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2 LOCATION '/spark/warehouse'" + val (target4, source4, location4, exists4) = parser.parsePlan(v4).collect { + case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting) + }.head + assert(exists4) + assert(target4.database.isEmpty) + assert(target4.table == "table1") + assert(source4.database.isEmpty) + assert(source4.table == "table2") + assert(location4 == Some("/spark/warehouse")) + } + + test("load data") { + val v1 = "LOAD DATA INPATH 'path' INTO TABLE table1" + val (table, path, isLocal, isOverwrite, partition) = parser.parsePlan(v1).collect { + case LoadDataCommand(t, path, l, o, partition) => (t, path, l, o, partition) + }.head + assert(table.database.isEmpty) + assert(table.table == "table1") + assert(path == "path") + assert(!isLocal) + assert(!isOverwrite) + assert(partition.isEmpty) + + val v2 = "LOAD DATA LOCAL INPATH 'path' OVERWRITE INTO TABLE table1 PARTITION(c='1', d='2')" + val (table2, path2, isLocal2, isOverwrite2, partition2) = parser.parsePlan(v2).collect { + case LoadDataCommand(t, path, l, o, partition) => (t, path, l, o, partition) + }.head + assert(table2.database.isEmpty) + assert(table2.table == "table1") + assert(path2 == "path") + assert(isLocal2) + assert(isOverwrite2) + assert(partition2.nonEmpty) + assert(partition2.get.apply("c") == "1" && partition2.get.apply("d") == "2") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala deleted file mode 100644 index bee470d8e1382..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ /dev/null @@ -1,739 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.net.URI -import java.util.Locale - -import org.apache.spark.sql.{AnalysisException, SaveMode} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans -import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.JsonTuple -import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, ScriptTransformation} -import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources.CreateTable -import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.StructType - -class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingleton { - val parser = TestHive.sessionState.sqlParser - - private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { - parser.parsePlan(sql).collect { - case CreateTable(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore) - }.head - } - - private def assertUnsupported(sql: String): Unit = { - val e = intercept[ParseException] { - parser.parsePlan(sql) - } - assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) - } - - private def analyzeCreateTable(sql: String): CatalogTable = { - TestHive.sessionState.analyzer.execute(parser.parsePlan(sql)).collect { - case CreateTableCommand(tableDesc, _) => tableDesc - }.head - } - - private def compareTransformQuery(sql: String, expected: LogicalPlan): Unit = { - val plan = parser.parsePlan(sql).asInstanceOf[ScriptTransformation].copy(ioschema = null) - comparePlans(plan, expected, checkAnalysis = false) - } - - test("Test CTAS #1") { - val s1 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view - |COMMENT 'This is the staging page view table' - |STORED AS RCFILE - |LOCATION '/user/external/page_view' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin - - val (desc, exists) = extractTableDesc(s1) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - assert(desc.comment == Some("This is the staging page view table")) - // TODO will be SQLText - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.storage.serde == - Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) - } - - test("Test CTAS #2") { - val s2 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view - |COMMENT 'This is the staging page view table' - |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' - | STORED AS - | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' - | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' - |LOCATION '/user/external/page_view' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin - - val (desc, exists) = extractTableDesc(s2) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - // TODO will be SQLText - assert(desc.comment == Some("This is the staging page view table")) - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.properties == Map()) - assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) - assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) - assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) - } - - test("Test CTAS #3") { - val s3 = """CREATE TABLE page_view AS SELECT * FROM src""" - val (desc, exists) = extractTableDesc(s3) - assert(exists == false) - assert(desc.identifier.database == None) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.MANAGED) - assert(desc.storage.locationUri == None) - assert(desc.schema.isEmpty) - assert(desc.viewText == None) // TODO will be SQLText - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.storage.properties == Map()) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) - assert(desc.storage.outputFormat == - Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) - assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - assert(desc.properties == Map()) - } - - test("Test CTAS #4") { - val s4 = - """CREATE TABLE page_view - |STORED BY 'storage.handler.class.name' AS SELECT * FROM src""".stripMargin - intercept[AnalysisException] { - extractTableDesc(s4) - } - } - - test("Test CTAS #5") { - val s5 = """CREATE TABLE ctas2 - | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" - | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") - | STORED AS RCFile - | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") - | AS - | SELECT key, value - | FROM src - | ORDER BY key, value""".stripMargin - val (desc, exists) = extractTableDesc(s5) - assert(exists == false) - assert(desc.identifier.database == None) - assert(desc.identifier.table == "ctas2") - assert(desc.tableType == CatalogTableType.MANAGED) - assert(desc.storage.locationUri == None) - assert(desc.schema.isEmpty) - assert(desc.viewText == None) // TODO will be SQLText - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.storage.properties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2"))) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) - assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) - } - - test("CTAS statement with a PARTITIONED BY clause is not allowed") { - assertUnsupported(s"CREATE TABLE ctas1 PARTITIONED BY (k int)" + - " AS SELECT key, value FROM (SELECT 1 as key, 2 as value) tmp") - } - - test("CTAS statement with schema") { - assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT * FROM src") - assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT 1, 'hello'") - } - - test("unsupported operations") { - intercept[ParseException] { - parser.parsePlan( - """ - |CREATE TEMPORARY TABLE ctas2 - |ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" - |WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") - |STORED AS RCFile - |TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") - |AS SELECT key, value FROM src ORDER BY key, value - """.stripMargin) - } - intercept[ParseException] { - parser.parsePlan( - """ - |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING) - |CLUSTERED BY(user_id) INTO 256 BUCKETS - |AS SELECT key, value FROM src ORDER BY key, value - """.stripMargin) - } - intercept[ParseException] { - parser.parsePlan( - """ - |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING) - |SKEWED BY (key) ON (1,5,6) - |AS SELECT key, value FROM src ORDER BY key, value - """.stripMargin) - } - intercept[ParseException] { - parser.parsePlan( - """ - |SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.contrib.serde2.TypedBytesSerDe' - |RECORDREADER 'org.apache.hadoop.hive.contrib.util.typedbytes.TypedBytesRecordReader' - |FROM testData - """.stripMargin) - } - } - - test("Invalid interval term should throw AnalysisException") { - def assertError(sql: String, errorMessage: String): Unit = { - val e = intercept[AnalysisException] { - parser.parsePlan(sql) - } - assert(e.getMessage.contains(errorMessage)) - } - assertError("select interval '42-32' year to month", - "month 32 outside range [0, 11]") - assertError("select interval '5 49:12:15' day to second", - "hour 49 outside range [0, 23]") - assertError("select interval '.1111111111' second", - "nanosecond 1111111111 outside range") - } - - test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") { - val analyzer = TestHive.sparkSession.sessionState.analyzer - val plan = analyzer.execute(parser.parsePlan( - """ - |SELECT * - |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test - |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b - """.stripMargin)) - - assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple]) - } - - test("transform query spec") { - val p = ScriptTransformation( - Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), - "func", Seq.empty, plans.table("e"), null) - - compareTransformQuery("select transform(a, b) using 'func' from e where f < 10", - p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) - compareTransformQuery("map a, b using 'func' as c, d from e", - p.copy(output = Seq('c.string, 'd.string))) - compareTransformQuery("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e", - p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) - } - - test("use backticks in output of Script Transform") { - parser.parsePlan( - """SELECT `t`.`thing1` - |FROM (SELECT TRANSFORM (`parquet_t1`.`key`, `parquet_t1`.`value`) - |USING 'cat' AS (`thing1` int, `thing2` string) FROM `default`.`parquet_t1`) AS t - """.stripMargin) - } - - test("use backticks in output of Generator") { - parser.parsePlan( - """ - |SELECT `gentab2`.`gencol2` - |FROM `default`.`src` - |LATERAL VIEW explode(array(array(1, 2, 3))) `gentab1` AS `gencol1` - |LATERAL VIEW explode(`gentab1`.`gencol1`) `gentab2` AS `gencol2` - """.stripMargin) - } - - test("use escaped backticks in output of Generator") { - parser.parsePlan( - """ - |SELECT `gen``tab2`.`gen``col2` - |FROM `default`.`src` - |LATERAL VIEW explode(array(array(1, 2, 3))) `gen``tab1` AS `gen``col1` - |LATERAL VIEW explode(`gen``tab1`.`gen``col1`) `gen``tab2` AS `gen``col2` - """.stripMargin) - } - - test("create table - basic") { - val query = "CREATE TABLE my_table (id int, name string)" - val (desc, allowExisting) = extractTableDesc(query) - assert(!allowExisting) - assert(desc.identifier.database.isEmpty) - assert(desc.identifier.table == "my_table") - assert(desc.tableType == CatalogTableType.MANAGED) - assert(desc.schema == new StructType().add("id", "int").add("name", "string")) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.bucketSpec.isEmpty) - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.storage.locationUri.isEmpty) - assert(desc.storage.inputFormat == - Some("org.apache.hadoop.mapred.TextInputFormat")) - assert(desc.storage.outputFormat == - Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) - assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - assert(desc.storage.properties.isEmpty) - assert(desc.properties.isEmpty) - assert(desc.comment.isEmpty) - } - - test("create table - with database name") { - val query = "CREATE TABLE dbx.my_table (id int, name string)" - val (desc, _) = extractTableDesc(query) - assert(desc.identifier.database == Some("dbx")) - assert(desc.identifier.table == "my_table") - } - - test("create table - temporary") { - val query = "CREATE TEMPORARY TABLE tab1 (id int, name string)" - val e = intercept[ParseException] { parser.parsePlan(query) } - assert(e.message.contains("CREATE TEMPORARY TABLE is not supported yet")) - } - - test("create table - external") { - val query = "CREATE EXTERNAL TABLE tab1 (id int, name string) LOCATION '/path/to/nowhere'" - val (desc, _) = extractTableDesc(query) - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/path/to/nowhere"))) - } - - test("create table - if not exists") { - val query = "CREATE TABLE IF NOT EXISTS tab1 (id int, name string)" - val (_, allowExisting) = extractTableDesc(query) - assert(allowExisting) - } - - test("create table - comment") { - val query = "CREATE TABLE my_table (id int, name string) COMMENT 'its hot as hell below'" - val (desc, _) = extractTableDesc(query) - assert(desc.comment == Some("its hot as hell below")) - } - - test("create table - partitioned columns") { - val query = "CREATE TABLE my_table (id int, name string) PARTITIONED BY (month int)" - val (desc, _) = extractTableDesc(query) - assert(desc.schema == new StructType() - .add("id", "int") - .add("name", "string") - .add("month", "int")) - assert(desc.partitionColumnNames == Seq("month")) - } - - test("create table - clustered by") { - val numBuckets = 10 - val bucketedColumn = "id" - val sortColumn = "id" - val baseQuery = - s""" - CREATE TABLE my_table ( - $bucketedColumn int, - name string) - CLUSTERED BY($bucketedColumn) - """ - - val query1 = s"$baseQuery INTO $numBuckets BUCKETS" - val (desc1, _) = extractTableDesc(query1) - assert(desc1.bucketSpec.isDefined) - val bucketSpec1 = desc1.bucketSpec.get - assert(bucketSpec1.numBuckets == numBuckets) - assert(bucketSpec1.bucketColumnNames.head.equals(bucketedColumn)) - assert(bucketSpec1.sortColumnNames.isEmpty) - - val query2 = s"$baseQuery SORTED BY($sortColumn) INTO $numBuckets BUCKETS" - val (desc2, _) = extractTableDesc(query2) - assert(desc2.bucketSpec.isDefined) - val bucketSpec2 = desc2.bucketSpec.get - assert(bucketSpec2.numBuckets == numBuckets) - assert(bucketSpec2.bucketColumnNames.head.equals(bucketedColumn)) - assert(bucketSpec2.sortColumnNames.head.equals(sortColumn)) - } - - test("create table - skewed by") { - val baseQuery = "CREATE TABLE my_table (id int, name string) SKEWED BY" - val query1 = s"$baseQuery(id) ON (1, 10, 100)" - val query2 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z'))" - val query3 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z')) STORED AS DIRECTORIES" - val e1 = intercept[ParseException] { parser.parsePlan(query1) } - val e2 = intercept[ParseException] { parser.parsePlan(query2) } - val e3 = intercept[ParseException] { parser.parsePlan(query3) } - assert(e1.getMessage.contains("Operation not allowed")) - assert(e2.getMessage.contains("Operation not allowed")) - assert(e3.getMessage.contains("Operation not allowed")) - } - - test("create table - row format") { - val baseQuery = "CREATE TABLE my_table (id int, name string) ROW FORMAT" - val query1 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff'" - val query2 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1')" - val query3 = - s""" - |$baseQuery DELIMITED FIELDS TERMINATED BY 'x' ESCAPED BY 'y' - |COLLECTION ITEMS TERMINATED BY 'a' - |MAP KEYS TERMINATED BY 'b' - |LINES TERMINATED BY '\n' - |NULL DEFINED AS 'c' - """.stripMargin - val (desc1, _) = extractTableDesc(query1) - val (desc2, _) = extractTableDesc(query2) - val (desc3, _) = extractTableDesc(query3) - assert(desc1.storage.serde == Some("org.apache.poof.serde.Baff")) - assert(desc1.storage.properties.isEmpty) - assert(desc2.storage.serde == Some("org.apache.poof.serde.Baff")) - assert(desc2.storage.properties == Map("k1" -> "v1")) - assert(desc3.storage.properties == Map( - "field.delim" -> "x", - "escape.delim" -> "y", - "serialization.format" -> "x", - "line.delim" -> "\n", - "colelction.delim" -> "a", // yes, it's a typo from Hive :) - "mapkey.delim" -> "b")) - } - - test("create table - file format") { - val baseQuery = "CREATE TABLE my_table (id int, name string) STORED AS" - val query1 = s"$baseQuery INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput'" - val query2 = s"$baseQuery ORC" - val (desc1, _) = extractTableDesc(query1) - val (desc2, _) = extractTableDesc(query2) - assert(desc1.storage.inputFormat == Some("winput")) - assert(desc1.storage.outputFormat == Some("wowput")) - assert(desc1.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - assert(desc2.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) - assert(desc2.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) - assert(desc2.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } - - test("create table - storage handler") { - val baseQuery = "CREATE TABLE my_table (id int, name string) STORED BY" - val query1 = s"$baseQuery 'org.papachi.StorageHandler'" - val query2 = s"$baseQuery 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')" - val e1 = intercept[ParseException] { parser.parsePlan(query1) } - val e2 = intercept[ParseException] { parser.parsePlan(query2) } - assert(e1.getMessage.contains("Operation not allowed")) - assert(e2.getMessage.contains("Operation not allowed")) - } - - test("create table - properties") { - val query = "CREATE TABLE my_table (id int, name string) TBLPROPERTIES ('k1'='v1', 'k2'='v2')" - val (desc, _) = extractTableDesc(query) - assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) - } - - test("create table - everything!") { - val query = - """ - |CREATE EXTERNAL TABLE IF NOT EXISTS dbx.my_table (id int, name string) - |COMMENT 'no comment' - |PARTITIONED BY (month int) - |ROW FORMAT SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1') - |STORED AS INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput' - |LOCATION '/path/to/mercury' - |TBLPROPERTIES ('k1'='v1', 'k2'='v2') - """.stripMargin - val (desc, allowExisting) = extractTableDesc(query) - assert(allowExisting) - assert(desc.identifier.database == Some("dbx")) - assert(desc.identifier.table == "my_table") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.schema == new StructType() - .add("id", "int") - .add("name", "string") - .add("month", "int")) - assert(desc.partitionColumnNames == Seq("month")) - assert(desc.bucketSpec.isEmpty) - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.storage.locationUri == Some(new URI("/path/to/mercury"))) - assert(desc.storage.inputFormat == Some("winput")) - assert(desc.storage.outputFormat == Some("wowput")) - assert(desc.storage.serde == Some("org.apache.poof.serde.Baff")) - assert(desc.storage.properties == Map("k1" -> "v1")) - assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) - assert(desc.comment == Some("no comment")) - } - - test("create view -- basic") { - val v1 = "CREATE VIEW view1 AS SELECT * FROM tab1" - val command = parser.parsePlan(v1).asInstanceOf[CreateViewCommand] - assert(!command.allowExisting) - assert(command.name.database.isEmpty) - assert(command.name.table == "view1") - assert(command.originalText == Some("SELECT * FROM tab1")) - assert(command.userSpecifiedColumns.isEmpty) - } - - test("create view - full") { - val v1 = - """ - |CREATE OR REPLACE VIEW view1 - |(col1, col3 COMMENT 'hello') - |COMMENT 'BLABLA' - |TBLPROPERTIES('prop1Key'="prop1Val") - |AS SELECT * FROM tab1 - """.stripMargin - val command = parser.parsePlan(v1).asInstanceOf[CreateViewCommand] - assert(command.name.database.isEmpty) - assert(command.name.table == "view1") - assert(command.userSpecifiedColumns == Seq("col1" -> None, "col3" -> Some("hello"))) - assert(command.originalText == Some("SELECT * FROM tab1")) - assert(command.properties == Map("prop1Key" -> "prop1Val")) - assert(command.comment == Some("BLABLA")) - } - - test("create view -- partitioned view") { - val v1 = "CREATE VIEW view1 partitioned on (ds, hr) as select * from srcpart" - intercept[ParseException] { - parser.parsePlan(v1) - } - } - - test("MSCK REPAIR table") { - val sql = "MSCK REPAIR TABLE tab1" - val parsed = parser.parsePlan(sql) - val expected = AlterTableRecoverPartitionsCommand( - TableIdentifier("tab1", None), - "MSCK REPAIR TABLE") - comparePlans(parsed, expected) - } - - test("create table like") { - val v1 = "CREATE TABLE table1 LIKE table2" - val (target, source, location, exists) = parser.parsePlan(v1).collect { - case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting) - }.head - assert(exists == false) - assert(target.database.isEmpty) - assert(target.table == "table1") - assert(source.database.isEmpty) - assert(source.table == "table2") - assert(location.isEmpty) - - val v2 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2" - val (target2, source2, location2, exists2) = parser.parsePlan(v2).collect { - case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting) - }.head - assert(exists2) - assert(target2.database.isEmpty) - assert(target2.table == "table1") - assert(source2.database.isEmpty) - assert(source2.table == "table2") - assert(location2.isEmpty) - - val v3 = "CREATE TABLE table1 LIKE table2 LOCATION '/spark/warehouse'" - val (target3, source3, location3, exists3) = parser.parsePlan(v3).collect { - case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting) - }.head - assert(!exists3) - assert(target3.database.isEmpty) - assert(target3.table == "table1") - assert(source3.database.isEmpty) - assert(source3.table == "table2") - assert(location3 == Some("/spark/warehouse")) - - val v4 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2 LOCATION '/spark/warehouse'" - val (target4, source4, location4, exists4) = parser.parsePlan(v4).collect { - case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting) - }.head - assert(exists4) - assert(target4.database.isEmpty) - assert(target4.table == "table1") - assert(source4.database.isEmpty) - assert(source4.table == "table2") - assert(location4 == Some("/spark/warehouse")) - } - - test("load data") { - val v1 = "LOAD DATA INPATH 'path' INTO TABLE table1" - val (table, path, isLocal, isOverwrite, partition) = parser.parsePlan(v1).collect { - case LoadDataCommand(t, path, l, o, partition) => (t, path, l, o, partition) - }.head - assert(table.database.isEmpty) - assert(table.table == "table1") - assert(path == "path") - assert(!isLocal) - assert(!isOverwrite) - assert(partition.isEmpty) - - val v2 = "LOAD DATA LOCAL INPATH 'path' OVERWRITE INTO TABLE table1 PARTITION(c='1', d='2')" - val (table2, path2, isLocal2, isOverwrite2, partition2) = parser.parsePlan(v2).collect { - case LoadDataCommand(t, path, l, o, partition) => (t, path, l, o, partition) - }.head - assert(table2.database.isEmpty) - assert(table2.table == "table1") - assert(path2 == "path") - assert(isLocal2) - assert(isOverwrite2) - assert(partition2.nonEmpty) - assert(partition2.get.apply("c") == "1" && partition2.get.apply("d") == "2") - } - - test("Test the default fileformat for Hive-serde tables") { - withSQLConf("hive.default.fileformat" -> "orc") { - val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)") - assert(exists) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) - assert(desc.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } - - withSQLConf("hive.default.fileformat" -> "parquet") { - val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)") - assert(exists) - val input = desc.storage.inputFormat - val output = desc.storage.outputFormat - val serde = desc.storage.serde - assert(input == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat")) - assert(output == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) - assert(serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - } - } - - test("table name with schema") { - // regression test for SPARK-11778 - spark.sql("create schema usrdb") - spark.sql("create table usrdb.test(c int)") - spark.read.table("usrdb.test") - spark.sql("drop table usrdb.test") - spark.sql("drop schema usrdb") - } - - test("SPARK-15887: hive-site.xml should be loaded") { - assert(hiveClient.getConf("hive.in.test", "") == "true") - } - - test("create hive serde table with new syntax - basic") { - val sql = - """ - |CREATE TABLE t - |(id int, name string COMMENT 'blabla') - |USING hive - |OPTIONS (fileFormat 'parquet', my_prop 1) - |LOCATION '/tmp/file' - |COMMENT 'BLABLA' - """.stripMargin - - val table = analyzeCreateTable(sql) - assert(table.schema == new StructType() - .add("id", "int") - .add("name", "string", nullable = true, comment = "blabla")) - assert(table.provider == Some(DDLUtils.HIVE_PROVIDER)) - assert(table.storage.locationUri == Some(new URI("/tmp/file"))) - assert(table.storage.properties == Map("my_prop" -> "1")) - assert(table.comment == Some("BLABLA")) - - assert(table.storage.inputFormat == - Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat")) - assert(table.storage.outputFormat == - Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) - assert(table.storage.serde == - Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - } - - test("create hive serde table with new syntax - with partition and bucketing") { - val v1 = "CREATE TABLE t (c1 int, c2 int) USING hive PARTITIONED BY (c2)" - val table = analyzeCreateTable(v1) - assert(table.schema == new StructType().add("c1", "int").add("c2", "int")) - assert(table.partitionColumnNames == Seq("c2")) - // check the default formats - assert(table.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - assert(table.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) - assert(table.storage.outputFormat == - Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) - - val v2 = "CREATE TABLE t (c1 int, c2 int) USING hive CLUSTERED BY (c2) INTO 4 BUCKETS" - val e2 = intercept[AnalysisException](analyzeCreateTable(v2)) - assert(e2.message.contains("Creating bucketed Hive serde table is not supported yet")) - - val v3 = - """ - |CREATE TABLE t (c1 int, c2 int) USING hive - |PARTITIONED BY (c2) - |CLUSTERED BY (c2) INTO 4 BUCKETS""".stripMargin - val e3 = intercept[AnalysisException](analyzeCreateTable(v3)) - assert(e3.message.contains("Creating bucketed Hive serde table is not supported yet")) - } - - test("create hive serde table with new syntax - Hive options error checking") { - val v1 = "CREATE TABLE t (c1 int) USING hive OPTIONS (inputFormat 'abc')" - val e1 = intercept[IllegalArgumentException](analyzeCreateTable(v1)) - assert(e1.getMessage.contains("Cannot specify only inputFormat or outputFormat")) - - val v2 = "CREATE TABLE t (c1 int) USING hive OPTIONS " + - "(fileFormat 'x', inputFormat 'a', outputFormat 'b')" - val e2 = intercept[IllegalArgumentException](analyzeCreateTable(v2)) - assert(e2.getMessage.contains( - "Cannot specify fileFormat and inputFormat/outputFormat together")) - - val v3 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'parquet', serde 'a')" - val e3 = intercept[IllegalArgumentException](analyzeCreateTable(v3)) - assert(e3.getMessage.contains("fileFormat 'parquet' already specifies a serde")) - - val v4 = "CREATE TABLE t (c1 int) USING hive OPTIONS (serde 'a', fieldDelim ' ')" - val e4 = intercept[IllegalArgumentException](analyzeCreateTable(v4)) - assert(e4.getMessage.contains("Cannot specify delimiters with a custom serde")) - - val v5 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fieldDelim ' ')" - val e5 = intercept[IllegalArgumentException](analyzeCreateTable(v5)) - assert(e5.getMessage.contains("Cannot specify delimiters without fileFormat")) - - val v6 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'parquet', fieldDelim ' ')" - val e6 = intercept[IllegalArgumentException](analyzeCreateTable(v6)) - assert(e6.getMessage.contains( - "Cannot specify delimiters as they are only compatible with fileFormat 'textfile'")) - - // The value of 'fileFormat' option is case-insensitive. - val v7 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'TEXTFILE', lineDelim ',')" - val e7 = intercept[IllegalArgumentException](analyzeCreateTable(v7)) - assert(e7.getMessage.contains("Hive data source only support newline '\\n' as line delimiter")) - - val v8 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'wrong')" - val e8 = intercept[IllegalArgumentException](analyzeCreateTable(v8)) - assert(e8.getMessage.contains("invalid fileFormat: 'wrong'")) - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala index 193fa83dbad99..72f8e8ff7c688 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala @@ -42,4 +42,8 @@ class TestHiveSuite extends TestHiveSingleton with SQLTestUtils { } testHiveSparkSession.reset() } + + test("SPARK-15887: hive-site.xml should be loaded") { + assert(hiveClient.getConf("hive.in.test", "") == "true") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 4c2fea3eb68bc..ee64bc9f9ee04 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1998,4 +1998,15 @@ class HiveDDLSuite sq.stop() } } + + test("table name with schema") { + // regression test for SPARK-11778 + withDatabase("usrdb") { + spark.sql("create schema usrdb") + withTable("usrdb.test") { + spark.sql("create table usrdb.test(c int)") + spark.read.table("usrdb.test") + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 7803ac39e508b..1c9f00141ae1d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -17,15 +17,23 @@ package org.apache.spark.sql.hive.execution +import java.net.URI + import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} +import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.types.StructType /** * A set of tests that validates support for Hive SerDe. */ -class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { +class HiveSerDeSuite extends HiveComparisonTest with PlanTest with BeforeAndAfterAll { override def beforeAll(): Unit = { import TestHive._ import org.apache.hadoop.hive.serde2.RegexSerDe @@ -60,4 +68,127 @@ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { val serdeinsRes = InputOutputMetricsHelper.run(sql("select * from serdeins").toDF()) assert(serdeinsRes === (serdeinsCnt, 0L, serdeinsCnt) :: Nil) } + + private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { + TestHive.sessionState.sqlParser.parsePlan(sql).collect { + case CreateTable(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore) + }.head + } + + private def analyzeCreateTable(sql: String): CatalogTable = { + TestHive.sessionState.analyzer.execute(TestHive.sessionState.sqlParser.parsePlan(sql)).collect { + case CreateTableCommand(tableDesc, _) => tableDesc + }.head + } + + test("Test the default fileformat for Hive-serde tables") { + withSQLConf("hive.default.fileformat" -> "orc") { + val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)") + assert(exists) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } + + withSQLConf("hive.default.fileformat" -> "parquet") { + val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)") + assert(exists) + val input = desc.storage.inputFormat + val output = desc.storage.outputFormat + val serde = desc.storage.serde + assert(input == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat")) + assert(output == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + assert(serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + } + } + + test("create hive serde table with new syntax - basic") { + val sql = + """ + |CREATE TABLE t + |(id int, name string COMMENT 'blabla') + |USING hive + |OPTIONS (fileFormat 'parquet', my_prop 1) + |LOCATION '/tmp/file' + |COMMENT 'BLABLA' + """.stripMargin + + val table = analyzeCreateTable(sql) + assert(table.schema == new StructType() + .add("id", "int") + .add("name", "string", nullable = true, comment = "blabla")) + assert(table.provider == Some(DDLUtils.HIVE_PROVIDER)) + assert(table.storage.locationUri == Some(new URI("/tmp/file"))) + assert(table.storage.properties == Map("my_prop" -> "1")) + assert(table.comment == Some("BLABLA")) + + assert(table.storage.inputFormat == + Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat")) + assert(table.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + assert(table.storage.serde == + Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + } + + test("create hive serde table with new syntax - with partition and bucketing") { + val v1 = "CREATE TABLE t (c1 int, c2 int) USING hive PARTITIONED BY (c2)" + val table = analyzeCreateTable(v1) + assert(table.schema == new StructType().add("c1", "int").add("c2", "int")) + assert(table.partitionColumnNames == Seq("c2")) + // check the default formats + assert(table.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + assert(table.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(table.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + + val v2 = "CREATE TABLE t (c1 int, c2 int) USING hive CLUSTERED BY (c2) INTO 4 BUCKETS" + val e2 = intercept[AnalysisException](analyzeCreateTable(v2)) + assert(e2.message.contains("Creating bucketed Hive serde table is not supported yet")) + + val v3 = + """ + |CREATE TABLE t (c1 int, c2 int) USING hive + |PARTITIONED BY (c2) + |CLUSTERED BY (c2) INTO 4 BUCKETS""".stripMargin + val e3 = intercept[AnalysisException](analyzeCreateTable(v3)) + assert(e3.message.contains("Creating bucketed Hive serde table is not supported yet")) + } + + test("create hive serde table with new syntax - Hive options error checking") { + val v1 = "CREATE TABLE t (c1 int) USING hive OPTIONS (inputFormat 'abc')" + val e1 = intercept[IllegalArgumentException](analyzeCreateTable(v1)) + assert(e1.getMessage.contains("Cannot specify only inputFormat or outputFormat")) + + val v2 = "CREATE TABLE t (c1 int) USING hive OPTIONS " + + "(fileFormat 'x', inputFormat 'a', outputFormat 'b')" + val e2 = intercept[IllegalArgumentException](analyzeCreateTable(v2)) + assert(e2.getMessage.contains( + "Cannot specify fileFormat and inputFormat/outputFormat together")) + + val v3 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'parquet', serde 'a')" + val e3 = intercept[IllegalArgumentException](analyzeCreateTable(v3)) + assert(e3.getMessage.contains("fileFormat 'parquet' already specifies a serde")) + + val v4 = "CREATE TABLE t (c1 int) USING hive OPTIONS (serde 'a', fieldDelim ' ')" + val e4 = intercept[IllegalArgumentException](analyzeCreateTable(v4)) + assert(e4.getMessage.contains("Cannot specify delimiters with a custom serde")) + + val v5 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fieldDelim ' ')" + val e5 = intercept[IllegalArgumentException](analyzeCreateTable(v5)) + assert(e5.getMessage.contains("Cannot specify delimiters without fileFormat")) + + val v6 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'parquet', fieldDelim ' ')" + val e6 = intercept[IllegalArgumentException](analyzeCreateTable(v6)) + assert(e6.getMessage.contains( + "Cannot specify delimiters as they are only compatible with fileFormat 'textfile'")) + + // The value of 'fileFormat' option is case-insensitive. + val v7 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'TEXTFILE', lineDelim ',')" + val e7 = intercept[IllegalArgumentException](analyzeCreateTable(v7)) + assert(e7.getMessage.contains("Hive data source only support newline '\\n' as line delimiter")) + + val v8 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'wrong')" + val e8 = intercept[IllegalArgumentException](analyzeCreateTable(v8)) + assert(e8.getMessage.contains("invalid fileFormat: 'wrong'")) + } } From 3ed1ae10052e456d4efb13b2e51b5a43f7b1609a Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 22 Aug 2017 10:14:41 -0700 Subject: [PATCH 025/187] [SPARK-20641][CORE] Add missing kvstore module in Laucher and SparkSubmit code There're two code in Launcher and SparkSubmit will will explicitly list all the Spark submodules, newly added kvstore module is missing in this two parts, so submitting a minor PR to fix this. Author: jerryshao Closes #19014 from jerryshao/missing-kvstore. --- core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 2 +- .../java/org/apache/spark/launcher/AbstractCommandBuilder.java | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index e7e8fbc25d0ec..e56925102d47e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -988,7 +988,7 @@ private[spark] object SparkSubmitUtils { // We need to specify each component explicitly, otherwise we miss spark-streaming-kafka-0-8 and // other spark-streaming utility components. Underscore is there to differentiate between // spark-streaming_2.1x and spark-streaming-kafka-0-8-assembly_2.1x - val IVY_DEFAULT_EXCLUDES = Seq("catalyst_", "core_", "graphx_", "launcher_", "mllib_", + val IVY_DEFAULT_EXCLUDES = Seq("catalyst_", "core_", "graphx_", "kvstore_", "launcher_", "mllib_", "mllib-local_", "network-common_", "network-shuffle_", "repl_", "sketch_", "sql_", "streaming_", "tags_", "unsafe_") diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 44028c58ac489..c32974a57fccc 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -144,6 +144,7 @@ List buildClassPath(String appClassPath) throws IOException { if (prependClasses || isTesting) { String scala = getScalaVersion(); List projects = Arrays.asList( + "common/kvstore", "common/network-common", "common/network-shuffle", "common/network-yarn", From 43d71d96596baa8d2111a4b20bf21c1c668ad793 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 22 Aug 2017 13:01:35 -0700 Subject: [PATCH 026/187] [SPARK-21499][SQL] Support creating persistent function for Spark UDAF(UserDefinedAggregateFunction) ## What changes were proposed in this pull request? This PR is to enable users to create persistent Scala UDAF (that extends UserDefinedAggregateFunction). ```SQL CREATE FUNCTION myDoubleAvg AS 'test.org.apache.spark.sql.MyDoubleAvg' ``` Before this PR, Spark UDAF only can be registered through the API `spark.udf.register(...)` ## How was this patch tested? Added test cases Author: gatorsmile Closes #18700 from gatorsmile/javaUDFinScala. --- .../sql/catalyst/catalog/SessionCatalog.scala | 41 ++++++- .../test/resources/sql-tests/inputs/udaf.sql | 13 +++ .../resources/sql-tests/results/udaf.sql.out | 54 ++++++++++ .../spark/sql/hive/HiveSessionCatalog.scala | 62 +++++------ .../sql/hive/execution/HiveUDAFSuite.scala | 13 +++ .../sql/hive/execution/HiveUDFSuite.scala | 101 ++++++++++-------- .../sql/hive/execution/SQLQuerySuite.scala | 42 +------- 7 files changed, 204 insertions(+), 122 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/udaf.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/udaf.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 6030d90ed99c3..0908d68d25649 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.catalog +import java.lang.reflect.InvocationTargetException import java.net.URI import java.util.Locale import java.util.concurrent.Callable @@ -24,6 +25,7 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.util.{Failure, Success, Try} +import scala.util.control.NonFatal import com.google.common.cache.{Cache, CacheBuilder} import org.apache.hadoop.conf.Configuration @@ -39,7 +41,9 @@ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils object SessionCatalog { val DEFAULT_DATABASE = "default" @@ -1075,13 +1079,33 @@ class SessionCatalog( // ---------------------------------------------------------------- /** - * Construct a [[FunctionBuilder]] based on the provided class that represents a function. + * Constructs a [[FunctionBuilder]] based on the provided class that represents a function. + */ + private def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { + val clazz = Utils.classForName(functionClassName) + (input: Seq[Expression]) => makeFunctionExpression(name, clazz, input) + } + + /** + * Constructs a [[Expression]] based on the provided class that represents a function. * * This performs reflection to decide what type of [[Expression]] to return in the builder. */ - protected def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { - // TODO: at least support UDAFs here - throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.") + protected def makeFunctionExpression( + name: String, + clazz: Class[_], + input: Seq[Expression]): Expression = { + val clsForUDAF = + Utils.classForName("org.apache.spark.sql.expressions.UserDefinedAggregateFunction") + if (clsForUDAF.isAssignableFrom(clazz)) { + val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF") + cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int]) + .newInstance(input, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1)) + .asInstanceOf[Expression] + } else { + throw new AnalysisException(s"No handler for UDAF '${clazz.getCanonicalName}'. " + + s"Use sparkSession.udf.register(...) instead.") + } } /** @@ -1105,7 +1129,14 @@ class SessionCatalog( } val info = new ExpressionInfo(funcDefinition.className, func.database.orNull, func.funcName) val builder = - functionBuilder.getOrElse(makeFunctionBuilder(func.unquotedString, funcDefinition.className)) + functionBuilder.getOrElse { + val className = funcDefinition.className + if (!Utils.classIsLoadable(className)) { + throw new AnalysisException(s"Can not load class '$className' when registering " + + s"the function '$func', please make sure it is on the classpath") + } + makeFunctionBuilder(func.unquotedString, className) + } functionRegistry.registerFunction(func, info, builder) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql new file mode 100644 index 0000000000000..2183ba23afc38 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql @@ -0,0 +1,13 @@ +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES +(1), (2), (3), (4) +as t1(int_col1); + +CREATE FUNCTION myDoubleAvg AS 'test.org.apache.spark.sql.MyDoubleAvg'; + +SELECT default.myDoubleAvg(int_col1) as my_avg from t1; + +SELECT default.myDoubleAvg(int_col1, 3) as my_avg from t1; + +CREATE FUNCTION udaf1 AS 'test.non.existent.udaf'; + +SELECT default.udaf1(int_col1) as udaf1 from t1; diff --git a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out new file mode 100644 index 0000000000000..4815a578b1029 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out @@ -0,0 +1,54 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES +(1), (2), (3), (4) +as t1(int_col1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE FUNCTION myDoubleAvg AS 'test.org.apache.spark.sql.MyDoubleAvg' +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT default.myDoubleAvg(int_col1) as my_avg from t1 +-- !query 2 schema +struct +-- !query 2 output +102.5 + + +-- !query 3 +SELECT default.myDoubleAvg(int_col1, 3) as my_avg from t1 +-- !query 3 schema +struct<> +-- !query 3 output +java.lang.AssertionError +assertion failed: Incorrect number of children + + +-- !query 4 +CREATE FUNCTION udaf1 AS 'test.non.existent.udaf' +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +SELECT default.udaf1(int_col1) as udaf1 from t1 +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Can not load class 'test.non.existent.udaf' when registering the function 'default.udaf1', please make sure it is on the classpath; line 1 pos 7 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 0d0269f694300..b352bf6971bad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -37,7 +37,6 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, DoubleType} -import org.apache.spark.util.Utils private[sql] class HiveSessionCatalog( @@ -58,55 +57,52 @@ private[sql] class HiveSessionCatalog( parser, functionResourceLoader) { - override def makeFunctionBuilder(funcName: String, className: String): FunctionBuilder = { - makeFunctionBuilder(funcName, Utils.classForName(className)) - } - /** - * Construct a [[FunctionBuilder]] based on the provided class that represents a function. + * Constructs a [[Expression]] based on the provided class that represents a function. + * + * This performs reflection to decide what type of [[Expression]] to return in the builder. */ - private def makeFunctionBuilder(name: String, clazz: Class[_]): FunctionBuilder = { - // When we instantiate hive UDF wrapper class, we may throw exception if the input - // expressions don't satisfy the hive UDF, such as type mismatch, input number - // mismatch, etc. Here we catch the exception and throw AnalysisException instead. - (children: Seq[Expression]) => { + override def makeFunctionExpression( + name: String, + clazz: Class[_], + input: Seq[Expression]): Expression = { + + Try(super.makeFunctionExpression(name, clazz, input)).getOrElse { + var udfExpr: Option[Expression] = None try { + // When we instantiate hive UDF wrapper class, we may throw exception if the input + // expressions don't satisfy the hive UDF, such as type mismatch, input number + // mismatch, etc. Here we catch the exception and throw AnalysisException instead. if (classOf[UDF].isAssignableFrom(clazz)) { - val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), children) - udf.dataType // Force it to check input data types. - udf + udfExpr = Some(HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.dataType // Force it to check input data types. } else if (classOf[GenericUDF].isAssignableFrom(clazz)) { - val udf = HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), children) - udf.dataType // Force it to check input data types. - udf + udfExpr = Some(HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.dataType // Force it to check input data types. } else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) { - val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), children) - udaf.dataType // Force it to check input data types. - udaf + udfExpr = Some(HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.dataType // Force it to check input data types. } else if (classOf[UDAF].isAssignableFrom(clazz)) { - val udaf = HiveUDAFFunction( + udfExpr = Some(HiveUDAFFunction( name, new HiveFunctionWrapper(clazz.getName), - children, - isUDAFBridgeRequired = true) - udaf.dataType // Force it to check input data types. - udaf + input, + isUDAFBridgeRequired = true)) + udfExpr.get.dataType // Force it to check input data types. } else if (classOf[GenericUDTF].isAssignableFrom(clazz)) { - val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), children) - udtf.elementSchema // Force it to check input data types. - udtf - } else { - throw new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}'") + udfExpr = Some(HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.asInstanceOf[HiveGenericUDTF].elementSchema // Force it to check data types. } } catch { - case ae: AnalysisException => - throw ae case NonFatal(e) => val analysisException = - new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}': $e") + new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e") analysisException.setStackTrace(e.getStackTrace) throw analysisException } + udfExpr.getOrElse { + throw new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}'") + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index 479ca1e8def56..8986fb58c6460 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo +import test.org.apache.spark.sql.MyDoubleAvg import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec @@ -86,6 +87,18 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { )) } + test("call JAVA UDAF") { + withTempView("temp") { + withUserDefinedFunction("myDoubleAvg" -> false) { + spark.range(1, 10).toDF("value").createOrReplaceTempView("temp") + sql(s"CREATE FUNCTION myDoubleAvg AS '${classOf[MyDoubleAvg].getName}'") + checkAnswer( + spark.sql("SELECT default.myDoubleAvg(value) as my_avg from temp"), + Row(105.0)) + } + } + } + test("non-deterministic children expressions of UDAF") { withTempView("view1") { spark.range(1).selectExpr("id as x", "id as y").createTempView("view1") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index cae338c0ab0ae..383d41f907c6d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -404,59 +404,34 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("Hive UDFs with insufficient number of input arguments should trigger an analysis error") { - Seq((1, 2)).toDF("a", "b").createOrReplaceTempView("testUDF") + withTempView("testUDF") { + Seq((1, 2)).toDF("a", "b").createOrReplaceTempView("testUDF") + + def testErrorMsgForFunc(funcName: String, className: String): Unit = { + withUserDefinedFunction(funcName -> true) { + sql(s"CREATE TEMPORARY FUNCTION $funcName AS '$className'") + val message = intercept[AnalysisException] { + sql(s"SELECT $funcName() FROM testUDF") + }.getMessage + assert(message.contains(s"No handler for UDF/UDAF/UDTF '$className'")) + } + } - { // HiveSimpleUDF - sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") - val message = intercept[AnalysisException] { - sql("SELECT testUDFTwoListList() FROM testUDF") - }.getMessage - assert(message.contains("No handler for Hive UDF")) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") - } + testErrorMsgForFunc("testUDFTwoListList", classOf[UDFTwoListList].getName) - { // HiveGenericUDF - sql(s"CREATE TEMPORARY FUNCTION testUDFAnd AS '${classOf[GenericUDFOPAnd].getName}'") - val message = intercept[AnalysisException] { - sql("SELECT testUDFAnd() FROM testUDF") - }.getMessage - assert(message.contains("No handler for Hive UDF")) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFAnd") - } + testErrorMsgForFunc("testUDFAnd", classOf[GenericUDFOPAnd].getName) - { // Hive UDAF - sql(s"CREATE TEMPORARY FUNCTION testUDAFPercentile AS '${classOf[UDAFPercentile].getName}'") - val message = intercept[AnalysisException] { - sql("SELECT testUDAFPercentile(a) FROM testUDF GROUP BY b") - }.getMessage - assert(message.contains("No handler for Hive UDF")) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFPercentile") - } + testErrorMsgForFunc("testUDAFPercentile", classOf[UDAFPercentile].getName) - { // AbstractGenericUDAFResolver - sql(s"CREATE TEMPORARY FUNCTION testUDAFAverage AS '${classOf[GenericUDAFAverage].getName}'") - val message = intercept[AnalysisException] { - sql("SELECT testUDAFAverage() FROM testUDF GROUP BY b") - }.getMessage - assert(message.contains("No handler for Hive UDF")) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFAverage") - } + testErrorMsgForFunc("testUDAFAverage", classOf[GenericUDAFAverage].getName) - { - // Hive UDTF - sql(s"CREATE TEMPORARY FUNCTION testUDTFExplode AS '${classOf[GenericUDTFExplode].getName}'") - val message = intercept[AnalysisException] { - sql("SELECT testUDTFExplode() FROM testUDF") - }.getMessage - assert(message.contains("No handler for Hive UDF")) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDTFExplode") + // AbstractGenericUDAFResolver + testErrorMsgForFunc("testUDTFExplode", classOf[GenericUDTFExplode].getName) } - - spark.catalog.dropTempView("testUDF") } test("Hive UDF in group by") { @@ -621,6 +596,46 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } } } + + test("UDTF") { + withUserDefinedFunction("udtf_count2" -> true) { + sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath}") + // The function source code can be found at: + // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF + sql( + """ + |CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) + + checkAnswer( + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), + Row(97, 500) :: Row(97, 500) :: Nil) + + checkAnswer( + sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + } + } + + test("permanent UDTF") { + withUserDefinedFunction("udtf_count_temp" -> false) { + sql( + s""" + |CREATE FUNCTION udtf_count_temp + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + |USING JAR '${hiveContext.getHiveFile("TestUDTF.jar").toURI}' + """.stripMargin) + + checkAnswer( + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count_temp(value) dd AS cc"), + Row(97, 500) :: Row(97, 500) :: Nil) + + checkAnswer( + sql("SELECT udtf_count_temp(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index d0e0d20df30af..02cfa02a37886 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.TestUtils import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry, NoSuchPartitionException} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{CatalogTableType, CatalogUtils, HiveTableRelation} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} @@ -98,46 +98,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(query1, Row("x1_y1") :: Row("x2_y2") :: Nil) } - test("UDTF") { - withUserDefinedFunction("udtf_count2" -> true) { - sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}") - // The function source code can be found at: - // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF - sql( - """ - |CREATE TEMPORARY FUNCTION udtf_count2 - |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' - """.stripMargin) - - checkAnswer( - sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), - Row(97, 500) :: Row(97, 500) :: Nil) - - checkAnswer( - sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), - Row(3) :: Row(3) :: Nil) - } - } - - test("permanent UDTF") { - withUserDefinedFunction("udtf_count_temp" -> false) { - sql( - s""" - |CREATE FUNCTION udtf_count_temp - |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' - |USING JAR '${hiveContext.getHiveFile("TestUDTF.jar").toURI}' - """.stripMargin) - - checkAnswer( - sql("SELECT key, cc FROM src LATERAL VIEW udtf_count_temp(value) dd AS cc"), - Row(97, 500) :: Row(97, 500) :: Nil) - - checkAnswer( - sql("SELECT udtf_count_temp(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), - Row(3) :: Row(3) :: Nil) - } - } - test("SPARK-6835: udtf in lateral view") { val df = Seq((1, 1)).toDF("c1", "c2") df.createOrReplaceTempView("table1") From 01a8e46278dbfde916a74b6fd51e08804602e1cf Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 22 Aug 2017 13:12:59 -0700 Subject: [PATCH 027/187] [SPARK-21769][SQL] Add a table-specific option for always respecting schemas inferred/controlled by Spark SQL ## What changes were proposed in this pull request? For Hive-serde tables, we always respect the schema stored in Hive metastore, because the schema could be altered by the other engines that share the same metastore. Thus, we always trust the metastore-controlled schema for Hive-serde tables when the schemas are different (without considering the nullability and cases). However, in some scenarios, Hive metastore also could INCORRECTLY overwrite the schemas when the serde and Hive metastore built-in serde are different. The proposed solution is to introduce a table-specific option for such scenarios. For a specific table, users can make Spark always respect Spark-inferred/controlled schema instead of trusting metastore-controlled schema. By default, we trust Hive metastore-controlled schema. ## How was this patch tested? Added a cross-version test case Author: gatorsmile Closes #19003 from gatorsmile/respectSparkSchema. --- .../execution/datasources/SourceOptions.scala | 50 ++++++++++++++++++ .../spark/sql/hive/HiveExternalCatalog.scala | 11 ++-- .../test/resources/avroDecimal/decimal.avro | Bin 0 -> 203 bytes .../spark/sql/hive/client/VersionsSuite.scala | 41 ++++++++++++++ 4 files changed, 97 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SourceOptions.scala create mode 100755 sql/hive/src/test/resources/avroDecimal/decimal.avro diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SourceOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SourceOptions.scala new file mode 100644 index 0000000000000..c98c0b2a756a1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SourceOptions.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap + +/** + * Options for the data source. + */ +class SourceOptions( + @transient private val parameters: CaseInsensitiveMap[String]) + extends Serializable { + import SourceOptions._ + + def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) + + // A flag to disable saving a data source table's metadata in hive compatible way. + val skipHiveMetadata: Boolean = parameters + .get(SKIP_HIVE_METADATA).map(_.toBoolean).getOrElse(DEFAULT_SKIP_HIVE_METADATA) + + // A flag to always respect the Spark schema restored from the table properties + val respectSparkSchema: Boolean = parameters + .get(RESPECT_SPARK_SCHEMA).map(_.toBoolean).getOrElse(DEFAULT_RESPECT_SPARK_SCHEMA) +} + + +object SourceOptions { + + val SKIP_HIVE_METADATA = "skipHiveMetadata" + val DEFAULT_SKIP_HIVE_METADATA = false + + val RESPECT_SPARK_SCHEMA = "respectSparkSchema" + val DEFAULT_RESPECT_SPARK_SCHEMA = false + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index bdbb8bccbc5cd..34af37ce11103 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ColumnStat import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.execution.datasources.{PartitioningUtils, SourceOptions} import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.internal.StaticSQLConf._ @@ -260,6 +260,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat private def createDataSourceTable(table: CatalogTable, ignoreIfExists: Boolean): Unit = { // data source table always have a provider, it's guaranteed by `DDLUtils.isDatasourceTable`. val provider = table.provider.get + val options = new SourceOptions(table.storage.properties) // To work around some hive metastore issues, e.g. not case-preserving, bad decimal type // support, no column nullability, etc., we should do some extra works before saving table @@ -325,11 +326,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val qualifiedTableName = table.identifier.quotedString val maybeSerde = HiveSerDe.sourceToSerDe(provider) - val skipHiveMetadata = table.storage.properties - .getOrElse("skipHiveMetadata", "false").toBoolean val (hiveCompatibleTable, logMessage) = maybeSerde match { - case _ if skipHiveMetadata => + case _ if options.skipHiveMetadata => val message = s"Persisting data source table $qualifiedTableName into Hive metastore in" + "Spark SQL specific format, which is NOT compatible with Hive." @@ -737,6 +736,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } private def restoreHiveSerdeTable(table: CatalogTable): CatalogTable = { + val options = new SourceOptions(table.storage.properties) val hiveTable = table.copy( provider = Some(DDLUtils.HIVE_PROVIDER), tracksPartitionsInCatalog = true) @@ -748,7 +748,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val partColumnNames = getPartitionColumnsFromTableProperties(table) val reorderedSchema = reorderSchema(schema = schemaFromTableProps, partColumnNames) - if (DataType.equalsIgnoreCaseAndNullability(reorderedSchema, table.schema)) { + if (DataType.equalsIgnoreCaseAndNullability(reorderedSchema, table.schema) || + options.respectSparkSchema) { hiveTable.copy( schema = reorderedSchema, partitionColumnNames = partColumnNames, diff --git a/sql/hive/src/test/resources/avroDecimal/decimal.avro b/sql/hive/src/test/resources/avroDecimal/decimal.avro new file mode 100755 index 0000000000000000000000000000000000000000..6da423f78661fca12d6e642bc22585099b671440 GIT binary patch literal 203 zcmeZI%3@>^ODrqO*DFrWNX<=L#8jij}OQt6}nK20*nCz0pc}r8zlDI&ia+DuKFz(gi^MnZ=p;c}iBs7CK7B$%#2Y zqm6Wwa`MwNft(PC)hR$#xrsSSwXs0R&~!TaxAOg}XkNB;_cIPZ-^#?n#KO>oE(!qS Cg-CY* literal 0 HcmV?d00001 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 072e538b9ed54..cbbe869403724 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -763,6 +763,47 @@ class VersionsSuite extends SparkFunSuite with Logging { } } + test(s"$version: read avro file containing decimal") { + val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") + val location = new File(url.getFile) + + val tableName = "tab1" + val avroSchema = + """{ + | "name": "test_record", + | "type": "record", + | "fields": [ { + | "name": "f0", + | "type": [ + | "null", + | { + | "precision": 38, + | "scale": 2, + | "type": "bytes", + | "logicalType": "decimal" + | } + | ] + | } ] + |} + """.stripMargin + withTable(tableName) { + versionSpark.sql( + s""" + |CREATE TABLE $tableName + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' + |WITH SERDEPROPERTIES ('respectSparkSchema' = 'true') + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' + |LOCATION '$location' + |TBLPROPERTIES ('avro.schema.literal' = '$avroSchema') + """.stripMargin + ) + assert(versionSpark.table(tableName).collect() === + versionSpark.sql("SELECT 1.30").collect()) + } + } + // TODO: add more tests. } } From d56c262109a5d94b46fffc04954c34671b14ee4f Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Tue, 22 Aug 2017 16:55:34 -0700 Subject: [PATCH 028/187] [SPARK-21681][ML] fix bug of MLOR do not work correctly when featureStd contains zero ## What changes were proposed in this pull request? fix bug of MLOR do not work correctly when featureStd contains zero We can reproduce the bug through such dataset (features including zero variance), will generate wrong result (all coefficients becomes 0) ``` val multinomialDatasetWithZeroVar = { val nPoints = 100 val coefficients = Array( -0.57997, 0.912083, -0.371077, -0.16624, -0.84355, -0.048509) val xMean = Array(5.843, 3.0) val xVariance = Array(0.6856, 0.0) // including zero variance val testData = generateMultinomialLogisticInput( coefficients, xMean, xVariance, addIntercept = true, nPoints, seed) val df = sc.parallelize(testData, 4).toDF().withColumn("weight", lit(1.0)) df.cache() df } ``` ## How was this patch tested? testcase added. Author: WeichenXu Closes #18896 from WeichenXu123/fix_mlor_stdvalue_zero_bug. --- .../optim/aggregator/LogisticAggregator.scala | 12 +-- .../LogisticRegressionSuite.scala | 78 +++++++++++++++++++ .../aggregator/LogisticAggregatorSuite.scala | 37 ++++++++- 3 files changed, 118 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala index 66a52942e668c..272d36dd94ae8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala @@ -270,11 +270,13 @@ private[ml] class LogisticAggregator( val margins = new Array[Double](numClasses) features.foreachActive { (index, value) => - val stdValue = value / localFeaturesStd(index) - var j = 0 - while (j < numClasses) { - margins(j) += localCoefficients(index * numClasses + j) * stdValue - j += 1 + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + val stdValue = value / localFeaturesStd(index) + var j = 0 + while (j < numClasses) { + margins(j) += localCoefficients(index * numClasses + j) * stdValue + j += 1 + } } } var i = 0 diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 0570499e74516..542977a48f0ac 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -46,6 +46,7 @@ class LogisticRegressionSuite @transient var smallMultinomialDataset: Dataset[_] = _ @transient var binaryDataset: Dataset[_] = _ @transient var multinomialDataset: Dataset[_] = _ + @transient var multinomialDatasetWithZeroVar: Dataset[_] = _ private val eps: Double = 1e-5 override def beforeAll(): Unit = { @@ -99,6 +100,23 @@ class LogisticRegressionSuite df.cache() df } + + multinomialDatasetWithZeroVar = { + val nPoints = 100 + val coefficients = Array( + -0.57997, 0.912083, -0.371077, + -0.16624, -0.84355, -0.048509) + + val xMean = Array(5.843, 3.0) + val xVariance = Array(0.6856, 0.0) + + val testData = generateMultinomialLogisticInput( + coefficients, xMean, xVariance, addIntercept = true, nPoints, seed) + + val df = sc.parallelize(testData, 4).toDF().withColumn("weight", lit(1.0)) + df.cache() + df + } } /** @@ -112,6 +130,11 @@ class LogisticRegressionSuite multinomialDataset.rdd.map { case Row(label: Double, features: Vector, weight: Double) => label + "," + weight + "," + features.toArray.mkString(",") }.repartition(1).saveAsTextFile("target/tmp/LogisticRegressionSuite/multinomialDataset") + multinomialDatasetWithZeroVar.rdd.map { + case Row(label: Double, features: Vector, weight: Double) => + label + "," + weight + "," + features.toArray.mkString(",") + }.repartition(1) + .saveAsTextFile("target/tmp/LogisticRegressionSuite/multinomialDatasetWithZeroVar") } test("params") { @@ -1392,6 +1415,61 @@ class LogisticRegressionSuite assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps) } + test("multinomial logistic regression with zero variance (SPARK-21681)") { + val sqlContext = multinomialDatasetWithZeroVar.sqlContext + import sqlContext.implicits._ + val mlr = new LogisticRegression().setFamily("multinomial").setFitIntercept(true) + .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setWeightCol("weight") + + val model = mlr.fit(multinomialDatasetWithZeroVar) + + /* + Use the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = as.factor(data$V1) + w = data$V2 + features = as.matrix(data.frame(data$V3, data$V4)) + coefficients = coef(glmnet(features, label, weights=w, family="multinomial", + alpha = 0, lambda = 0)) + coefficients + $`0` + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + 0.2658824 + data.V3 0.1881871 + data.V4 . + + $`1` + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + 0.53604701 + data.V3 -0.02412645 + data.V4 . + + $`2` + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + -0.8019294 + data.V3 -0.1640607 + data.V4 . + */ + + val coefficientsR = new DenseMatrix(3, 2, Array( + 0.1881871, 0.0, + -0.02412645, 0.0, + -0.1640607, 0.0), isTransposed = true) + val interceptsR = Vectors.dense(0.2658824, 0.53604701, -0.8019294) + + model.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) + + assert(model.coefficientMatrix ~== coefficientsR relTol 0.05) + assert(model.coefficientMatrix.toArray.sum ~== 0.0 absTol eps) + assert(model.interceptVector ~== interceptsR relTol 0.05) + assert(model.interceptVector.toArray.sum ~== 0.0 absTol eps) + } + test("multinomial logistic regression with intercept without regularization with bound") { // Bound constrained optimization with bound on one side. val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala index 2b29c67d859db..16ef4af4f94e8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala @@ -28,6 +28,7 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var instances: Array[Instance] = _ @transient var instancesConstantFeature: Array[Instance] = _ + @transient var instancesConstantFeatureFiltered: Array[Instance] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -41,6 +42,11 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { Instance(1.0, 0.5, Vectors.dense(1.0, 1.0)), Instance(2.0, 0.3, Vectors.dense(1.0, 0.5)) ) + instancesConstantFeatureFiltered = Array( + Instance(0.0, 0.1, Vectors.dense(2.0)), + Instance(1.0, 0.5, Vectors.dense(1.0)), + Instance(2.0, 0.3, Vectors.dense(0.5)) + ) } /** Get summary statistics for some data and create a new LogisticAggregator. */ @@ -233,21 +239,44 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { val binaryInstances = instancesConstantFeature.map { instance => if (instance.label <= 1.0) instance else Instance(0.0, instance.weight, instance.features) } + val binaryInstancesFiltered = instancesConstantFeatureFiltered.map { instance => + if (instance.label <= 1.0) instance else Instance(0.0, instance.weight, instance.features) + } val coefArray = Array(1.0, 2.0, -2.0, 3.0, 0.0, -1.0) + val coefArrayFiltered = Array(3.0, 0.0, -1.0) val interceptArray = Array(4.0, 2.0, -3.0) val aggConstantFeature = getNewAggregator(instancesConstantFeature, Vectors.dense(coefArray ++ interceptArray), fitIntercept = true, isMultinomial = true) - instances.foreach(aggConstantFeature.add) + val aggConstantFeatureFiltered = getNewAggregator(instancesConstantFeatureFiltered, + Vectors.dense(coefArrayFiltered ++ interceptArray), fitIntercept = true, isMultinomial = true) + + instancesConstantFeature.foreach(aggConstantFeature.add) + instancesConstantFeatureFiltered.foreach(aggConstantFeatureFiltered.add) + // constant features should not affect gradient - assert(aggConstantFeature.gradient(0) === 0.0) + def validateGradient(grad: Vector, gradFiltered: Vector, numCoefficientSets: Int): Unit = { + for (i <- 0 until numCoefficientSets) { + assert(grad(i) === 0.0) + assert(grad(numCoefficientSets + i) == gradFiltered(i)) + } + } + + validateGradient(aggConstantFeature.gradient, aggConstantFeatureFiltered.gradient, 3) val binaryCoefArray = Array(1.0, 2.0) + val binaryCoefArrayFiltered = Array(2.0) val intercept = 1.0 val aggConstantFeatureBinary = getNewAggregator(binaryInstances, Vectors.dense(binaryCoefArray ++ Array(intercept)), fitIntercept = true, isMultinomial = false) - instances.foreach(aggConstantFeatureBinary.add) + val aggConstantFeatureBinaryFiltered = getNewAggregator(binaryInstancesFiltered, + Vectors.dense(binaryCoefArrayFiltered ++ Array(intercept)), fitIntercept = true, + isMultinomial = false) + binaryInstances.foreach(aggConstantFeatureBinary.add) + binaryInstancesFiltered.foreach(aggConstantFeatureBinaryFiltered.add) + // constant features should not affect gradient - assert(aggConstantFeatureBinary.gradient(0) === 0.0) + validateGradient(aggConstantFeatureBinary.gradient, + aggConstantFeatureBinaryFiltered.gradient, 1) } } From 41bb1ddc63298c004bb6a6bb6fff9fd4f6e44792 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 22 Aug 2017 17:40:50 -0700 Subject: [PATCH 029/187] [SPARK-10931][ML][PYSPARK] PySpark Models Copy Param Values from Estimator ## What changes were proposed in this pull request? Added call to copy values of Params from Estimator to Model after fit in PySpark ML. This will copy values for any params that are also defined in the Model. Since currently most Models do not define the same params from the Estimator, also added method to create new Params from looking at the Java object if they do not exist in the Python object. This is a temporary fix that can be removed once the PySpark models properly define the params themselves. ## How was this patch tested? Refactored the `check_params` test to optionally check if the model params for Python and Java match and added this check to an existing fitted model that shares params between Estimator and Model. Author: Bryan Cutler Closes #17849 from BryanCutler/pyspark-models-own-params-SPARK-10931. --- python/pyspark/ml/classification.py | 2 +- python/pyspark/ml/clustering.py | 8 ++- python/pyspark/ml/tests.py | 87 +++++++++++++++++------------ python/pyspark/ml/wrapper.py | 32 ++++++++++- 4 files changed, 91 insertions(+), 38 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index bccf8e7f636f1..235cee48bc6a6 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1434,7 +1434,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(MultilayerPerceptronClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid) - self._setDefault(maxIter=100, tol=1E-4, blockSize=128, stepSize=0.03, solver="l-bfgs") + self._setDefault(maxIter=100, tol=1E-6, blockSize=128, stepSize=0.03, solver="l-bfgs") kwargs = self._input_kwargs self.setParams(**kwargs) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 88ac7e275e386..66fb00508522e 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -745,7 +745,13 @@ def toLocal(self): WARNING: This involves collecting a large :py:func:`topicsMatrix` to the driver. """ - return LocalLDAModel(self._call_java("toLocal")) + model = LocalLDAModel(self._call_java("toLocal")) + + # SPARK-10931: Temporary fix to be removed once LDAModel defines Params + model._create_params_from_java() + model._transfer_params_from_java() + + return model @since("2.0.0") def trainingLogLikelihood(self): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 0495973d2f625..6076b3c2f26a6 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -455,6 +455,54 @@ def test_logistic_regression_check_thresholds(self): LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] ) + @staticmethod + def check_params(test_self, py_stage, check_params_exist=True): + """ + Checks common requirements for Params.params: + - set of params exist in Java and Python and are ordered by names + - param parent has the same UID as the object's UID + - default param value from Java matches value in Python + - optionally check if all params from Java also exist in Python + """ + py_stage_str = "%s %s" % (type(py_stage), py_stage) + if not hasattr(py_stage, "_to_java"): + return + java_stage = py_stage._to_java() + if java_stage is None: + return + test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str) + if check_params_exist: + param_names = [p.name for p in py_stage.params] + java_params = list(java_stage.params()) + java_param_names = [jp.name() for jp in java_params] + test_self.assertEqual( + param_names, sorted(java_param_names), + "Param list in Python does not match Java for %s:\nJava = %s\nPython = %s" + % (py_stage_str, java_param_names, param_names)) + for p in py_stage.params: + test_self.assertEqual(p.parent, py_stage.uid) + java_param = java_stage.getParam(p.name) + py_has_default = py_stage.hasDefault(p) + java_has_default = java_stage.hasDefault(java_param) + test_self.assertEqual(py_has_default, java_has_default, + "Default value mismatch of param %s for Params %s" + % (p.name, str(py_stage))) + if py_has_default: + if p.name == "seed": + continue # Random seeds between Spark and PySpark are different + java_default = _java2py(test_self.sc, + java_stage.clear(java_param).getOrDefault(java_param)) + py_stage._clear(p) + py_default = py_stage.getOrDefault(p) + # equality test for NaN is always False + if isinstance(java_default, float) and np.isnan(java_default): + java_default = "NaN" + py_default = "NaN" if np.isnan(py_default) else "not NaN" + test_self.assertEqual( + java_default, py_default, + "Java default %s != python default %s of param %s for Params %s" + % (str(java_default), str(py_default), p.name, str(py_stage))) + class EvaluatorTests(SparkSessionTestCase): @@ -511,6 +559,8 @@ def test_idf(self): "Model should inherit the UID from its parent estimator.") output = idf0m.transform(dataset) self.assertIsNotNone(output.head().idf) + # Test that parameters transferred to Python Model + ParamTests.check_params(self, idf0m) def test_ngram(self): dataset = self.spark.createDataFrame([ @@ -1656,40 +1706,6 @@ class DefaultValuesTests(PySparkTestCase): those in their Scala counterparts. """ - def check_params(self, py_stage): - import pyspark.ml.feature - if not hasattr(py_stage, "_to_java"): - return - java_stage = py_stage._to_java() - if java_stage is None: - return - for p in py_stage.params: - java_param = java_stage.getParam(p.name) - py_has_default = py_stage.hasDefault(p) - java_has_default = java_stage.hasDefault(java_param) - self.assertEqual(py_has_default, java_has_default, - "Default value mismatch of param %s for Params %s" - % (p.name, str(py_stage))) - if py_has_default: - if p.name == "seed": - return # Random seeds between Spark and PySpark are different - java_default =\ - _java2py(self.sc, java_stage.clear(java_param).getOrDefault(java_param)) - py_stage._clear(p) - py_default = py_stage.getOrDefault(p) - if isinstance(py_stage, pyspark.ml.feature.Imputer) and p.name == "missingValue": - # SPARK-15040 - default value for Imputer param 'missingValue' is NaN, - # and NaN != NaN, so handle it specially here - import math - self.assertTrue(math.isnan(java_default) and math.isnan(py_default), - "Java default %s and python default %s are not both NaN for " - "param %s for Params %s" - % (str(java_default), str(py_default), p.name, str(py_stage))) - return - self.assertEqual(java_default, py_default, - "Java default %s != python default %s of param %s for Params %s" - % (str(java_default), str(py_default), p.name, str(py_stage))) - def test_java_params(self): import pyspark.ml.feature import pyspark.ml.classification @@ -1703,7 +1719,8 @@ def test_java_params(self): for name, cls in inspect.getmembers(module, inspect.isclass): if not name.endswith('Model') and issubclass(cls, JavaParams)\ and not inspect.isabstract(cls): - self.check_params(cls()) + # NOTE: disable check_params_exist until there is parity with Scala API + ParamTests.check_params(self, cls(), check_params_exist=False) def _squared_distance(a, b): diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index ee6301ef19a43..0f846fbc5b5ef 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -135,6 +135,20 @@ def _transfer_param_map_to_java(self, pyParamMap): paramMap.put([pair]) return paramMap + def _create_params_from_java(self): + """ + SPARK-10931: Temporary fix to create params that are defined in the Java obj but not here + """ + java_params = list(self._java_obj.params()) + from pyspark.ml.param import Param + for java_param in java_params: + java_param_name = java_param.name() + if not hasattr(self, java_param_name): + param = Param(self, java_param_name, java_param.doc()) + setattr(param, "created_from_java_param", True) + setattr(self, java_param_name, param) + self._params = None # need to reset so self.params will discover new params + def _transfer_params_from_java(self): """ Transforms the embedded params from the companion Java object. @@ -147,6 +161,10 @@ def _transfer_params_from_java(self): if self._java_obj.isSet(java_param): value = _java2py(sc, self._java_obj.getOrDefault(java_param)) self._set(**{param.name: value}) + # SPARK-10931: Temporary fix for params that have a default in Java + if self._java_obj.hasDefault(java_param) and not self.isDefined(param): + value = _java2py(sc, self._java_obj.getDefault(java_param)).get() + self._setDefault(**{param.name: value}) def _transfer_param_map_from_java(self, javaParamMap): """ @@ -204,6 +222,11 @@ def __get_class(clazz): # Load information from java_stage to the instance. py_stage = py_type() py_stage._java_obj = java_stage + + # SPARK-10931: Temporary fix so that persisted models would own params from Estimator + if issubclass(py_type, JavaModel): + py_stage._create_params_from_java() + py_stage._resetUid(java_stage.uid()) py_stage._transfer_params_from_java() elif hasattr(py_type, "_from_java"): @@ -263,7 +286,8 @@ def _fit_java(self, dataset): def _fit(self, dataset): java_model = self._fit_java(dataset) - return self._create_model(java_model) + model = self._create_model(java_model) + return self._copyValues(model) @inherit_doc @@ -307,4 +331,10 @@ def __init__(self, java_model=None): """ super(JavaModel, self).__init__(java_model) if java_model is not None: + + # SPARK-10931: This is a temporary fix to allow models to own params + # from estimators. Eventually, these params should be in models through + # using common base classes between estimators and models. + self._create_params_from_java() + self._resetUid(java_model.uid()) From 3c0c2d09ca89c6b6247137823169db17847dfae3 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 22 Aug 2017 19:07:43 -0700 Subject: [PATCH 030/187] [SPARK-21765] Set isStreaming on leaf nodes for streaming plans. ## What changes were proposed in this pull request? All streaming logical plans will now have isStreaming set. This involved adding isStreaming as a case class arg in a few cases, since a node might be logically streaming depending on where it came from. ## How was this patch tested? Existing unit tests - no functional change is intended in this PR. Author: Jose Torres Author: Tathagata Das Closes #18973 from joseph-torres/SPARK-21765. --- .../spark/sql/kafka010/KafkaSource.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 10 +++--- .../plans/logical/LocalRelation.scala | 5 ++- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../plans/logical/basicLogicalOperators.scala | 11 ++++--- .../analysis/ResolveInlineTablesSuite.scala | 2 +- .../analysis/UnsupportedOperationsSuite.scala | 6 ++-- .../optimizer/ReplaceOperatorSuite.scala | 6 ++-- .../sql/catalyst/plans/LogicalPlanSuite.scala | 6 ++-- .../apache/spark/sql/DataFrameReader.scala | 4 +-- .../apache/spark/sql/DataFrameWriter.scala | 4 +-- .../scala/org/apache/spark/sql/Dataset.scala | 7 +++-- .../org/apache/spark/sql/SQLContext.scala | 7 +++-- .../org/apache/spark/sql/SparkSession.scala | 8 +++-- .../spark/sql/execution/ExistingRDD.scala | 8 +++-- .../execution/OptimizeMetadataOnlyQuery.scala | 6 ++-- .../spark/sql/execution/SparkStrategies.scala | 8 +++-- .../execution/datasources/DataSource.scala | 2 +- .../datasources/DataSourceStrategy.scala | 15 ++++----- .../datasources/FileSourceStrategy.scala | 2 +- .../datasources/LogicalRelation.scala | 12 ++++--- .../PruneFileSourcePartitions.scala | 1 + .../sql/execution/datasources/rules.scala | 10 +++--- .../streaming/FileStreamSource.scala | 2 +- .../streaming/RateSourceProvider.scala | 5 +-- .../execution/streaming/StreamExecution.scala | 3 ++ .../sql/execution/streaming/memory.scala | 31 +++++++++++++++---- .../OptimizeMetadataOnlyQuerySuite.scala | 4 +-- .../sql/execution/SparkPlannerSuite.scala | 2 +- .../datasources/FileSourceStrategySuite.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 3 +- .../ParquetPartitionDiscoverySuite.scala | 2 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 +- .../spark/sql/sources/FilteredScanSuite.scala | 2 +- .../spark/sql/sources/PathOptionSuite.scala | 2 +- .../sql/streaming/FileStreamSinkSuite.scala | 2 +- .../sql/streaming/FileStreamSourceSuite.scala | 5 ++- .../spark/sql/streaming/StreamSuite.scala | 12 ++++++- .../streaming/StreamingAggregationSuite.scala | 29 +++++++++++++++-- .../sql/streaming/StreamingQuerySuite.scala | 5 ++- .../test/DataStreamReaderWriterSuite.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- .../spark/sql/hive/orc/OrcFilterSuite.scala | 4 +-- .../apache/spark/sql/hive/parquetSuites.scala | 8 ++--- 46 files changed, 180 insertions(+), 97 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 7ac183776e20d..e9cff04ba5f2e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -310,7 +310,7 @@ private[kafka010] class KafkaSource( currentPartitionOffsets = Some(untilPartitionOffsets) } - sqlContext.internalCreateDataFrame(rdd, schema) + sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) } /** Stop this source and free any resources it has allocated. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e2d7164d93ac1..75d83bc6e86f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1175,14 +1175,14 @@ object DecimalAggregates extends Rule[LogicalPlan] { */ object ConvertToLocalRelation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Project(projectList, LocalRelation(output, data)) + case Project(projectList, LocalRelation(output, data, isStreaming)) if !projectList.exists(hasUnevaluableExpr) => val projection = new InterpretedProjection(projectList, output) projection.initialize(0) - LocalRelation(projectList.map(_.toAttribute), data.map(projection)) + LocalRelation(projectList.map(_.toAttribute), data.map(projection), isStreaming) - case Limit(IntegerLiteral(limit), LocalRelation(output, data)) => - LocalRelation(output, data.take(limit)) + case Limit(IntegerLiteral(limit), LocalRelation(output, data, isStreaming)) => + LocalRelation(output, data.take(limit), isStreaming) } private def hasUnevaluableExpr(expr: Expression): Boolean = { @@ -1207,7 +1207,7 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { */ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Deduplicate(keys, child, streaming) if !streaming => + case Deduplicate(keys, child) if !child.isStreaming => val keyExprIds = keys.map(_.exprId) val aggCols = child.output.map { attr => if (keyExprIds.contains(attr.exprId)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 1c986fbde7ada..7a21183664c56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -43,7 +43,10 @@ object LocalRelation { } } -case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) +case class LocalRelation(output: Seq[Attribute], + data: Seq[InternalRow] = Nil, + // Indicates whether this relation has data from a streaming source. + override val isStreaming: Boolean = false) extends LeafNode with analysis.MultiInstanceRelation { // A local relation must have resolved output. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 9b440cd99f994..d893b392e56b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -47,7 +47,7 @@ abstract class LogicalPlan */ def analyzed: Boolean = _analyzed - /** Returns true if this subtree contains any streaming data sources. */ + /** Returns true if this subtree has data from a streaming data source. */ def isStreaming: Boolean = children.exists(_.isStreaming == true) /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 303014e0b8d31..4b3054dbfe2f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -429,9 +429,10 @@ case class Sort( /** Factory for constructing new `Range` nodes. */ object Range { - def apply(start: Long, end: Long, step: Long, numSlices: Option[Int]): Range = { + def apply(start: Long, end: Long, step: Long, + numSlices: Option[Int], isStreaming: Boolean = false): Range = { val output = StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes - new Range(start, end, step, numSlices, output) + new Range(start, end, step, numSlices, output, isStreaming) } def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = { Range(start, end, step, Some(numSlices)) @@ -443,7 +444,8 @@ case class Range( end: Long, step: Long, numSlices: Option[Int], - output: Seq[Attribute]) + output: Seq[Attribute], + override val isStreaming: Boolean) extends LeafNode with MultiInstanceRelation { require(step != 0, s"step ($step) cannot be 0") @@ -784,8 +786,7 @@ case class OneRowRelation() extends LeafNode { /** A logical plan for `dropDuplicates`. */ case class Deduplicate( keys: Seq[Attribute], - child: LogicalPlan, - streaming: Boolean) extends UnaryNode { + child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala index d0fe815052256..9e99c8e11cdfe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala @@ -93,7 +93,7 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType)))) val withTimeZone = ResolveTimeZone(conf).apply(table) - val LocalRelation(output, data) = ResolveInlineTables(conf).apply(withTimeZone) + val LocalRelation(output, data, _) = ResolveInlineTables(conf).apply(withTimeZone) val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType) .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long] assert(output.map(_.dataType) == Seq(TimestampType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index f68d930f60523..4de75866e04a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -368,18 +368,18 @@ class UnsupportedOperationsSuite extends SparkFunSuite { Aggregate( Seq(attributeWithWatermark), aggExprs("c"), - Deduplicate(Seq(att), streamRelation, streaming = true)), + Deduplicate(Seq(att), streamRelation)), outputMode = Append) assertNotSupportedInStreamingPlan( "Deduplicate - Deduplicate on streaming relation after aggregation", - Deduplicate(Seq(att), Aggregate(Nil, aggExprs("c"), streamRelation), streaming = true), + Deduplicate(Seq(att), Aggregate(Nil, aggExprs("c"), streamRelation)), outputMode = Complete, expectedMsgs = Seq("dropDuplicates")) assertSupportedInStreamingPlan( "Deduplicate - Deduplicate on batch relation inside a streaming query", - Deduplicate(Seq(att), batchRelation, streaming = false), + Deduplicate(Seq(att), batchRelation), outputMode = Append ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index e68423f85c92e..85988d2fb948c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -79,7 +79,7 @@ class ReplaceOperatorSuite extends PlanTest { val input = LocalRelation('a.int, 'b.int) val attrA = input.output(0) val attrB = input.output(1) - val query = Deduplicate(Seq(attrA), input, streaming = false) // dropDuplicates("a") + val query = Deduplicate(Seq(attrA), input) // dropDuplicates("a") val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -95,9 +95,9 @@ class ReplaceOperatorSuite extends PlanTest { } test("don't replace streaming Deduplicate") { - val input = LocalRelation('a.int, 'b.int) + val input = LocalRelation(Seq('a.int, 'b.int), isStreaming = true) val attrA = input.output(0) - val query = Deduplicate(Seq(attrA), input, streaming = true) // dropDuplicates("a") + val query = Deduplicate(Seq(attrA), input) // dropDuplicates("a") val optimized = Optimize.execute(query.analyze) comparePlans(optimized, query) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index cc86f1f6e2f48..cdf912df7c76a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -73,10 +73,8 @@ class LogicalPlanSuite extends SparkFunSuite { test("isStreaming") { val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) - val incrementalRelation = new LocalRelation( - Seq(AttributeReference("a", IntegerType, nullable = true)())) { - override def isStreaming(): Boolean = true - } + val incrementalRelation = LocalRelation( + Seq(AttributeReference("a", IntegerType, nullable = true)()), isStreaming = true) case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 10b28ce812afc..41cb019499ae1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -410,7 +410,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { Dataset.ofRows( sparkSession, - LogicalRDD(schema.toAttributes, parsed)(sparkSession)) + LogicalRDD(schema.toAttributes, parsed, isStreaming = jsonDataset.isStreaming)(sparkSession)) } /** @@ -473,7 +473,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { Dataset.ofRows( sparkSession, - LogicalRDD(schema.toAttributes, parsed)(sparkSession)) + LogicalRDD(schema.toAttributes, parsed, isStreaming = csvDataset.isStreaming)(sparkSession)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 877051a60e910..cca93525d6792 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -371,14 +371,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { case (true, SaveMode.Overwrite) => // Get all input data source or hive relations of the query. val srcRelations = df.logicalPlan.collect { - case LogicalRelation(src: BaseRelation, _, _) => src + case LogicalRelation(src: BaseRelation, _, _, _) => src case relation: HiveTableRelation => relation.tableMeta.identifier } val tableRelation = df.sparkSession.table(tableIdentWithDB).queryExecution.analyzed EliminateSubqueryAliases(tableRelation) match { // check if the table is a data source table (the relation is a BaseRelation). - case LogicalRelation(dest: BaseRelation, _, _) if srcRelations.contains(dest) => + case LogicalRelation(dest: BaseRelation, _, _, _) if srcRelations.contains(dest) => throw new AnalysisException( s"Cannot overwrite table $tableName that is also being read from") // check hive table relation when overwrite mode diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 615686ccbe2b3..c6707396af1a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -569,7 +569,8 @@ class Dataset[T] private[sql]( logicalPlan.output, internalRdd, outputPartitioning, - physicalPlan.outputOrdering + physicalPlan.outputOrdering, + isStreaming )(sparkSession)).as[T] } @@ -2233,7 +2234,7 @@ class Dataset[T] private[sql]( } cols } - Deduplicate(groupCols, logicalPlan, isStreaming) + Deduplicate(groupCols, logicalPlan) } /** @@ -2993,7 +2994,7 @@ class Dataset[T] private[sql]( */ def inputFiles: Array[String] = { val files: Seq[String] = queryExecution.optimizedPlan.collect { - case LogicalRelation(fsBasedRelation: FileRelation, _, _) => + case LogicalRelation(fsBasedRelation: FileRelation, _, _, _) => fsBasedRelation.inputFiles case fr: FileRelation => fr.inputFiles diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 7fde6e9469e5e..af6018472cb03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -420,8 +420,11 @@ class SQLContext private[sql](val sparkSession: SparkSession) * converted to Catalyst rows. */ private[sql] - def internalCreateDataFrame(catalystRows: RDD[InternalRow], schema: StructType) = { - sparkSession.internalCreateDataFrame(catalystRows, schema) + def internalCreateDataFrame( + catalystRows: RDD[InternalRow], + schema: StructType, + isStreaming: Boolean = false) = { + sparkSession.internalCreateDataFrame(catalystRows, schema, isStreaming) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 6dfe8a66baa9b..863c316bbac65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -564,10 +564,14 @@ class SparkSession private( */ private[sql] def internalCreateDataFrame( catalystRows: RDD[InternalRow], - schema: StructType): DataFrame = { + schema: StructType, + isStreaming: Boolean = false): DataFrame = { // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. - val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) + val logicalPlan = LogicalRDD( + schema.toAttributes, + catalystRows, + isStreaming = isStreaming)(self) Dataset.ofRows(self, logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index dcb918eeb9d10..f3555508185fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -125,7 +125,8 @@ case class LogicalRDD( output: Seq[Attribute], rdd: RDD[InternalRow], outputPartitioning: Partitioning = UnknownPartitioning(0), - outputOrdering: Seq[SortOrder] = Nil)(session: SparkSession) + outputOrdering: Seq[SortOrder] = Nil, + override val isStreaming: Boolean = false)(session: SparkSession) extends LeafNode with MultiInstanceRelation { override protected final def otherCopyArgs: Seq[AnyRef] = session :: Nil @@ -150,11 +151,12 @@ case class LogicalRDD( output.map(rewrite), rdd, rewrittenPartitioning, - rewrittenOrdering + rewrittenOrdering, + isStreaming )(session).asInstanceOf[this.type] } - override protected def stringArgs: Iterator[Any] = Iterator(output) + override protected def stringArgs: Iterator[Any] = Iterator(output, isStreaming) override def computeStats(): Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 301c4f02647d5..18f6f697bc857 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -94,10 +94,10 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic child transform { case plan if plan eq relation => relation match { - case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _) => + case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, isStreaming) => val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) val partitionData = fsRelation.location.listFiles(Nil, Nil) - LocalRelation(partAttrs, partitionData.map(_.values)) + LocalRelation(partAttrs, partitionData.map(_.values), isStreaming) case relation: HiveTableRelation => val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) @@ -130,7 +130,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic object PartitionedRelation { def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = plan match { - case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _) + case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) if fsRelation.partitionSchema.nonEmpty => val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) Some((AttributeSet(partAttrs), l)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c115cb6e80e91..6b16408e27840 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -221,12 +221,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** - * Used to plan aggregation queries that are computed incrementally as part of a + * Used to plan streaming aggregation queries that are computed incrementally as part of a * [[StreamingQuery]]. Currently this rule is injected into the planner * on-demand, only when planning in a [[org.apache.spark.sql.execution.streaming.StreamExecution]] */ object StatefulAggregationStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case _ if !plan.isStreaming => Nil + case EventTimeWatermark(columnName, delay, child) => EventTimeWatermarkExec(columnName, delay, planLater(child)) :: Nil @@ -248,7 +250,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object StreamingDeduplicationStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case Deduplicate(keys, child, true) => + case Deduplicate(keys, child) if child.isStreaming => StreamingDeduplicateExec(keys, planLater(child)) :: Nil case _ => Nil @@ -410,7 +412,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil - case logical.LocalRelation(output, data) => + case logical.LocalRelation(output, data, _) => LocalTableScanExec(output, data) :: Nil case logical.LocalLimit(IntegerLiteral(limit), child) => execution.LocalLimitExec(limit, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 567ff49773f9b..b9502a95a7c08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -455,7 +455,7 @@ case class DataSource( val fileIndex = catalogTable.map(_.identifier).map { tableIdent => sparkSession.table(tableIdent).queryExecution.analyzed.collect { - case LogicalRelation(t: HadoopFsRelation, _, _) => t.location + case LogicalRelation(t: HadoopFsRelation, _, _, _) => t.location }.head } // For partitioned relation r, r.schema's column ordering can be different from the column diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 237017742770a..0deac1984bd62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -136,12 +136,12 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast if query.resolved && DDLUtils.isDatasourceTable(tableDesc) => CreateDataSourceTableAsSelectCommand(tableDesc, mode, query) - case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _), + case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _), parts, query, overwrite, false) if parts.isEmpty => InsertIntoDataSourceCommand(l, query, overwrite) case i @ InsertIntoTable( - l @ LogicalRelation(t: HadoopFsRelation, _, table), parts, query, overwrite, _) => + l @ LogicalRelation(t: HadoopFsRelation, _, table, _), parts, query, overwrite, _) => // If the InsertIntoTable command is for a partitioned HadoopFsRelation and // the user has specified static partitions, we add a Project operator on top of the query // to include those constant column values in the query result. @@ -177,7 +177,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast val outputPath = t.location.rootPaths.head val inputPaths = actualQuery.collect { - case LogicalRelation(r: HadoopFsRelation, _, _) => r.location.rootPaths + case LogicalRelation(r: HadoopFsRelation, _, _, _) => r.location.rootPaths }.flatten val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append @@ -268,7 +268,7 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with import DataSourceStrategy._ def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _, _)) => pruneFilterProjectRaw( l, projects, @@ -276,21 +276,22 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with (requestedColumns, allPredicates, _) => toCatalystRDD(l, requestedColumns, t.buildScan(requestedColumns, allPredicates))) :: Nil - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _, _)) => + case PhysicalOperation(projects, filters, + l @ LogicalRelation(t: PrunedFilteredScan, _, _, _)) => pruneFilterProject( l, projects, filters, (a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedScan, _, _)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedScan, _, _, _)) => pruneFilterProject( l, projects, filters, (a, _) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray))) :: Nil - case l @ LogicalRelation(baseRelation: TableScan, _, _) => + case l @ LogicalRelation(baseRelation: TableScan, _, _, _) => RowDataSourceScanExec( l.output, l.output.indices, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 17f7e0e601c0c..16b22717b8d92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -52,7 +52,7 @@ import org.apache.spark.sql.execution.SparkPlan object FileSourceStrategy extends Strategy with Logging { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projects, filters, - l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table)) => + l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) => // Filters on this relation fall into four categories based on where we can use them to avoid // reading unneeded data: // - partition keys only - used to prune directories to read diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 699f1bad9c4ed..17a61074d3b5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -30,12 +30,14 @@ import org.apache.spark.util.Utils case class LogicalRelation( relation: BaseRelation, output: Seq[AttributeReference], - catalogTable: Option[CatalogTable]) + catalogTable: Option[CatalogTable], + override val isStreaming: Boolean) extends LeafNode with MultiInstanceRelation { // Logical Relations are distinct if they have different output for the sake of transformations. override def equals(other: Any): Boolean = other match { - case l @ LogicalRelation(otherRelation, _, _) => relation == otherRelation && output == l.output + case l @ LogicalRelation(otherRelation, _, _, isStreaming) => + relation == otherRelation && output == l.output && isStreaming == l.isStreaming case _ => false } @@ -76,9 +78,9 @@ case class LogicalRelation( } object LogicalRelation { - def apply(relation: BaseRelation): LogicalRelation = - LogicalRelation(relation, relation.schema.toAttributes, None) + def apply(relation: BaseRelation, isStreaming: Boolean = false): LogicalRelation = + LogicalRelation(relation, relation.schema.toAttributes, None, isStreaming) def apply(relation: BaseRelation, table: CatalogTable): LogicalRelation = - LogicalRelation(relation, relation.schema.toAttributes, Some(table)) + LogicalRelation(relation, relation.schema.toAttributes, Some(table), false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index f5df1848a38c4..3b830accb83f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -36,6 +36,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { _, _), _, + _, _)) if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => // The attribute name of predicate could be different than the one in schema in case of diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 84acca242aa41..7a2c85e8e01f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -385,10 +385,10 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] wit case relation: HiveTableRelation => val metadata = relation.tableMeta preprocess(i, metadata.identifier.quotedString, metadata.partitionColumnNames) - case LogicalRelation(h: HadoopFsRelation, _, catalogTable) => + case LogicalRelation(h: HadoopFsRelation, _, catalogTable, _) => val tblName = catalogTable.map(_.identifier.quotedString).getOrElse("unknown") preprocess(i, tblName, h.partitionSchema.map(_.name)) - case LogicalRelation(_: InsertableRelation, _, catalogTable) => + case LogicalRelation(_: InsertableRelation, _, catalogTable, _) => val tblName = catalogTable.map(_.identifier.quotedString).getOrElse("unknown") preprocess(i, tblName, Nil) case _ => i @@ -428,7 +428,7 @@ object PreReadCheck extends (LogicalPlan => Unit) { private def checkNumInputFileBlockSources(e: Expression, operator: LogicalPlan): Int = { operator match { case _: HiveTableRelation => 1 - case _ @ LogicalRelation(_: HadoopFsRelation, _, _) => 1 + case _ @ LogicalRelation(_: HadoopFsRelation, _, _, _) => 1 case _: LeafNode => 0 // UNION ALL has multiple children, but these children do not concurrently use InputFileBlock. case u: Union => @@ -454,10 +454,10 @@ object PreWriteCheck extends (LogicalPlan => Unit) { def apply(plan: LogicalPlan): Unit = { plan.foreach { - case InsertIntoTable(l @ LogicalRelation(relation, _, _), partition, query, _, _) => + case InsertIntoTable(l @ LogicalRelation(relation, _, _, _), partition, query, _, _) => // Get all input data source relations of the query. val srcRelations = query.collect { - case LogicalRelation(src, _, _) => src + case LogicalRelation(src, _, _, _) => src } if (srcRelations.contains(relation)) { failAnalysis("Cannot insert into table that is also being read from.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 4b1b2520390ba..f17417343e289 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -171,7 +171,7 @@ class FileStreamSource( className = fileFormatClassName, options = optionsWithPartitionBasePath) Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation( - checkFilesExist = false))) + checkFilesExist = false), isStreaming = true)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index e76d4dc6125df..077a4778e34a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -200,7 +200,8 @@ class RateStreamSource( s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") if (rangeStart == rangeEnd) { - return sqlContext.internalCreateDataFrame(sqlContext.sparkContext.emptyRDD, schema) + return sqlContext.internalCreateDataFrame( + sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) } val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) @@ -211,7 +212,7 @@ class RateStreamSource( val relative = math.round((v - rangeStart) * relativeMsPerValue) InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) } - sqlContext.internalCreateDataFrame(rdd, schema) + sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) } override def stop(): Unit = {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 9bc114f138562..432b2d4925ae2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -609,6 +609,9 @@ class StreamExecution( if committedOffsets.get(source).map(_ != available).getOrElse(true) => val current = committedOffsets.get(source) val batch = source.getBatch(current, available) + assert(batch.isStreaming, + s"DataFrame returned by getBatch from $source did not have isStreaming=true\n" + + s"${batch.queryExecution.logical}") logDebug(s"Retrieving data from $source: $current -> $available") Some(source -> batch) case _ => None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 587ae2bfb63fb..c9784c093b408 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.NonFatal @@ -27,13 +29,14 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils + object MemoryStream { protected val currentBlockId = new AtomicInteger(0) protected val memoryStreamId = new AtomicInteger(0) @@ -44,7 +47,7 @@ object MemoryStream { /** * A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]] - * is primarily intended for use in unit tests as it can only replay data when the object is still + * is intended for use in unit tests as it can only replay data when the object is still * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) @@ -85,8 +88,9 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } def addData(data: TraversableOnce[A]): Offset = { - import sqlContext.implicits._ - val ds = data.toVector.toDS() + val encoded = data.toVector.map(d => encoder.toRow(d).copy()) + val plan = new LocalRelation(schema.toAttributes, encoded, isStreaming = true) + val ds = Dataset[A](sqlContext.sparkSession, plan) logDebug(s"Adding ds: $ds") this.synchronized { currentOffset = currentOffset + 1 @@ -118,8 +122,8 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) batches.slice(sliceStart, sliceEnd) } - logDebug( - s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}") + logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal)) + newBlocks .map(_.toDF()) .reduceOption(_ union _) @@ -128,6 +132,21 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } + private def generateDebugString( + blocks: TraversableOnce[Dataset[A]], + startOrdinal: Int, + endOrdinal: Int): String = { + val originalUnsupportedCheck = + sqlContext.getConf("spark.sql.streaming.unsupportedOperationCheck") + try { + sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", "false") + s"MemoryBatch [$startOrdinal, $endOrdinal]: " + + s"${blocks.flatMap(_.collect()).mkString(", ")}" + } finally { + sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", originalUnsupportedCheck) + } + } + override def commit(end: Offset): Unit = synchronized { def check(newOffset: LongOffset): Unit = { val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala index 58c310596ca6d..223c3d7729a50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala @@ -42,14 +42,14 @@ class OptimizeMetadataOnlyQuerySuite extends QueryTest with SharedSQLContext { private def assertMetadataOnlyQuery(df: DataFrame): Unit = { val localRelations = df.queryExecution.optimizedPlan.collect { - case l @ LocalRelation(_, _) => l + case l @ LocalRelation(_, _, _) => l } assert(localRelations.size == 1) } private def assertNotMetadataOnlyQuery(df: DataFrame): Unit = { val localRelations = df.queryExecution.optimizedPlan.collect { - case l @ LocalRelation(_, _) => l + case l @ LocalRelation(_, _, _) => l } assert(localRelations.size == 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala index aecfd3062147c..5828f9783da42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala @@ -40,7 +40,7 @@ class SparkPlannerSuite extends SharedSQLContext { case Union(children) => planned += 1 UnionExec(children.map(planLater)) :: planLater(NeverPlanned) :: Nil - case LocalRelation(output, data) => + case LocalRelation(output, data, _) => planned += 1 LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil case NeverPlanned => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index d77f0c298ffe3..c1d61b843d899 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -556,7 +556,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi if (buckets > 0) { val bucketed = df.queryExecution.analyzed transform { - case l @ LogicalRelation(r: HadoopFsRelation, _, _) => + case l @ LogicalRelation(r: HadoopFsRelation, _, _, _) => l.copy(relation = r.copy(bucketSpec = Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil)))(r.sparkSession)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index c43c1ec8b9a6b..28e8521b35fa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -63,7 +63,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex var maybeRelation: Option[HadoopFsRelation] = None val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(relation: HadoopFsRelation, _, _)) => + case PhysicalOperation(_, filters, + LogicalRelation(relation: HadoopFsRelation, _, _, _)) => maybeRelation = Some(relation) filters }.flatten.reduceLeftOption(_ && _) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 2f5fd8438f682..837a0872d7b71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -651,7 +651,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val queryExecution = spark.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { case LogicalRelation( - HadoopFsRelation(location: PartitioningAwareFileIndex, _, _, _, _, _), _, _) => + HadoopFsRelation(location: PartitioningAwareFileIndex, _, _, _, _, _), _, _, _) => assert(location.partitionSpec() === PartitionSpec.emptySpec) }.getOrElse { fail(s"Expecting a matching HadoopFsRelation, but got:\n$queryExecution") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 8dc11d80c3063..f951b46e4dd7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -247,7 +247,7 @@ class JDBCSuite extends SparkFunSuite // Check whether the tables are fetched in the expected degree of parallelism def checkNumPartitions(df: DataFrame, expectedNumPartitions: Int): Unit = { val jdbcRelations = df.queryExecution.analyzed.collect { - case LogicalRelation(r: JDBCRelation, _, _) => r + case LogicalRelation(r: JDBCRelation, _, _, _) => r } assert(jdbcRelations.length == 1) assert(jdbcRelations.head.parts.length == expectedNumPartitions, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index fe9469b49e385..c45b507d2b489 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -327,7 +327,7 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic val table = spark.table("oneToTenFiltered") val relation = table.queryExecution.logical.collectFirst { - case LogicalRelation(r, _, _) => r + case LogicalRelation(r, _, _, _) => r }.get assert( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala index 3fd7a5be1da37..85da3f0e38468 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala @@ -135,7 +135,7 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { private def getPathOption(tableName: String): Option[String] = { spark.table(tableName).queryExecution.analyzed.collect { - case LogicalRelation(r: TestOptionsRelation, _, _) => r.pathOption + case LogicalRelation(r: TestOptionsRelation, _, _, _) => r.pathOption }.head } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index a5cf40c3581c6..08db06b94904b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -127,7 +127,7 @@ class FileStreamSinkSuite extends StreamTest { // Verify that MetadataLogFileIndex is being used and the correct partitioning schema has // been inferred val hadoopdFsRelations = outputDf.queryExecution.analyzed.collect { - case LogicalRelation(baseRelation: HadoopFsRelation, _, _) => baseRelation + case LogicalRelation(baseRelation: HadoopFsRelation, _, _, _) => baseRelation } assert(hadoopdFsRelations.size === 1) assert(hadoopdFsRelations.head.location.isInstanceOf[MetadataLogFileIndex]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index e2ec690d90e52..b6baaed1927e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -1105,7 +1105,10 @@ class FileStreamSourceSuite extends FileStreamSourceTest { def verify(startId: Option[Int], endId: Int, expected: String*): Unit = { val start = startId.map(new FileStreamSourceOffset(_)) val end = FileStreamSourceOffset(endId) - assert(fileSource.getBatch(start, end).as[String].collect().toSeq === expected) + + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + assert(fileSource.getBatch(start, end).as[String].collect().toSeq === expected) + } } verify(startId = None, endId = 2, "keep1", "keep2", "keep3") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 6f7b9d35a6bb3..012cccfdd9166 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ @@ -728,7 +729,16 @@ class FakeDefaultSource extends FakeSource { override def getBatch(start: Option[Offset], end: Offset): DataFrame = { val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1 - spark.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a") + val ds = new Dataset[java.lang.Long]( + spark.sparkSession, + Range( + startOffset, + end.asInstanceOf[LongOffset].offset + 1, + 1, + Some(spark.sparkSession.sparkContext.defaultParallelism), + isStreaming = true), + Encoders.LONG) + ds.toDF("a") } override def stop() {} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index b6e82b621c8cb..e0979ce296c3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.streaming import java.util.{Locale, TimeZone} +import org.scalatest.Assertions import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, DataFrame} +import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming._ @@ -31,12 +33,14 @@ import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.types.StructType object FailureSinglton { var firstTime = true } -class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class StreamingAggregationSuite extends StateStoreMetricsTest + with BeforeAndAfterAll with Assertions { override def afterAll(): Unit = { super.afterAll() @@ -356,4 +360,25 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte CheckLastBatch((90L, 1), (100L, 1), (105L, 1)) ) } + + test("SPARK-19690: do not convert batch aggregation in streaming query to streaming") { + val streamInput = MemoryStream[Int] + val batchDF = Seq(1, 2, 3, 4, 5) + .toDF("value") + .withColumn("parity", 'value % 2) + .groupBy('parity) + .agg(count("*") as 'joinValue) + val joinDF = streamInput + .toDF() + .join(batchDF, 'value === 'parity) + + // make sure we're planning an aggregate in the first place + assert(batchDF.queryExecution.optimizedPlan match { case _: Aggregate => true }) + + testStream(joinDF, Append)( + AddData(streamInput, 0, 1, 2, 3), + CheckLastBatch((0, 0, 2), (1, 1, 3)), + AddData(streamInput, 0, 1, 2, 3), + CheckLastBatch((0, 0, 2), (1, 1, 3))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 27ea6902fa1fd..969f594edf615 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -647,7 +647,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi val source = new Source() { override def schema: StructType = triggerDF.schema override def getOffset: Option[Offset] = Some(LongOffset(0)) - override def getBatch(start: Option[Offset], end: Offset): DataFrame = triggerDF + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + sqlContext.internalCreateDataFrame( + triggerDF.queryExecution.toRdd, triggerDF.schema, isStreaming = true) + } override def stop(): Unit = {} } StreamingExecutionRelation(source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index e8a6202b8adce..aa163d2211c38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -88,7 +88,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { override def getBatch(start: Option[Offset], end: Offset): DataFrame = { import spark.implicits._ - Seq[Int]().toDS().toDF() + spark.internalCreateDataFrame(spark.sparkContext.emptyRDD, schema, isStreaming = true) } override def stop() {} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 8bab059ed5e84..f0f2c493498b3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -73,7 +73,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log catalogProxy.getCachedTable(tableIdentifier) match { case null => None // Cache miss - case logical @ LogicalRelation(relation: HadoopFsRelation, _, _) => + case logical @ LogicalRelation(relation: HadoopFsRelation, _, _, _) => val cachedRelationFileFormatClass = relation.fileFormat.getClass expectedFileFormat match { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index e01198dd53178..83cee5d1b8a42 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -583,7 +583,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv Row(3) :: Row(4) :: Nil) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: HadoopFsRelation, _, _) => // OK + case LogicalRelation(p: HadoopFsRelation, _, _, _) => // OK case _ => fail(s"test_parquet_ctas should have be converted to ${classOf[HadoopFsRelation]}") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 02cfa02a37886..d2a6ef7b2b377 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -411,7 +411,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val catalogTable = sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) relation match { - case LogicalRelation(r: HadoopFsRelation, _, _) => + case LogicalRelation(r: HadoopFsRelation, _, _, _) => if (!isDataSourceTable) { fail( s"${classOf[HiveTableRelation].getCanonicalName} is expected, but found " + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala index 222c24927a763..de6f0d67f1734 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala @@ -45,7 +45,7 @@ class OrcFilterSuite extends QueryTest with OrcTest { var maybeRelation: Option[HadoopFsRelation] = None val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _)) => + case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) => maybeRelation = Some(orcRelation) filters }.flatten.reduceLeftOption(_ && _) @@ -89,7 +89,7 @@ class OrcFilterSuite extends QueryTest with OrcTest { var maybeRelation: Option[HadoopFsRelation] = None val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _)) => + case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) => maybeRelation = Some(orcRelation) filters }.flatten.reduceLeftOption(_ && _) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 303884da19f09..740e0837350cc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -285,7 +285,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { ) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(_: HadoopFsRelation, _, _) => // OK + case LogicalRelation(_: HadoopFsRelation, _, _, _) => // OK case _ => fail( "test_parquet_ctas should be converted to " + s"${classOf[HadoopFsRelation ].getCanonicalName }") @@ -370,7 +370,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { assertResult(2) { analyzed.collect { - case r @ LogicalRelation(_: HadoopFsRelation, _, _) => r + case r @ LogicalRelation(_: HadoopFsRelation, _, _, _) => r }.size } } @@ -379,7 +379,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { def collectHadoopFsRelation(df: DataFrame): HadoopFsRelation = { val plan = df.queryExecution.analyzed plan.collectFirst { - case LogicalRelation(r: HadoopFsRelation, _, _) => r + case LogicalRelation(r: HadoopFsRelation, _, _, _) => r }.getOrElse { fail(s"Expecting a HadoopFsRelation 2, but got:\n$plan") } @@ -459,7 +459,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { // Converted test_parquet should be cached. getCachedDataSourceTable(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") - case LogicalRelation(_: HadoopFsRelation, _, _) => // OK + case LogicalRelation(_: HadoopFsRelation, _, _, _) => // OK case other => fail( "The cached test_parquet should be a Parquet Relation. " + From 34296190558435fce73184fb7fb1e3d2ced7c3f6 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 23 Aug 2017 11:06:53 +0800 Subject: [PATCH 031/187] [ML][MINOR] Make sharedParams update. ## What changes were proposed in this pull request? ```sharedParams.scala``` was generated by ```SharedParamsCodeGen```, but it's not updated in master. Maybe someone manual update ```sharedParams.scala```, this PR fix this issue. ## How was this patch tested? Offline check. Author: Yanbo Liang Closes #19011 from yanboliang/sharedParams. --- .../scala/org/apache/spark/ml/param/shared/sharedParams.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 545e45e84e9ea..6061d9ca0a084 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -154,7 +154,7 @@ private[ml] trait HasVarianceCol extends Params { } /** - * Trait for shared param threshold (default: 0.5). + * Trait for shared param threshold. */ private[ml] trait HasThreshold extends Params { From d58a3507ed2d48eabb857c92aecead19a52f4952 Mon Sep 17 00:00:00 2001 From: Jane Wang Date: Wed, 23 Aug 2017 11:31:54 +0800 Subject: [PATCH 032/187] [SPARK-19326] Speculated task attempts do not get launched in few scenarios ## What changes were proposed in this pull request? Add a new listener event when a speculative task is created and notify it to ExecutorAllocationManager for requesting more executor. ## How was this patch tested? - Added Unittests. - For the test snippet in the jira: val n = 100 val someRDD = sc.parallelize(1 to n, n) someRDD.mapPartitionsWithIndex( (index: Int, it: Iterator[Int]) => { if (index == 1) { Thread.sleep(Long.MaxValue) // fake long running task(s) } it.toList.map(x => index + ", " + x).iterator }).collect With this code change, spark indicates 101 jobs are running (99 succeeded, 2 running and 1 is speculative job) Author: Jane Wang Closes #18492 from janewangfb/speculated_task_not_launched. --- .../apache/spark/SparkFirehoseListener.java | 5 ++ .../spark/ExecutorAllocationManager.scala | 61 ++++++++++++++++--- .../apache/spark/scheduler/DAGScheduler.scala | 14 +++++ .../spark/scheduler/DAGSchedulerEvent.scala | 4 ++ .../spark/scheduler/SparkListener.scala | 11 ++++ .../spark/scheduler/SparkListenerBus.scala | 2 + .../spark/scheduler/TaskSetManager.scala | 1 + .../ExecutorAllocationManagerSuite.scala | 48 ++++++++++++++- .../spark/scheduler/TaskSetManagerSuite.scala | 9 +++ 9 files changed, 144 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 140c52fd12f94..3583856d88998 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -139,6 +139,11 @@ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { onEvent(blockUpdated); } + @Override + public void onSpeculativeTaskSubmitted(SparkListenerSpeculativeTaskSubmitted speculativeTask) { + onEvent(speculativeTask); + } + @Override public void onOtherEvent(SparkListenerEvent event) { onEvent(event); diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 337631a6f9a34..33503260bbe02 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -373,8 +373,14 @@ private[spark] class ExecutorAllocationManager( // If our target has not changed, do not send a message // to the cluster manager and reset our exponential growth if (delta == 0) { - numExecutorsToAdd = 1 - return 0 + // Check if there is any speculative jobs pending + if (listener.pendingTasks == 0 && listener.pendingSpeculativeTasks > 0) { + numExecutorsTarget = + math.max(math.min(maxNumExecutorsNeeded + 1, maxNumExecutors), minNumExecutors) + } else { + numExecutorsToAdd = 1 + return 0 + } } val addRequestAcknowledged = try { @@ -588,17 +594,22 @@ private[spark] class ExecutorAllocationManager( * A listener that notifies the given allocation manager of when to add and remove executors. * * This class is intentionally conservative in its assumptions about the relative ordering - * and consistency of events returned by the listener. For simplicity, it does not account - * for speculated tasks. + * and consistency of events returned by the listener. */ private class ExecutorAllocationListener extends SparkListener { private val stageIdToNumTasks = new mutable.HashMap[Int, Int] private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]] private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]] - // Number of tasks currently running on the cluster. Should be 0 when no stages are active. + // Number of tasks currently running on the cluster including speculative tasks. + // Should be 0 when no stages are active. private var numRunningTasks: Int = _ + // Number of speculative tasks to be scheduled in each stage + private val stageIdToNumSpeculativeTasks = new mutable.HashMap[Int, Int] + // The speculative tasks started in each stage + private val stageIdToSpeculativeTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]] + // stageId to tuple (the number of task with locality preferences, a map where each pair is a // node and the number of tasks that would like to be scheduled on that node) map, // maintain the executor placement hints for each stage Id used by resource framework to better @@ -637,7 +648,9 @@ private[spark] class ExecutorAllocationManager( val stageId = stageCompleted.stageInfo.stageId allocationManager.synchronized { stageIdToNumTasks -= stageId + stageIdToNumSpeculativeTasks -= stageId stageIdToTaskIndices -= stageId + stageIdToSpeculativeTaskIndices -= stageId stageIdToExecutorPlacementHints -= stageId // Update the executor placement hints @@ -645,7 +658,7 @@ private[spark] class ExecutorAllocationManager( // If this is the last stage with pending tasks, mark the scheduler queue as empty // This is needed in case the stage is aborted for any reason - if (stageIdToNumTasks.isEmpty) { + if (stageIdToNumTasks.isEmpty && stageIdToNumSpeculativeTasks.isEmpty) { allocationManager.onSchedulerQueueEmpty() if (numRunningTasks != 0) { logWarning("No stages are running, but numRunningTasks != 0") @@ -671,7 +684,12 @@ private[spark] class ExecutorAllocationManager( } // If this is the last pending task, mark the scheduler queue as empty - stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex + if (taskStart.taskInfo.speculative) { + stageIdToSpeculativeTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += + taskIndex + } else { + stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex + } if (totalPendingTasks() == 0) { allocationManager.onSchedulerQueueEmpty() } @@ -705,7 +723,11 @@ private[spark] class ExecutorAllocationManager( if (totalPendingTasks() == 0) { allocationManager.onSchedulerBacklogged() } - stageIdToTaskIndices.get(stageId).foreach { _.remove(taskIndex) } + if (taskEnd.taskInfo.speculative) { + stageIdToSpeculativeTaskIndices.get(stageId).foreach {_.remove(taskIndex)} + } else { + stageIdToTaskIndices.get(stageId).foreach {_.remove(taskIndex)} + } } } } @@ -726,18 +748,39 @@ private[spark] class ExecutorAllocationManager( allocationManager.onExecutorRemoved(executorRemoved.executorId) } + override def onSpeculativeTaskSubmitted(speculativeTask: SparkListenerSpeculativeTaskSubmitted) + : Unit = { + val stageId = speculativeTask.stageId + + allocationManager.synchronized { + stageIdToNumSpeculativeTasks(stageId) = + stageIdToNumSpeculativeTasks.getOrElse(stageId, 0) + 1 + allocationManager.onSchedulerBacklogged() + } + } + /** * An estimate of the total number of pending tasks remaining for currently running stages. Does * not account for tasks which may have failed and been resubmitted. * * Note: This is not thread-safe without the caller owning the `allocationManager` lock. */ - def totalPendingTasks(): Int = { + def pendingTasks(): Int = { stageIdToNumTasks.map { case (stageId, numTasks) => numTasks - stageIdToTaskIndices.get(stageId).map(_.size).getOrElse(0) }.sum } + def pendingSpeculativeTasks(): Int = { + stageIdToNumSpeculativeTasks.map { case (stageId, numTasks) => + numTasks - stageIdToSpeculativeTaskIndices.get(stageId).map(_.size).getOrElse(0) + }.sum + } + + def totalPendingTasks(): Int = { + pendingTasks + pendingSpeculativeTasks + } + /** * The number of tasks currently running across all stages. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 21bf9d013ebef..562dd1da4fe14 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -281,6 +281,13 @@ class DAGScheduler( eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception)) } + /** + * Called by the TaskSetManager when it decides a speculative task is needed. + */ + def speculativeTaskSubmitted(task: Task[_]): Unit = { + eventProcessLoop.post(SpeculativeTaskSubmitted(task)) + } + private[scheduler] def getCacheLocs(rdd: RDD[_]): IndexedSeq[Seq[TaskLocation]] = cacheLocs.synchronized { // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times @@ -812,6 +819,10 @@ class DAGScheduler( listenerBus.post(SparkListenerTaskStart(task.stageId, stageAttemptId, taskInfo)) } + private[scheduler] def handleSpeculativeTaskSubmitted(task: Task[_]): Unit = { + listenerBus.post(SparkListenerSpeculativeTaskSubmitted(task.stageId)) + } + private[scheduler] def handleTaskSetFailed( taskSet: TaskSet, reason: String, @@ -1778,6 +1789,9 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case BeginEvent(task, taskInfo) => dagScheduler.handleBeginEvent(task, taskInfo) + case SpeculativeTaskSubmitted(task) => + dagScheduler.handleSpeculativeTaskSubmitted(task) + case GettingResultEvent(taskInfo) => dagScheduler.handleGetTaskResult(taskInfo) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 3f8d5639a2b90..54ab8f8b3e1d8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -94,3 +94,7 @@ case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Thr extends DAGSchedulerEvent private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent + +private[scheduler] +case class SpeculativeTaskSubmitted(task: Task[_]) extends DAGSchedulerEvent + diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 59f89a82a1da8..b76e560669d59 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -52,6 +52,9 @@ case class SparkListenerTaskStart(stageId: Int, stageAttemptId: Int, taskInfo: T @DeveloperApi case class SparkListenerTaskGettingResult(taskInfo: TaskInfo) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerSpeculativeTaskSubmitted(stageId: Int) extends SparkListenerEvent + @DeveloperApi case class SparkListenerTaskEnd( stageId: Int, @@ -290,6 +293,11 @@ private[spark] trait SparkListenerInterface { */ def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit + /** + * Called when a speculative task is submitted + */ + def onSpeculativeTaskSubmitted(speculativeTask: SparkListenerSpeculativeTaskSubmitted): Unit + /** * Called when other events like SQL-specific events are posted. */ @@ -354,5 +362,8 @@ abstract class SparkListener extends SparkListenerInterface { override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { } + override def onSpeculativeTaskSubmitted( + speculativeTask: SparkListenerSpeculativeTaskSubmitted): Unit = { } + override def onOtherEvent(event: SparkListenerEvent): Unit = { } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 3b0d3b1b150fe..056c0cbded435 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -71,6 +71,8 @@ private[spark] trait SparkListenerBus listener.onNodeUnblacklisted(nodeUnblacklisted) case blockUpdated: SparkListenerBlockUpdated => listener.onBlockUpdated(blockUpdated) + case speculativeTaskSubmitted: SparkListenerSpeculativeTaskSubmitted => + listener.onSpeculativeTaskSubmitted(speculativeTaskSubmitted) case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index c2f817858473c..3804ea863b4f9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -966,6 +966,7 @@ private[spark] class TaskSetManager( "Marking task %d in stage %s (on %s) as speculatable because it ran more than %.0f ms" .format(index, taskSet.id, info.host, threshold)) speculatableTasks += index + sched.dagScheduler.speculativeTaskSubmitted(tasks(index)) foundTasks = true } } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index b9ce71a0c5254..7da4bae0ab7eb 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -188,6 +188,40 @@ class ExecutorAllocationManagerSuite assert(numExecutorsTarget(manager) === 10) } + test("add executors when speculative tasks added") { + sc = createSparkContext(0, 10, 0) + val manager = sc.executorAllocationManager.get + + // Verify that we're capped at number of tasks including the speculative ones in the stage + sc.listenerBus.postToAll(SparkListenerSpeculativeTaskSubmitted(1)) + assert(numExecutorsTarget(manager) === 0) + assert(numExecutorsToAdd(manager) === 1) + assert(addExecutors(manager) === 1) + sc.listenerBus.postToAll(SparkListenerSpeculativeTaskSubmitted(1)) + sc.listenerBus.postToAll(SparkListenerSpeculativeTaskSubmitted(1)) + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 2))) + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 2) + assert(addExecutors(manager) === 2) + assert(numExecutorsTarget(manager) === 3) + assert(numExecutorsToAdd(manager) === 4) + assert(addExecutors(manager) === 2) + assert(numExecutorsTarget(manager) === 5) + assert(numExecutorsToAdd(manager) === 1) + + // Verify that running a task doesn't affect the target + sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) + assert(numExecutorsTarget(manager) === 5) + assert(addExecutors(manager) === 0) + assert(numExecutorsToAdd(manager) === 1) + + // Verify that running a speculative task doesn't affect the target + sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-2", true))) + assert(numExecutorsTarget(manager) === 5) + assert(addExecutors(manager) === 0) + assert(numExecutorsToAdd(manager) === 1) + } + test("cancel pending executors when no longer needed") { sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get @@ -1031,10 +1065,15 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { taskLocalityPreferences = taskLocalityPreferences) } - private def createTaskInfo(taskId: Int, taskIndex: Int, executorId: String): TaskInfo = { - new TaskInfo(taskId, taskIndex, 0, 0, executorId, "", TaskLocality.ANY, speculative = false) + private def createTaskInfo( + taskId: Int, + taskIndex: Int, + executorId: String, + speculative: Boolean = false): TaskInfo = { + new TaskInfo(taskId, taskIndex, 0, 0, executorId, "", TaskLocality.ANY, speculative) } + /* ------------------------------------------------------- * | Helper methods for accessing private methods and fields | * ------------------------------------------------------- */ @@ -1061,6 +1100,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val _onExecutorBusy = PrivateMethod[Unit]('onExecutorBusy) private val _localityAwareTasks = PrivateMethod[Int]('localityAwareTasks) private val _hostToLocalTaskCount = PrivateMethod[Map[String, Int]]('hostToLocalTaskCount) + private val _onSpeculativeTaskSubmitted = PrivateMethod[Unit]('onSpeculativeTaskSubmitted) private def numExecutorsToAdd(manager: ExecutorAllocationManager): Int = { manager invokePrivate _numExecutorsToAdd() @@ -1136,6 +1176,10 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { manager invokePrivate _onExecutorBusy(id) } + private def onSpeculativeTaskSubmitted(manager: ExecutorAllocationManager, id: String) : Unit = { + manager invokePrivate _onSpeculativeTaskSubmitted(id) + } + private def localityAwareTasks(manager: ExecutorAllocationManager): Int = { manager invokePrivate _localityAwareTasks() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 6f1663b210969..ae43f4cadc037 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -60,6 +60,10 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) exception: Option[Throwable]): Unit = { taskScheduler.taskSetsFailed += taskSet.id } + + override def speculativeTaskSubmitted(task: Task[_]): Unit = { + taskScheduler.speculativeTasks += task.partitionId + } } // Get the rack for a given host @@ -92,6 +96,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex val endedTasks = new mutable.HashMap[Long, TaskEndReason] val finishedManagers = new ArrayBuffer[TaskSetManager] val taskSetsFailed = new ArrayBuffer[String] + val speculativeTasks = new ArrayBuffer[Int] val executors = new mutable.HashMap[String, String] for ((execId, host) <- liveExecutors) { @@ -139,6 +144,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex } } + override def getRackForHost(value: String): Option[String] = FakeRackUtil.getRackForHost(value) } @@ -929,6 +935,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // > 0ms, so advance the clock by 1ms here. clock.advance(1) assert(manager.checkSpeculatableTasks(0)) + assert(sched.speculativeTasks.toSet === Set(3)) + // Offer resource to start the speculative attempt for the running task val taskOption5 = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption5.isDefined) @@ -1016,6 +1024,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // > 0ms, so advance the clock by 1ms here. clock.advance(1) assert(manager.checkSpeculatableTasks(0)) + assert(sched.speculativeTasks.toSet === Set(3, 4)) // Offer resource to start the speculative attempt for the running task val taskOption5 = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption5.isDefined) From d6b30edd4974b593cc8085f680ccb524c7722c85 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Tue, 22 Aug 2017 21:16:34 -0700 Subject: [PATCH 033/187] [SPARK-12664][ML] Expose probability in mlp model ## What changes were proposed in this pull request? Modify MLP model to inherit `ProbabilisticClassificationModel` and so that it can expose the probability column when transforming data. ## How was this patch tested? Test added. Author: WeichenXu Closes #17373 from WeichenXu123/expose_probability_in_mlp_model. --- .../scala/org/apache/spark/ml/ann/Layer.scala | 53 ++++++++++++++++--- .../MultilayerPerceptronClassifier.scala | 17 ++++-- .../apache/spark/ml/ann/GradientSuite.scala | 2 +- .../MultilayerPerceptronClassifierSuite.scala | 42 +++++++++++++++ python/pyspark/ml/classification.py | 4 +- 5 files changed, 103 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index e7e0dae0b5a01..014ff07c21158 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -361,17 +361,42 @@ private[ann] trait TopologyModel extends Serializable { * Forward propagation * * @param data input data + * @param includeLastLayer Include the last layer in the output. In + * MultilayerPerceptronClassifier, the last layer is always softmax; + * the last layer of outputs is needed for class predictions, but not + * for rawPrediction. + * * @return array of outputs for each of the layers */ - def forward(data: BDM[Double]): Array[BDM[Double]] + def forward(data: BDM[Double], includeLastLayer: Boolean): Array[BDM[Double]] /** - * Prediction of the model + * Prediction of the model. See {@link ProbabilisticClassificationModel} * - * @param data input data + * @param features input features * @return prediction */ - def predict(data: Vector): Vector + def predict(features: Vector): Vector + + /** + * Raw prediction of the model. See {@link ProbabilisticClassificationModel} + * + * @param features input features + * @return raw prediction + * + * Note: This interface is only used for classification Model. + */ + def predictRaw(features: Vector): Vector + + /** + * Probability of the model. See {@link ProbabilisticClassificationModel} + * + * @param rawPrediction raw prediction vector + * @return probability + * + * Note: This interface is only used for classification Model. + */ + def raw2ProbabilityInPlace(rawPrediction: Vector): Vector /** * Computes gradient for the network @@ -463,7 +488,7 @@ private[ml] class FeedForwardModel private( private var outputs: Array[BDM[Double]] = null private var deltas: Array[BDM[Double]] = null - override def forward(data: BDM[Double]): Array[BDM[Double]] = { + override def forward(data: BDM[Double], includeLastLayer: Boolean): Array[BDM[Double]] = { // Initialize output arrays for all layers. Special treatment for InPlace val currentBatchSize = data.cols // TODO: allocate outputs as one big array and then create BDMs from it @@ -481,7 +506,8 @@ private[ml] class FeedForwardModel private( } } layerModels(0).eval(data, outputs(0)) - for (i <- 1 until layerModels.length) { + val end = if (includeLastLayer) layerModels.length else layerModels.length - 1 + for (i <- 1 until end) { layerModels(i).eval(outputs(i - 1), outputs(i)) } outputs @@ -492,7 +518,7 @@ private[ml] class FeedForwardModel private( target: BDM[Double], cumGradient: Vector, realBatchSize: Int): Double = { - val outputs = forward(data) + val outputs = forward(data, true) val currentBatchSize = data.cols // TODO: allocate deltas as one big array and then create BDMs from it if (deltas == null || deltas(0).cols != currentBatchSize) { @@ -527,9 +553,20 @@ private[ml] class FeedForwardModel private( override def predict(data: Vector): Vector = { val size = data.size - val result = forward(new BDM[Double](size, 1, data.toArray)) + val result = forward(new BDM[Double](size, 1, data.toArray), true) Vectors.dense(result.last.toArray) } + + override def predictRaw(data: Vector): Vector = { + val result = forward(new BDM[Double](data.size, 1, data.toArray), false) + Vectors.dense(result(result.length - 2).toArray) + } + + override def raw2ProbabilityInPlace(data: Vector): Vector = { + val dataMatrix = new BDM[Double](data.size, 1, data.toArray) + layerModels.last.eval(dataMatrix, dataMatrix) + data + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index ceba11edc93be..14a0c9f5a66dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -32,7 +32,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.sql.Dataset /** Params for Multilayer Perceptron. */ -private[classification] trait MultilayerPerceptronParams extends PredictorParams +private[classification] trait MultilayerPerceptronParams extends ProbabilisticClassifierParams with HasSeed with HasMaxIter with HasTol with HasStepSize with HasSolver { import MultilayerPerceptronClassifier._ @@ -143,7 +143,8 @@ private object LabelConverter { @Since("1.5.0") class MultilayerPerceptronClassifier @Since("1.5.0") ( @Since("1.5.0") override val uid: String) - extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel] + extends ProbabilisticClassifier[Vector, MultilayerPerceptronClassifier, + MultilayerPerceptronClassificationModel] with MultilayerPerceptronParams with DefaultParamsWritable { @Since("1.5.0") @@ -301,13 +302,13 @@ class MultilayerPerceptronClassificationModel private[ml] ( @Since("1.5.0") override val uid: String, @Since("1.5.0") val layers: Array[Int], @Since("2.0.0") val weights: Vector) - extends PredictionModel[Vector, MultilayerPerceptronClassificationModel] + extends ProbabilisticClassificationModel[Vector, MultilayerPerceptronClassificationModel] with Serializable with MLWritable { @Since("1.6.0") override val numFeatures: Int = layers.head - private val mlpModel = FeedForwardTopology + private[ml] val mlpModel = FeedForwardTopology .multiLayerPerceptron(layers, softmaxOnTop = true) .model(weights) @@ -335,6 +336,14 @@ class MultilayerPerceptronClassificationModel private[ml] ( @Since("2.0.0") override def write: MLWriter = new MultilayerPerceptronClassificationModel.MultilayerPerceptronClassificationModelWriter(this) + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + mlpModel.raw2ProbabilityInPlace(rawPrediction) + } + + override protected def predictRaw(features: Vector): Vector = mlpModel.predictRaw(features) + + override def numClasses: Int = layers.last } @Since("2.0.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala index f0c0183323c92..2f225645bdfc4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala @@ -64,7 +64,7 @@ class GradientSuite extends SparkFunSuite with MLlibTestSparkContext { } private def computeLoss(input: BDM[Double], target: BDM[Double], model: TopologyModel): Double = { - val outputs = model.forward(input) + val outputs = model.forward(input, true) model.layerModels.last match { case layerWithLoss: LossFunction => layerWithLoss.loss(outputs.last, target, new BDM[Double](target.rows, target.cols)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index ce54c3df4f3f6..c294e4ad54bf7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.sql.functions._ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -82,6 +83,47 @@ class MultilayerPerceptronClassifierSuite } } + test("Predicted class probabilities: calibration on toy dataset") { + val layers = Array[Int](4, 5, 2) + + val strongDataset = Seq( + (Vectors.dense(1, 2, 3, 4), 0d, Vectors.dense(1d, 0d)), + (Vectors.dense(4, 3, 2, 1), 1d, Vectors.dense(0d, 1d)), + (Vectors.dense(1, 1, 1, 1), 0d, Vectors.dense(.5, .5)), + (Vectors.dense(1, 1, 1, 1), 1d, Vectors.dense(.5, .5)) + ).toDF("features", "label", "expectedProbability") + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(123L) + .setMaxIter(100) + .setSolver("l-bfgs") + val model = trainer.fit(strongDataset) + val result = model.transform(strongDataset) + result.select("probability", "expectedProbability").collect().foreach { + case Row(p: Vector, e: Vector) => + assert(p ~== e absTol 1e-3) + } + } + + test("test model probability") { + val layers = Array[Int](2, 5, 2) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(123L) + .setMaxIter(100) + .setSolver("l-bfgs") + val model = trainer.fit(dataset) + model.setProbabilityCol("probability") + val result = model.transform(dataset) + val features2prob = udf { features: Vector => model.mlpModel.predict(features) } + result.select(features2prob(col("features")), col("probability")).collect().foreach { + case Row(p1: Vector, p2: Vector) => + assert(p1 ~== p2 absTol 1e-3) + } + } + test("Test setWeights by training restart") { val dataFrame = Seq( (Vectors.dense(0.0, 0.0), 0.0), diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 235cee48bc6a6..f0f42a34942d7 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1378,7 +1378,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, >>> testDF = spark.createDataFrame([ ... (Vectors.dense([1.0, 0.0]),), ... (Vectors.dense([0.0, 0.0]),)], ["features"]) - >>> model.transform(testDF).show() + >>> model.transform(testDF).select("features", "prediction").show() +---------+----------+ | features|prediction| +---------+----------+ @@ -1512,7 +1512,7 @@ def getInitialWeights(self): return self.getOrDefault(self.initialWeights) -class MultilayerPerceptronClassificationModel(JavaModel, JavaPredictionModel, JavaMLWritable, +class MultilayerPerceptronClassificationModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): """ Model fitted by MultilayerPerceptronClassifier. From 1662e93119d68498942386906de309d35f4a135f Mon Sep 17 00:00:00 2001 From: Sanket Chintapalli Date: Wed, 23 Aug 2017 11:51:11 -0500 Subject: [PATCH 034/187] [SPARK-21501] Change CacheLoader to limit entries based on memory footprint Right now the spark shuffle service has a cache for index files. It is based on a # of files cached (spark.shuffle.service.index.cache.entries). This can cause issues if people have a lot of reducers because the size of each entry can fluctuate based on the # of reducers. We saw an issues with a job that had 170000 reducers and it caused NM with spark shuffle service to use 700-800MB or memory in NM by itself. We should change this cache to be memory based and only allow a certain memory size used. When I say memory based I mean the cache should have a limit of say 100MB. https://issues.apache.org/jira/browse/SPARK-21501 Manual Testing with 170000 reducers has been performed with cache loaded up to max 100MB default limit, with each shuffle index file of size 1.3MB. Eviction takes place as soon as the total cache size reaches the 100MB limit and the objects will be ready for garbage collection there by avoiding NM to crash. No notable difference in runtime has been observed. Author: Sanket Chintapalli Closes #18940 from redsanket/SPARK-21501. --- .../org/apache/spark/network/util/TransportConf.java | 4 ++++ .../network/shuffle/ExternalShuffleBlockResolver.java | 11 +++++++++-- .../network/shuffle/ShuffleIndexInformation.java | 11 ++++++++++- core/src/main/scala/org/apache/spark/SparkConf.scala | 4 +++- docs/configuration.md | 6 +++--- 5 files changed, 29 insertions(+), 7 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 88256b810bf04..fa2ff42de07d0 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -67,6 +67,10 @@ public int getInt(String name, int defaultValue) { return conf.getInt(name, defaultValue); } + public String get(String name, String defaultValue) { + return conf.get(name, defaultValue); + } + private String getConfKey(String suffix) { return "spark." + module + "." + suffix; } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index d7ec0e299dead..e6399897be9c2 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -33,6 +33,7 @@ import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; +import com.google.common.cache.Weigher; import com.google.common.collect.Maps; import org.iq80.leveldb.DB; import org.iq80.leveldb.DBIterator; @@ -104,7 +105,7 @@ public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorF Executor directoryCleaner) throws IOException { this.conf = conf; this.registeredExecutorFile = registeredExecutorFile; - int indexCacheEntries = conf.getInt("spark.shuffle.service.index.cache.entries", 1024); + String indexCacheSize = conf.get("spark.shuffle.service.index.cache.size", "100m"); CacheLoader indexCacheLoader = new CacheLoader() { public ShuffleIndexInformation load(File file) throws IOException { @@ -112,7 +113,13 @@ public ShuffleIndexInformation load(File file) throws IOException { } }; shuffleIndexCache = CacheBuilder.newBuilder() - .maximumSize(indexCacheEntries).build(indexCacheLoader); + .maximumWeight(JavaUtils.byteStringAsBytes(indexCacheSize)) + .weigher(new Weigher() { + public int weigh(File file, ShuffleIndexInformation indexInfo) { + return indexInfo.getSize(); + } + }) + .build(indexCacheLoader); db = LevelDBProvider.initLevelDB(this.registeredExecutorFile, CURRENT_VERSION, mapper); if (db != null) { executors = reloadRegisteredExecutors(db); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index 39ca9ba574853..386738ece51a6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -31,9 +31,10 @@ public class ShuffleIndexInformation { /** offsets as long buffer */ private final LongBuffer offsets; + private int size; public ShuffleIndexInformation(File indexFile) throws IOException { - int size = (int)indexFile.length(); + size = (int)indexFile.length(); ByteBuffer buffer = ByteBuffer.allocate(size); offsets = buffer.asLongBuffer(); DataInputStream dis = null; @@ -47,6 +48,14 @@ public ShuffleIndexInformation(File indexFile) throws IOException { } } + /** + * Size of the index file + * @return size + */ + public int getSize() { + return size; + } + /** * Get index offset for a particular reducer. */ diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 715cfdcc8f4ef..e61f943af49f2 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -597,7 +597,9 @@ private[spark] object SparkConf extends Logging { DeprecatedConfig("spark.scheduler.executorTaskBlacklistTime", "2.1.0", "Please use the new blacklisting options, spark.blacklist.*"), DeprecatedConfig("spark.yarn.am.port", "2.0.0", "Not used any more"), - DeprecatedConfig("spark.executor.port", "2.0.0", "Not used any more") + DeprecatedConfig("spark.executor.port", "2.0.0", "Not used any more"), + DeprecatedConfig("spark.shuffle.service.index.cache.entries", "2.3.0", + "Not used any more. Please use spark.shuffle.service.index.cache.size") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) diff --git a/docs/configuration.md b/docs/configuration.md index e7c0306920e08..6e9fe591b70a3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -627,10 +627,10 @@ Apart from these, the following properties are also available, and may be useful - spark.shuffle.service.index.cache.entries - 1024 + spark.shuffle.service.index.cache.size + 100m - Max number of entries to keep in the index cache of the shuffle service. + Cache entries limited to the specified memory footprint. From 6942aeeb0a0095a1ba85a817eb9e0edc410e5624 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 23 Aug 2017 12:02:24 -0700 Subject: [PATCH 035/187] [SPARK-21603][SQL][FOLLOW-UP] Change the default value of maxLinesPerFunction into 4000 ## What changes were proposed in this pull request? This pr changed the default value of `maxLinesPerFunction` into `4000`. In #18810, we had this new option to disable code generation for too long functions and I found this option only affected `Q17` and `Q66` in TPC-DS. But, `Q66` had some performance regression: ``` Q17 w/o #18810, 3224ms --> q17 w/#18810, 2627ms (improvement) Q66 w/o #18810, 1712ms --> q66 w/#18810, 3032ms (regression) ``` To keep the previous performance in TPC-DS, we better set higher value at `maxLinesPerFunction` by default. ## How was this patch tested? Existing tests. Author: Takeshi Yamamuro Closes #19021 from maropu/SPARK-21603-FOLLOWUP-1. --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2c7397c1ec774..a685099505ee8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -577,10 +577,10 @@ object SQLConf { .doc("The maximum lines of a single Java function generated by whole-stage codegen. " + "When the generated function exceeds this threshold, " + "the whole-stage codegen is deactivated for this subtree of the current query plan. " + - "The default value 2667 is the max length of byte code JIT supported " + - "for a single function(8000) divided by 3.") + "The default value 4000 is the max length of byte code JIT supported " + + "for a single function(8000) divided by 2.") .intConf - .createWithDefault(2667) + .createWithDefault(4000) val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes") .doc("The maximum number of bytes to pack into a single partition when reading files.") From b8aaef49fbf02401c874b06d17cbe354f739b9e7 Mon Sep 17 00:00:00 2001 From: 10129659 Date: Wed, 23 Aug 2017 20:35:08 -0700 Subject: [PATCH 036/187] [SPARK-21807][SQL] Override ++ operation in ExpressionSet to reduce clone time ## What changes were proposed in this pull request? The getAliasedConstraints fuction in LogicalPlan.scala will clone the expression set when an element added, and it will take a long time. This PR add a function to add multiple elements at once to reduce the clone time. Before modified, the cost of getAliasedConstraints is: 100 expressions: 41 seconds 150 expressions: 466 seconds After modified, the cost of getAliasedConstraints is: 100 expressions: 1.8 seconds 150 expressions: 6.5 seconds The test is like this: test("getAliasedConstraints") { val expressionNum = 150 val aggExpression = (1 to expressionNum).map(i => Alias(Count(Literal(1)), s"cnt$i")()) val aggPlan = Aggregate(Nil, aggExpression, LocalRelation()) val beginTime = System.currentTimeMillis() val expressions = aggPlan.validConstraints println(s"validConstraints cost: ${System.currentTimeMillis() - beginTime}ms") // The size of Aliased expression is n * (n - 1) / 2 + n assert( expressions.size === expressionNum * (expressionNum - 1) / 2 + expressionNum) } (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Run new added test. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: 10129659 Closes #19022 from eatoncys/getAliasedConstraints. --- .../spark/sql/catalyst/expressions/ExpressionSet.scala | 8 +++++++- .../sql/catalyst/expressions/ExpressionSetSuite.scala | 9 +++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala index ede0b1654bbd6..305ac90e245b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.mutable +import scala.collection.{mutable, GenTraversableOnce} import scala.collection.mutable.ArrayBuffer object ExpressionSet { @@ -67,6 +67,12 @@ class ExpressionSet protected( newSet } + override def ++(elems: GenTraversableOnce[Expression]): ExpressionSet = { + val newSet = new ExpressionSet(baseSet.clone(), originals.clone()) + elems.foreach(newSet.add) + newSet + } + override def -(elem: Expression): ExpressionSet = { val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized) val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index d617ad540d5ff..a1000a0e80799 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -210,4 +210,13 @@ class ExpressionSetSuite extends SparkFunSuite { assert((initialSet - (aLower + 1)).size == 0) } + + test("add multiple elements to set") { + val initialSet = ExpressionSet(aUpper + 1 :: Nil) + val setToAddWithSameExpression = ExpressionSet(aUpper + 1 :: aUpper + 2 :: Nil) + val setToAddWithOutSameExpression = ExpressionSet(aUpper + 3 :: aUpper + 4 :: Nil) + + assert((initialSet ++ setToAddWithSameExpression).size == 2) + assert((initialSet ++ setToAddWithOutSameExpression).size == 3) + } } From 43cbfad9992624d89bbb3209d1f5b765c7947bb9 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 23 Aug 2017 21:35:17 -0700 Subject: [PATCH 037/187] [SPARK-21805][SPARKR] Disable R vignettes code on Windows ## What changes were proposed in this pull request? Code in vignettes requires winutils on windows to run, when publishing to CRAN or building from source, winutils might not be available, so it's better to disable code run (so resulting vigenttes will not have output from code, but text is still there and code is still there) fix * checking re-building of vignette outputs ... WARNING and > %LOCALAPPDATA% not found. Please define the environment variable or restart and enter an installation path in localDir. ## How was this patch tested? jenkins, appveyor, r-hub before: https://artifacts.r-hub.io/SparkR_2.2.0.tar.gz-49cecef3bb09db1db130db31604e0293/SparkR.Rcheck/00check.log after: https://artifacts.r-hub.io/SparkR_2.2.0.tar.gz-86a066c7576f46794930ad114e5cff7c/SparkR.Rcheck/00check.log Author: Felix Cheung Closes #19016 from felixcheung/rvigwind. --- R/pkg/DESCRIPTION | 2 +- R/pkg/R/install.R | 6 +++++- R/pkg/vignettes/sparkr-vignettes.Rmd | 11 +++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index b739d423a36cc..d1c846c048274 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -2,7 +2,7 @@ Package: SparkR Type: Package Version: 2.3.0 Title: R Frontend for Apache Spark -Description: The SparkR package provides an R Frontend for Apache Spark. +Description: Provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), email = "shivaram@cs.berkeley.edu"), person("Xiangrui", "Meng", role = "aut", diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R index ec931befa2854..492dee68e164d 100644 --- a/R/pkg/R/install.R +++ b/R/pkg/R/install.R @@ -270,7 +270,11 @@ sparkCachePath <- function() { if (is_windows()) { winAppPath <- Sys.getenv("LOCALAPPDATA", unset = NA) if (is.na(winAppPath)) { - stop(paste("%LOCALAPPDATA% not found.", + message("%LOCALAPPDATA% not found. Falling back to %USERPROFILE%.") + winAppPath <- Sys.getenv("USERPROFILE", unset = NA) + } + if (is.na(winAppPath)) { + stop(paste("%LOCALAPPDATA% and %USERPROFILE% not found.", "Please define the environment variable", "or restart and enter an installation path in localDir.")) } else { diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 2301a64576d0e..caeae72e37bbf 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -27,6 +27,17 @@ vignette: > limitations under the License. --> +```{r setup, include=FALSE} +library(knitr) +opts_hooks$set(eval = function(options) { + # override eval to FALSE only on windows + if (.Platform$OS.type == "windows") { + options$eval = FALSE + } + options +}) +``` + ## Overview SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](http://spark.apache.org/mllib/). From ce0d3bb377766bdf4df7852272557ae846408877 Mon Sep 17 00:00:00 2001 From: "Susan X. Huynh" Date: Thu, 24 Aug 2017 10:05:38 +0100 Subject: [PATCH 038/187] [SPARK-21694][MESOS] Support Mesos CNI network labels JIRA ticket: https://issues.apache.org/jira/browse/SPARK-21694 ## What changes were proposed in this pull request? Spark already supports launching containers attached to a given CNI network by specifying it via the config `spark.mesos.network.name`. This PR adds support to pass in network labels to CNI plugins via a new config option `spark.mesos.network.labels`. These network labels are key-value pairs that are set in the `NetworkInfo` of both the driver and executor tasks. More details in the related Mesos documentation: http://mesos.apache.org/documentation/latest/cni/#mesos-meta-data-to-cni-plugins ## How was this patch tested? Unit tests, for both driver and executor tasks. Manual integration test to submit a job with the `spark.mesos.network.labels` option, hit the mesos/state.json endpoint, and check that the labels are set in the driver and executor tasks. ArtRand skonto Author: Susan X. Huynh Closes #18910 from susanxhuynh/sh-mesos-cni-labels. --- docs/running-on-mesos.md | 14 ++++++++++++++ .../apache/spark/deploy/mesos/config.scala | 19 +++++++++++++++++-- .../MesosCoarseGrainedSchedulerBackend.scala | 2 +- .../mesos/MesosSchedulerBackendUtil.scala | 9 +++++++-- .../mesos/MesosClusterSchedulerSuite.scala | 9 +++++++-- ...osCoarseGrainedSchedulerBackendSuite.scala | 9 +++++++-- 6 files changed, 53 insertions(+), 9 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index ae3855084a650..0e5a20c578db3 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -537,6 +537,20 @@ See the [configuration page](configuration.html) for information on Spark config for more details. + + spark.mesos.network.labels + (none) + + Pass network labels to CNI plugins. This is a comma-separated list + of key-value pairs, where each key-value pair has the format key:value. + Example: + +
key1:val1,key2:val2
+ See + the Mesos CNI docs + for more details. + + spark.mesos.fetcherCache.enable false diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala index 6c8619e3c3c13..a5015b9243316 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala @@ -56,7 +56,7 @@ package object config { .stringConf .createOptional - private [spark] val DRIVER_LABELS = + private[spark] val DRIVER_LABELS = ConfigBuilder("spark.mesos.driver.labels") .doc("Mesos labels to add to the driver. Labels are free-form key-value pairs. Key-value " + "pairs should be separated by a colon, and commas used to list more than one." + @@ -64,10 +64,25 @@ package object config { .stringConf .createOptional - private [spark] val DRIVER_FAILOVER_TIMEOUT = + private[spark] val DRIVER_FAILOVER_TIMEOUT = ConfigBuilder("spark.mesos.driver.failoverTimeout") .doc("Amount of time in seconds that the master will wait to hear from the driver, " + "during a temporary disconnection, before tearing down all the executors.") .doubleConf .createWithDefault(0.0) + + private[spark] val NETWORK_NAME = + ConfigBuilder("spark.mesos.network.name") + .doc("Attach containers to the given named network. If this job is launched " + + "in cluster mode, also launch the driver in the given named network.") + .stringConf + .createOptional + + private[spark] val NETWORK_LABELS = + ConfigBuilder("spark.mesos.network.labels") + .doc("Network labels to pass to CNI plugins. This is a comma-separated list " + + "of key-value pairs, where each key-value pair has the format key:value. " + + "Example: key1:val1,key2:val2") + .stringConf + .createOptional } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 5ecd466194d8b..26699873145b4 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -670,7 +670,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } private def executorHostname(offer: Offer): String = { - if (sc.conf.getOption("spark.mesos.network.name").isDefined) { + if (sc.conf.get(NETWORK_NAME).isDefined) { // The agent's IP is not visible in a CNI container, so we bind to 0.0.0.0 "0.0.0.0" } else { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index fbcbc55099ec5..e5c1e801f2772 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -21,6 +21,7 @@ import org.apache.mesos.Protos.{ContainerInfo, Image, NetworkInfo, Parameter, Vo import org.apache.mesos.Protos.ContainerInfo.{DockerInfo, MesosInfo} import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.mesos.config.{NETWORK_LABELS, NETWORK_NAME} import org.apache.spark.internal.Logging /** @@ -161,8 +162,12 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { volumes.foreach(_.foreach(containerInfo.addVolumes(_))) } - conf.getOption("spark.mesos.network.name").map { name => - val info = NetworkInfo.newBuilder().setName(name).build() + conf.get(NETWORK_NAME).map { name => + val networkLabels = MesosProtoUtils.mesosLabels(conf.get(NETWORK_LABELS).getOrElse("")) + val info = NetworkInfo.newBuilder() + .setName(name) + .setLabels(networkLabels) + .build() containerInfo.addNetworkInfos(info) } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index 0bb47906347d5..50bb501071509 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -222,7 +222,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi assert(env.getOrElse("TEST_ENV", null) == "TEST_VAL") } - test("supports spark.mesos.network.name") { + test("supports spark.mesos.network.name and spark.mesos.network.labels") { setScheduler() val mem = 1000 @@ -233,7 +233,8 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi command, Map("spark.mesos.executor.home" -> "test", "spark.app.name" -> "test", - "spark.mesos.network.name" -> "test-network-name"), + "spark.mesos.network.name" -> "test-network-name", + "spark.mesos.network.labels" -> "key1:val1,key2:val2"), "s1", new Date())) @@ -246,6 +247,10 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi val networkInfos = launchedTasks.head.getContainer.getNetworkInfosList assert(networkInfos.size == 1) assert(networkInfos.get(0).getName == "test-network-name") + assert(networkInfos.get(0).getLabels.getLabels(0).getKey == "key1") + assert(networkInfos.get(0).getLabels.getLabels(0).getValue == "val1") + assert(networkInfos.get(0).getLabels.getLabels(1).getKey == "key2") + assert(networkInfos.get(0).getLabels.getLabels(1).getValue == "val2") } test("supports spark.mesos.driver.labels") { diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index a8175e29bc9cf..ab29c295dd893 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -568,9 +568,10 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(launchedTasks.head.getLabels.equals(taskLabels)) } - test("mesos supports spark.mesos.network.name") { + test("mesos supports spark.mesos.network.name and spark.mesos.network.labels") { setBackend(Map( - "spark.mesos.network.name" -> "test-network-name" + "spark.mesos.network.name" -> "test-network-name", + "spark.mesos.network.labels" -> "key1:val1,key2:val2" )) val (mem, cpu) = (backend.executorMemory(sc), 4) @@ -582,6 +583,10 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite val networkInfos = launchedTasks.head.getContainer.getNetworkInfosList assert(networkInfos.size == 1) assert(networkInfos.get(0).getName == "test-network-name") + assert(networkInfos.get(0).getLabels.getLabels(0).getKey == "key1") + assert(networkInfos.get(0).getLabels.getLabels(0).getValue == "val1") + assert(networkInfos.get(0).getLabels.getLabels(1).getKey == "key2") + assert(networkInfos.get(0).getLabels.getLabels(1).getValue == "val2") } test("supports spark.scheduler.minRegisteredResourcesRatio") { From 846bc61cf5aa522dc755d50359ef3856ef2b17bf Mon Sep 17 00:00:00 2001 From: lufei Date: Thu, 24 Aug 2017 10:07:27 +0100 Subject: [PATCH 039/187] [MINOR][SQL] The comment of Class ExchangeCoordinator exist a typing and context error MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The given example in the comment of Class ExchangeCoordinator is exist four post-shuffle partitions,but the current comment is “three”. ## How was this patch tested? Author: lufei Closes #19028 from figo77/SPARK-21816. --- .../spark/sql/execution/exchange/ExchangeCoordinator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index deb2c24d0f16e..9fc4ffb651ec8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -75,7 +75,7 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * For example, we have two stages with the following pre-shuffle partition size statistics: * stage 1: [100 MB, 20 MB, 100 MB, 10MB, 30 MB] * stage 2: [10 MB, 10 MB, 70 MB, 5 MB, 5 MB] - * assuming the target input size is 128 MB, we will have three post-shuffle partitions, + * assuming the target input size is 128 MB, we will have four post-shuffle partitions, * which are: * - post-shuffle partition 0: pre-shuffle partition 0 (size 110 MB) * - post-shuffle partition 1: pre-shuffle partition 1 (size 30 MB) From 95713eb4f22de4e16617a605f74a1d6373ed270b Mon Sep 17 00:00:00 2001 From: Jen-Ming Chung Date: Thu, 24 Aug 2017 19:24:00 +0900 Subject: [PATCH 040/187] [SPARK-21804][SQL] json_tuple returns null values within repeated columns except the first one ## What changes were proposed in this pull request? When json_tuple in extracting values from JSON it returns null values within repeated columns except the first one as below: ``` scala scala> spark.sql("""SELECT json_tuple('{"a":1, "b":2}', 'a', 'b', 'a')""").show() +---+---+----+ | c0| c1| c2| +---+---+----+ | 1| 2|null| +---+---+----+ ``` I think this should be consistent with Hive's implementation: ``` hive> SELECT json_tuple('{"a": 1, "b": 2}', 'a', 'a'); ... 1 1 ``` In this PR, we located all the matched indices in `fieldNames` instead of returning the first matched index, i.e., indexOf. ## How was this patch tested? Added test in JsonExpressionsSuite. Author: Jen-Ming Chung Closes #19017 from jmchung/SPARK-21804. --- .../sql/catalyst/expressions/jsonExpressions.scala | 12 ++++++++++-- .../catalyst/expressions/JsonExpressionsSuite.scala | 10 ++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index c3757373a3cf9..ee5da1a83a4ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -436,7 +436,8 @@ case class JsonTuple(children: Seq[Expression]) while (parser.nextToken() != JsonToken.END_OBJECT) { if (parser.getCurrentToken == JsonToken.FIELD_NAME) { // check to see if this field is desired in the output - val idx = fieldNames.indexOf(parser.getCurrentName) + val jsonField = parser.getCurrentName + var idx = fieldNames.indexOf(jsonField) if (idx >= 0) { // it is, copy the child tree to the correct location in the output row val output = new ByteArrayOutputStream() @@ -447,7 +448,14 @@ case class JsonTuple(children: Seq[Expression]) generator => copyCurrentStructure(generator, parser) } - row(idx) = UTF8String.fromBytes(output.toByteArray) + val jsonValue = UTF8String.fromBytes(output.toByteArray) + + // SPARK-21804: json_tuple returns null values within repeated columns + // except the first one; so that we need to check the remaining fields. + do { + row(idx) = jsonValue + idx = fieldNames.indexOf(jsonField, idx + 1) + } while (idx >= 0) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 1cd2b4fc18a5c..9991bda165a01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -373,6 +373,16 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow(UTF8String.fromString("1"), null, UTF8String.fromString("2"))) } + test("SPARK-21804: json_tuple returns null values within repeated columns except the first one") { + checkJsonTuple( + JsonTuple(Literal("""{"f1": 1, "f2": 2}""") :: + NonFoldableLiteral("f1") :: + NonFoldableLiteral("cast(NULL AS STRING)") :: + NonFoldableLiteral("f1") :: + Nil), + InternalRow(UTF8String.fromString("1"), null, UTF8String.fromString("1"))) + } + val gmtId = Option(DateTimeUtils.TimeZoneGMT.getID) test("from_json") { From dc5d34d8dcd6526d1dfdac8606661561c7576a62 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 24 Aug 2017 20:29:03 +0900 Subject: [PATCH 041/187] [SPARK-19165][PYTHON][SQL] PySpark APIs using columns as arguments should validate input types for column ## What changes were proposed in this pull request? While preparing to take over https://github.com/apache/spark/pull/16537, I realised a (I think) better approach to make the exception handling in one point. This PR proposes to fix `_to_java_column` in `pyspark.sql.column`, which most of functions in `functions.py` and some other APIs use. This `_to_java_column` basically looks not working with other types than `pyspark.sql.column.Column` or string (`str` and `unicode`). If this is not `Column`, then it calls `_create_column_from_name` which calls `functions.col` within JVM: https://github.com/apache/spark/blob/42b9eda80e975d970c3e8da4047b318b83dd269f/sql/core/src/main/scala/org/apache/spark/sql/functions.scala#L76 And it looks we only have `String` one with `col`. So, these should work: ```python >>> from pyspark.sql.column import _to_java_column, Column >>> _to_java_column("a") JavaObject id=o28 >>> _to_java_column(u"a") JavaObject id=o29 >>> _to_java_column(spark.range(1).id) JavaObject id=o33 ``` whereas these do not: ```python >>> _to_java_column(1) ``` ``` ... py4j.protocol.Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.col. Trace: py4j.Py4JException: Method col([class java.lang.Integer]) does not exist ... ``` ```python >>> _to_java_column([]) ``` ``` ... py4j.protocol.Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.col. Trace: py4j.Py4JException: Method col([class java.util.ArrayList]) does not exist ... ``` ```python >>> class A(): pass >>> _to_java_column(A()) ``` ``` ... AttributeError: 'A' object has no attribute '_get_object_id' ``` Meaning most of functions using `_to_java_column` such as `udf` or `to_json` or some other APIs throw an exception as below: ```python >>> from pyspark.sql.functions import udf >>> udf(lambda x: x)(None) ``` ``` ... py4j.protocol.Py4JJavaError: An error occurred while calling z:org.apache.spark.sql.functions.col. : java.lang.NullPointerException ... ``` ```python >>> from pyspark.sql.functions import to_json >>> to_json(None) ``` ``` ... py4j.protocol.Py4JJavaError: An error occurred while calling z:org.apache.spark.sql.functions.col. : java.lang.NullPointerException ... ``` **After this PR**: ```python >>> from pyspark.sql.functions import udf >>> udf(lambda x: x)(None) ... ``` ``` TypeError: Invalid argument, not a string or column: None of type . For column literals, use 'lit', 'array', 'struct' or 'create_map' functions. ``` ```python >>> from pyspark.sql.functions import to_json >>> to_json(None) ``` ``` ... TypeError: Invalid argument, not a string or column: None of type . For column literals, use 'lit', 'array', 'struct' or 'create_map' functions. ``` ## How was this patch tested? Unit tests added in `python/pyspark/sql/tests.py` and manual tests. Author: hyukjinkwon Author: zero323 Closes #19027 from HyukjinKwon/SPARK-19165. --- python/pyspark/sql/column.py | 8 +++++++- python/pyspark/sql/tests.py | 25 +++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index b172f38ea22d0..43b38a2cd477c 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -44,8 +44,14 @@ def _create_column_from_name(name): def _to_java_column(col): if isinstance(col, Column): jcol = col._jc - else: + elif isinstance(col, basestring): jcol = _create_column_from_name(col) + else: + raise TypeError( + "Invalid argument, not a string or column: " + "{0} of type {1}. " + "For column literals, use 'lit', 'array', 'struct' or 'create_map' " + "function.".format(col, type(col))) return jcol diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 45a3f9e7165f1..1ecde68fb0ac1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -704,6 +704,31 @@ def __call__(self, x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) + def test_validate_column_types(self): + from pyspark.sql.functions import udf, to_json + from pyspark.sql.column import _to_java_column + + self.assertTrue("Column" in _to_java_column("a").getClass().toString()) + self.assertTrue("Column" in _to_java_column(u"a").getClass().toString()) + self.assertTrue("Column" in _to_java_column(self.spark.range(1).id).getClass().toString()) + + self.assertRaisesRegexp( + TypeError, + "Invalid argument, not a string or column", + lambda: _to_java_column(1)) + + class A(): + pass + + self.assertRaises(TypeError, lambda: _to_java_column(A())) + self.assertRaises(TypeError, lambda: _to_java_column([])) + + self.assertRaisesRegexp( + TypeError, + "Invalid argument, not a string or column", + lambda: udf(lambda x: x)(None)) + self.assertRaises(TypeError, lambda: to_json(1)) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) From 9e33954ddfe1148f69e523c89827feb76ba892c9 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 24 Aug 2017 21:13:44 +0800 Subject: [PATCH 042/187] [SPARK-21745][SQL] Refactor ColumnVector hierarchy to make ColumnVector read-only and to introduce WritableColumnVector. ## What changes were proposed in this pull request? This is a refactoring of `ColumnVector` hierarchy and related classes. 1. make `ColumnVector` read-only 2. introduce `WritableColumnVector` with write interface 3. remove `ReadOnlyColumnVector` ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #18958 from ueshin/issues/SPARK-21745. --- .../expressions/codegen/CodeGenerator.scala | 28 +- .../parquet/VectorizedColumnReader.java | 31 +- .../VectorizedParquetRecordReader.java | 23 +- .../parquet/VectorizedPlainValuesReader.java | 16 +- .../parquet/VectorizedRleValuesReader.java | 87 ++- .../parquet/VectorizedValuesReader.java | 16 +- .../vectorized/AggregateHashMap.java | 10 +- .../vectorized/ArrowColumnVector.java | 45 +- .../execution/vectorized/ColumnVector.java | 632 +--------------- .../vectorized/ColumnVectorUtils.java | 18 +- .../execution/vectorized/ColumnarBatch.java | 106 +-- .../vectorized/OffHeapColumnVector.java | 34 +- .../vectorized/OnHeapColumnVector.java | 35 +- .../vectorized/ReadOnlyColumnVector.java | 251 ------- .../vectorized/WritableColumnVector.java | 674 ++++++++++++++++++ .../VectorizedHashMapGenerator.scala | 39 +- .../vectorized/ColumnarBatchBenchmark.scala | 23 +- .../vectorized/ColumnarBatchSuite.scala | 109 +-- 18 files changed, 1078 insertions(+), 1099 deletions(-) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 807765c1e00a1..38538630c8b32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -464,14 +464,13 @@ class CodegenContext { /** * Returns the specialized code to set a given value in a column vector for a given `DataType`. */ - def setValue(batch: String, row: String, dataType: DataType, ordinal: Int, - value: String): String = { + def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = { val jt = javaType(dataType) dataType match { case _ if isPrimitiveType(jt) => - s"$batch.column($ordinal).put${primitiveTypeName(jt)}($row, $value);" - case t: DecimalType => s"$batch.column($ordinal).putDecimal($row, $value, ${t.precision});" - case t: StringType => s"$batch.column($ordinal).putByteArray($row, $value.getBytes());" + s"$vector.put${primitiveTypeName(jt)}($rowId, $value);" + case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});" + case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());" case _ => throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") } @@ -482,37 +481,36 @@ class CodegenContext { * that could potentially be nullable. */ def updateColumn( - batch: String, - row: String, + vector: String, + rowId: String, dataType: DataType, - ordinal: Int, ev: ExprCode, nullable: Boolean): String = { if (nullable) { s""" if (!${ev.isNull}) { - ${setValue(batch, row, dataType, ordinal, ev.value)} + ${setValue(vector, rowId, dataType, ev.value)} } else { - $batch.column($ordinal).putNull($row); + $vector.putNull($rowId); } """ } else { - s"""${setValue(batch, row, dataType, ordinal, ev.value)};""" + s"""${setValue(vector, rowId, dataType, ev.value)};""" } } /** * Returns the specialized code to access a value from a column vector for a given `DataType`. */ - def getValue(batch: String, row: String, dataType: DataType, ordinal: Int): String = { + def getValue(vector: String, rowId: String, dataType: DataType): String = { val jt = javaType(dataType) dataType match { case _ if isPrimitiveType(jt) => - s"$batch.column($ordinal).get${primitiveTypeName(jt)}($row)" + s"$vector.get${primitiveTypeName(jt)}($rowId)" case t: DecimalType => - s"$batch.column($ordinal).getDecimal($row, ${t.precision}, ${t.scale})" + s"$vector.getDecimal($rowId, ${t.precision}, ${t.scale})" case StringType => - s"$batch.column($ordinal).getUTF8String($row)" + s"$vector.getUTF8String($rowId)" case _ => throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index fd8db1727212f..f37864a0f5393 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; @@ -135,9 +136,9 @@ private boolean next() throws IOException { /** * Reads `total` values from this columnReader into column. */ - void readBatch(int total, ColumnVector column) throws IOException { + void readBatch(int total, WritableColumnVector column) throws IOException { int rowId = 0; - ColumnVector dictionaryIds = null; + WritableColumnVector dictionaryIds = null; if (dictionary != null) { // SPARK-16334: We only maintain a single dictionary per row batch, so that it can be used to // decode all previous dictionary encoded pages if we ever encounter a non-dictionary encoded @@ -219,8 +220,11 @@ void readBatch(int total, ColumnVector column) throws IOException { /** * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. */ - private void decodeDictionaryIds(int rowId, int num, ColumnVector column, - ColumnVector dictionaryIds) { + private void decodeDictionaryIds( + int rowId, + int num, + WritableColumnVector column, + ColumnVector dictionaryIds) { switch (descriptor.getType()) { case INT32: if (column.dataType() == DataTypes.IntegerType || @@ -346,13 +350,13 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, * is guaranteed that num is smaller than the number of values left in the current page. */ - private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IOException { + private void readBooleanBatch(int rowId, int num, WritableColumnVector column) throws IOException { assert(column.dataType() == DataTypes.BooleanType); defColumn.readBooleans( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } - private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException { + private void readIntBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType || @@ -370,7 +374,7 @@ private void readIntBatch(int rowId, int num, ColumnVector column) throws IOExce } } - private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException { + private void readLongBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. if (column.dataType() == DataTypes.LongType || DecimalType.is64BitDecimalType(column.dataType())) { @@ -389,7 +393,7 @@ private void readLongBatch(int rowId, int num, ColumnVector column) throws IOExc } } - private void readFloatBatch(int rowId, int num, ColumnVector column) throws IOException { + private void readFloatBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: support implicit cast to double? if (column.dataType() == DataTypes.FloatType) { @@ -400,7 +404,7 @@ private void readFloatBatch(int rowId, int num, ColumnVector column) throws IOEx } } - private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOException { + private void readDoubleBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions if (column.dataType() == DataTypes.DoubleType) { @@ -411,7 +415,7 @@ private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOE } } - private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOException { + private void readBinaryBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; @@ -432,8 +436,11 @@ private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOE } } - private void readFixedLenByteArrayBatch(int rowId, int num, - ColumnVector column, int arrayLen) throws IOException { + private void readFixedLenByteArrayBatch( + int rowId, + int num, + WritableColumnVector column, + int arrayLen) throws IOException { VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 04f8141d66e9d..0cacf0c9c93a5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -31,6 +31,9 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; import org.apache.spark.sql.execution.vectorized.ColumnarBatch; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -90,6 +93,8 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa */ private ColumnarBatch columnarBatch; + private WritableColumnVector[] columnVectors; + /** * If true, this class returns batches instead of rows. */ @@ -172,20 +177,26 @@ public void initBatch(MemoryMode memMode, StructType partitionColumns, } } - columnarBatch = ColumnarBatch.allocate(batchSchema, memMode); + int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE; + if (memMode == MemoryMode.OFF_HEAP) { + columnVectors = OffHeapColumnVector.allocateColumns(capacity, batchSchema); + } else { + columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema); + } + columnarBatch = new ColumnarBatch(batchSchema, columnVectors, capacity); if (partitionColumns != null) { int partitionIdx = sparkSchema.fields().length; for (int i = 0; i < partitionColumns.fields().length; i++) { - ColumnVectorUtils.populate(columnarBatch.column(i + partitionIdx), partitionValues, i); - columnarBatch.column(i + partitionIdx).setIsConstant(); + ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); + columnVectors[i + partitionIdx].setIsConstant(); } } // Initialize missing columns with nulls. for (int i = 0; i < missingColumns.length; i++) { if (missingColumns[i]) { - columnarBatch.column(i).putNulls(0, columnarBatch.capacity()); - columnarBatch.column(i).setIsConstant(); + columnVectors[i].putNulls(0, columnarBatch.capacity()); + columnVectors[i].setIsConstant(); } } } @@ -226,7 +237,7 @@ public boolean nextBatch() throws IOException { int num = (int) Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { if (columnReaders[i] == null) continue; - columnReaders[i].readBatch(num, columnarBatch.column(i)); + columnReaders[i].readBatch(num, columnVectors[i]); } rowsReturned += num; columnarBatch.setNumRows(num); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 98018b7f48bd8..5b75f719339fb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -20,7 +20,7 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; -import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.unsafe.Platform; import org.apache.parquet.column.values.ValuesReader; @@ -56,7 +56,7 @@ public void skip() { } @Override - public final void readBooleans(int total, ColumnVector c, int rowId) { + public final void readBooleans(int total, WritableColumnVector c, int rowId) { // TODO: properly vectorize this for (int i = 0; i < total; i++) { c.putBoolean(rowId + i, readBoolean()); @@ -64,31 +64,31 @@ public final void readBooleans(int total, ColumnVector c, int rowId) { } @Override - public final void readIntegers(int total, ColumnVector c, int rowId) { + public final void readIntegers(int total, WritableColumnVector c, int rowId) { c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); offset += 4 * total; } @Override - public final void readLongs(int total, ColumnVector c, int rowId) { + public final void readLongs(int total, WritableColumnVector c, int rowId) { c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); offset += 8 * total; } @Override - public final void readFloats(int total, ColumnVector c, int rowId) { + public final void readFloats(int total, WritableColumnVector c, int rowId) { c.putFloats(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); offset += 4 * total; } @Override - public final void readDoubles(int total, ColumnVector c, int rowId) { + public final void readDoubles(int total, WritableColumnVector c, int rowId) { c.putDoubles(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); offset += 8 * total; } @Override - public final void readBytes(int total, ColumnVector c, int rowId) { + public final void readBytes(int total, WritableColumnVector c, int rowId) { for (int i = 0; i < total; i++) { // Bytes are stored as a 4-byte little endian int. Just read the first byte. // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. @@ -159,7 +159,7 @@ public final double readDouble() { } @Override - public final void readBinary(int total, ColumnVector v, int rowId) { + public final void readBinary(int total, WritableColumnVector v, int rowId) { for (int i = 0; i < total; i++) { int len = readInteger(); int start = offset; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 62157389013bb..fc7fa70c39419 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -25,7 +25,7 @@ import org.apache.parquet.io.ParquetDecodingException; import org.apache.parquet.io.api.Binary; -import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; /** * A values reader for Parquet's run-length encoded data. This is based off of the version in @@ -177,7 +177,11 @@ public int readInteger() { * c[rowId] = null; * } */ - public void readIntegers(int total, ColumnVector c, int rowId, int level, + public void readIntegers( + int total, + WritableColumnVector c, + int rowId, + int level, VectorizedValuesReader data) { int left = total; while (left > 0) { @@ -208,8 +212,12 @@ public void readIntegers(int total, ColumnVector c, int rowId, int level, } // TODO: can this code duplication be removed without a perf penalty? - public void readBooleans(int total, ColumnVector c, - int rowId, int level, VectorizedValuesReader data) { + public void readBooleans( + int total, + WritableColumnVector c, + int rowId, + int level, + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -238,8 +246,12 @@ public void readBooleans(int total, ColumnVector c, } } - public void readBytes(int total, ColumnVector c, - int rowId, int level, VectorizedValuesReader data) { + public void readBytes( + int total, + WritableColumnVector c, + int rowId, + int level, + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -268,8 +280,12 @@ public void readBytes(int total, ColumnVector c, } } - public void readShorts(int total, ColumnVector c, - int rowId, int level, VectorizedValuesReader data) { + public void readShorts( + int total, + WritableColumnVector c, + int rowId, + int level, + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -300,8 +316,12 @@ public void readShorts(int total, ColumnVector c, } } - public void readLongs(int total, ColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + public void readLongs( + int total, + WritableColumnVector c, + int rowId, + int level, + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -330,8 +350,12 @@ public void readLongs(int total, ColumnVector c, int rowId, int level, } } - public void readFloats(int total, ColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + public void readFloats( + int total, + WritableColumnVector c, + int rowId, + int level, + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -360,8 +384,12 @@ public void readFloats(int total, ColumnVector c, int rowId, int level, } } - public void readDoubles(int total, ColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + public void readDoubles( + int total, + WritableColumnVector c, + int rowId, + int level, + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -390,8 +418,12 @@ public void readDoubles(int total, ColumnVector c, int rowId, int level, } } - public void readBinarys(int total, ColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + public void readBinarys( + int total, + WritableColumnVector c, + int rowId, + int level, + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -424,8 +456,13 @@ public void readBinarys(int total, ColumnVector c, int rowId, int level, * Decoding for dictionary ids. The IDs are populated into `values` and the nullability is * populated into `nulls`. */ - public void readIntegers(int total, ColumnVector values, ColumnVector nulls, int rowId, int level, - VectorizedValuesReader data) { + public void readIntegers( + int total, + WritableColumnVector values, + WritableColumnVector nulls, + int rowId, + int level, + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -459,7 +496,7 @@ public void readIntegers(int total, ColumnVector values, ColumnVector nulls, int // IDs. This is different than the above APIs that decodes definitions levels along with values. // Since this is only used to decode dictionary IDs, only decoding integers is supported. @Override - public void readIntegers(int total, ColumnVector c, int rowId) { + public void readIntegers(int total, WritableColumnVector c, int rowId) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -485,32 +522,32 @@ public byte readByte() { } @Override - public void readBytes(int total, ColumnVector c, int rowId) { + public void readBytes(int total, WritableColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); } @Override - public void readLongs(int total, ColumnVector c, int rowId) { + public void readLongs(int total, WritableColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); } @Override - public void readBinary(int total, ColumnVector c, int rowId) { + public void readBinary(int total, WritableColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); } @Override - public void readBooleans(int total, ColumnVector c, int rowId) { + public void readBooleans(int total, WritableColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); } @Override - public void readFloats(int total, ColumnVector c, int rowId) { + public void readFloats(int total, WritableColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); } @Override - public void readDoubles(int total, ColumnVector c, int rowId) { + public void readDoubles(int total, WritableColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java index 88418ca53fe1e..57d92ae27ece8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet; -import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.parquet.io.api.Binary; @@ -37,11 +37,11 @@ public interface VectorizedValuesReader { /* * Reads `total` values into `c` start at `c[rowId]` */ - void readBooleans(int total, ColumnVector c, int rowId); - void readBytes(int total, ColumnVector c, int rowId); - void readIntegers(int total, ColumnVector c, int rowId); - void readLongs(int total, ColumnVector c, int rowId); - void readFloats(int total, ColumnVector c, int rowId); - void readDoubles(int total, ColumnVector c, int rowId); - void readBinary(int total, ColumnVector c, int rowId); + void readBooleans(int total, WritableColumnVector c, int rowId); + void readBytes(int total, WritableColumnVector c, int rowId); + void readIntegers(int total, WritableColumnVector c, int rowId); + void readLongs(int total, WritableColumnVector c, int rowId); + void readFloats(int total, WritableColumnVector c, int rowId); + void readDoubles(int total, WritableColumnVector c, int rowId); + void readBinary(int total, WritableColumnVector c, int rowId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java index 25a565d32638d..1c94f706dc685 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java @@ -41,6 +41,7 @@ */ public class AggregateHashMap { + private OnHeapColumnVector[] columnVectors; private ColumnarBatch batch; private int[] buckets; private int numBuckets; @@ -62,7 +63,8 @@ public AggregateHashMap(StructType schema, int capacity, double loadFactor, int this.maxSteps = maxSteps; numBuckets = (int) (capacity / loadFactor); - batch = ColumnarBatch.allocate(schema, MemoryMode.ON_HEAP, capacity); + columnVectors = OnHeapColumnVector.allocateColumns(capacity, schema); + batch = new ColumnarBatch(schema, columnVectors, capacity); buckets = new int[numBuckets]; Arrays.fill(buckets, -1); } @@ -74,8 +76,8 @@ public AggregateHashMap(StructType schema) { public ColumnarBatch.Row findOrInsert(long key) { int idx = find(key); if (idx != -1 && buckets[idx] == -1) { - batch.column(0).putLong(numRows, key); - batch.column(1).putLong(numRows, 0); + columnVectors[0].putLong(numRows, key); + columnVectors[1].putLong(numRows, 0); buckets[idx] = numRows++; } return batch.getRow(buckets[idx]); @@ -105,6 +107,6 @@ private long hash(long key) { } private boolean equals(int idx, long key1) { - return batch.column(0).getLong(buckets[idx]) == key1; + return columnVectors[0].getLong(buckets[idx]) == key1; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 59d66c599c518..be2a9c246747c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -29,12 +29,13 @@ /** * A column vector backed by Apache Arrow. */ -public final class ArrowColumnVector extends ReadOnlyColumnVector { +public final class ArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; - private final int valueCount; + private ArrowColumnVector[] childColumns; private void ensureAccessible(int index) { + int valueCount = accessor.getValueCount(); if (index < 0 || index >= valueCount) { throw new IndexOutOfBoundsException( String.format("index: %d, valueCount: %d", index, valueCount)); @@ -42,12 +43,23 @@ private void ensureAccessible(int index) { } private void ensureAccessible(int index, int count) { + int valueCount = accessor.getValueCount(); if (index < 0 || index + count > valueCount) { throw new IndexOutOfBoundsException( String.format("index range: [%d, %d), valueCount: %d", index, index + count, valueCount)); } } + @Override + public int numNulls() { + return accessor.getNullCount(); + } + + @Override + public boolean anyNullsSet() { + return numNulls() > 0; + } + @Override public long nullsNativeAddress() { throw new RuntimeException("Cannot get native address for arrow column"); @@ -274,9 +286,20 @@ public byte[] getBinary(int rowId) { return accessor.getBinary(rowId); } + /** + * Returns the data for the underlying array. + */ + @Override + public ArrowColumnVector arrayData() { return childColumns[0]; } + + /** + * Returns the ordinal's child data column. + */ + @Override + public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + public ArrowColumnVector(ValueVector vector) { - super(vector.getValueCapacity(), ArrowUtils.fromArrowField(vector.getField()), - MemoryMode.OFF_HEAP); + super(ArrowUtils.fromArrowField(vector.getField())); if (vector instanceof NullableBitVector) { accessor = new BooleanAccessor((NullableBitVector) vector); @@ -302,7 +325,7 @@ public ArrowColumnVector(ValueVector vector) { ListVector listVector = (ListVector) vector; accessor = new ArrayAccessor(listVector); - childColumns = new ColumnVector[1]; + childColumns = new ArrowColumnVector[1]; childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); resultArray = new ColumnVector.Array(childColumns[0]); } else if (vector instanceof MapVector) { @@ -317,9 +340,6 @@ public ArrowColumnVector(ValueVector vector) { } else { throw new UnsupportedOperationException(); } - valueCount = accessor.getValueCount(); - numNulls = accessor.getNullCount(); - anyNullsSet = numNulls > 0; } private abstract static class ArrowVectorAccessor { @@ -327,14 +347,9 @@ private abstract static class ArrowVectorAccessor { private final ValueVector vector; private final ValueVector.Accessor nulls; - private final int valueCount; - private final int nullCount; - ArrowVectorAccessor(ValueVector vector) { this.vector = vector; this.nulls = vector.getAccessor(); - this.valueCount = nulls.getValueCount(); - this.nullCount = nulls.getNullCount(); } final boolean isNullAt(int rowId) { @@ -342,11 +357,11 @@ final boolean isNullAt(int rowId) { } final int getValueCount() { - return valueCount; + return nulls.getValueCount(); } final int getNullCount() { - return nullCount; + return nulls.getNullCount(); } final void close() { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 77966382881b8..a69dd9718fe33 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -16,23 +16,16 @@ */ package org.apache.spark.sql.execution.vectorized; -import java.math.BigDecimal; -import java.math.BigInteger; - -import com.google.common.annotations.VisibleForTesting; - -import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; -import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; /** * This class represents a column of values and provides the main APIs to access the data - * values. It supports all the types and contains get/put APIs as well as their batched versions. + * values. It supports all the types and contains get APIs as well as their batched versions. * The batched versions are preferable whenever possible. * * To handle nested schemas, ColumnVector has two types: Arrays and Structs. In both cases these @@ -40,34 +33,15 @@ * contains nullability, and in the case of Arrays, the lengths and offsets into the child column. * Lengths and offsets are encoded identically to INTs. * Maps are just a special case of a two field struct. - * Strings are handled as an Array of ByteType. - * - * Capacity: The data stored is dense but the arrays are not fixed capacity. It is the - * responsibility of the caller to call reserve() to ensure there is enough room before adding - * elements. This means that the put() APIs do not check as in common cases (i.e. flat schemas), - * the lengths are known up front. * * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values * in the current RowBatch. * - * A ColumnVector should be considered immutable once originally created. In other words, it is not - * valid to call put APIs after reads until reset() is called. + * A ColumnVector should be considered immutable once originally created. * * ColumnVectors are intended to be reused. */ public abstract class ColumnVector implements AutoCloseable { - /** - * Allocates a column to store elements of `type` on or off heap. - * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is - * in number of elements, not number of bytes. - */ - public static ColumnVector allocate(int capacity, DataType type, MemoryMode mode) { - if (mode == MemoryMode.OFF_HEAP) { - return new OffHeapColumnVector(capacity, type); - } else { - return new OnHeapColumnVector(capacity, type); - } - } /** * Holder object to return an array. This object is intended to be reused. Callers should @@ -278,75 +252,22 @@ public Object get(int ordinal, DataType dataType) { */ public final DataType dataType() { return type; } - /** - * Resets this column for writing. The currently stored values are no longer accessible. - */ - public void reset() { - if (isConstant) return; - - if (childColumns != null) { - for (ColumnVector c: childColumns) { - c.reset(); - } - } - numNulls = 0; - elementsAppended = 0; - if (anyNullsSet) { - putNotNulls(0, capacity); - anyNullsSet = false; - } - } - /** * Cleans up memory for this column. The column is not usable after this. * TODO: this should probably have ref-counted semantics. */ public abstract void close(); - public void reserve(int requiredCapacity) { - if (requiredCapacity > capacity) { - int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L); - if (requiredCapacity <= newCapacity) { - try { - reserveInternal(newCapacity); - } catch (OutOfMemoryError outOfMemoryError) { - throwUnsupportedException(requiredCapacity, outOfMemoryError); - } - } else { - throwUnsupportedException(requiredCapacity, null); - } - } - } - - private void throwUnsupportedException(int requiredCapacity, Throwable cause) { - String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + - "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + - "vectorized reader by setting " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + - " to false."; - - if (cause != null) { - throw new RuntimeException(message, cause); - } else { - throw new RuntimeException(message); - } - } - - /** - * Ensures that there is enough storage to store capacity elements. That is, the put() APIs - * must work for all rowIds < capacity. - */ - protected abstract void reserveInternal(int capacity); - /** * Returns the number of nulls in this column. */ - public final int numNulls() { return numNulls; } + public abstract int numNulls(); /** * Returns true if any of the nulls indicator are set for this column. This can be used * as an optimization to prevent setting nulls. */ - public final boolean anyNullsSet() { return anyNullsSet; } + public abstract boolean anyNullsSet(); /** * Returns the off heap ptr for the arrays backing the NULLs and values buffer. Only valid @@ -355,33 +276,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { public abstract long nullsNativeAddress(); public abstract long valuesNativeAddress(); - /** - * Sets the value at rowId to null/not null. - */ - public abstract void putNotNull(int rowId); - public abstract void putNull(int rowId); - - /** - * Sets the values from [rowId, rowId + count) to null/not null. - */ - public abstract void putNulls(int rowId, int count); - public abstract void putNotNulls(int rowId, int count); - /** * Returns whether the value at rowId is NULL. */ public abstract boolean isNullAt(int rowId); - /** - * Sets the value at rowId to `value`. - */ - public abstract void putBoolean(int rowId, boolean value); - - /** - * Sets values from [rowId, rowId + count) to value. - */ - public abstract void putBooleans(int rowId, int count, boolean value); - /** * Returns the value for rowId. */ @@ -392,21 +291,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract boolean[] getBooleans(int rowId, int count); - /** - * Sets the value at rowId to `value`. - */ - public abstract void putByte(int rowId, byte value); - - /** - * Sets values from [rowId, rowId + count) to value. - */ - public abstract void putBytes(int rowId, int count, byte value); - - /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) - */ - public abstract void putBytes(int rowId, int count, byte[] src, int srcIndex); - /** * Returns the value for rowId. */ @@ -417,21 +301,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract byte[] getBytes(int rowId, int count); - /** - * Sets the value at rowId to `value`. - */ - public abstract void putShort(int rowId, short value); - - /** - * Sets values from [rowId, rowId + count) to value. - */ - public abstract void putShorts(int rowId, int count, short value); - - /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) - */ - public abstract void putShorts(int rowId, int count, short[] src, int srcIndex); - /** * Returns the value for rowId. */ @@ -442,27 +311,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract short[] getShorts(int rowId, int count); - /** - * Sets the value at rowId to `value`. - */ - public abstract void putInt(int rowId, int value); - - /** - * Sets values from [rowId, rowId + count) to value. - */ - public abstract void putInts(int rowId, int count, int value); - - /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) - */ - public abstract void putInts(int rowId, int count, int[] src, int srcIndex); - - /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be 4-byte little endian ints. - */ - public abstract void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex); - /** * Returns the value for rowId. */ @@ -480,27 +328,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract int getDictId(int rowId); - /** - * Sets the value at rowId to `value`. - */ - public abstract void putLong(int rowId, long value); - - /** - * Sets values from [rowId, rowId + count) to value. - */ - public abstract void putLongs(int rowId, int count, long value); - - /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) - */ - public abstract void putLongs(int rowId, int count, long[] src, int srcIndex); - - /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be 8-byte little endian longs. - */ - public abstract void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex); - /** * Returns the value for rowId. */ @@ -511,27 +338,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract long[] getLongs(int rowId, int count); - /** - * Sets the value at rowId to `value`. - */ - public abstract void putFloat(int rowId, float value); - - /** - * Sets values from [rowId, rowId + count) to value. - */ - public abstract void putFloats(int rowId, int count, float value); - - /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) - */ - public abstract void putFloats(int rowId, int count, float[] src, int srcIndex); - - /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be ieee formatted floats. - */ - public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex); - /** * Returns the value for rowId. */ @@ -542,27 +348,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract float[] getFloats(int rowId, int count); - /** - * Sets the value at rowId to `value`. - */ - public abstract void putDouble(int rowId, double value); - - /** - * Sets values from [rowId, rowId + count) to value. - */ - public abstract void putDoubles(int rowId, int count, double value); - - /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) - */ - public abstract void putDoubles(int rowId, int count, double[] src, int srcIndex); - - /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be ieee formatted doubles. - */ - public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex); - /** * Returns the value for rowId. */ @@ -573,11 +358,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract double[] getDoubles(int rowId, int count); - /** - * Puts a byte array that already exists in this column. - */ - public abstract void putArray(int rowId, int offset, int length); - /** * Returns the length of the array at rowid. */ @@ -608,7 +388,7 @@ public ColumnarBatch.Row getStruct(int rowId, int size) { /** * Returns the array at rowid. */ - public final Array getArray(int rowId) { + public final ColumnVector.Array getArray(int rowId) { resultArray.length = getArrayLength(rowId); resultArray.offset = getArrayOffset(rowId); return resultArray; @@ -617,24 +397,7 @@ public final Array getArray(int rowId) { /** * Loads the data into array.byteArray. */ - public abstract void loadBytes(Array array); - - /** - * Sets the value at rowId to `value`. - */ - public abstract int putByteArray(int rowId, byte[] value, int offset, int count); - public final int putByteArray(int rowId, byte[] value) { - return putByteArray(rowId, value, 0, value.length); - } - - /** - * Returns the value for rowId. - */ - private Array getByteArray(int rowId) { - Array array = getArray(rowId); - array.data.loadBytes(array); - return array; - } + public abstract void loadBytes(ColumnVector.Array array); /** * Returns the value for rowId. @@ -646,354 +409,42 @@ public MapData getMap(int ordinal) { /** * Returns the decimal for rowId. */ - public Decimal getDecimal(int rowId, int precision, int scale) { - if (precision <= Decimal.MAX_INT_DIGITS()) { - return Decimal.createUnsafe(getInt(rowId), precision, scale); - } else if (precision <= Decimal.MAX_LONG_DIGITS()) { - return Decimal.createUnsafe(getLong(rowId), precision, scale); - } else { - // TODO: best perf? - byte[] bytes = getBinary(rowId); - BigInteger bigInteger = new BigInteger(bytes); - BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); - return Decimal.apply(javaDecimal, precision, scale); - } - } - - - public void putDecimal(int rowId, Decimal value, int precision) { - if (precision <= Decimal.MAX_INT_DIGITS()) { - putInt(rowId, (int) value.toUnscaledLong()); - } else if (precision <= Decimal.MAX_LONG_DIGITS()) { - putLong(rowId, value.toUnscaledLong()); - } else { - BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue(); - putByteArray(rowId, bigInteger.toByteArray()); - } - } + public abstract Decimal getDecimal(int rowId, int precision, int scale); /** * Returns the UTF8String for rowId. */ - public UTF8String getUTF8String(int rowId) { - if (dictionary == null) { - ColumnVector.Array a = getByteArray(rowId); - return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); - } else { - byte[] bytes = dictionary.decodeToBinary(dictionaryIds.getDictId(rowId)); - return UTF8String.fromBytes(bytes); - } - } + public abstract UTF8String getUTF8String(int rowId); /** * Returns the byte array for rowId. */ - public byte[] getBinary(int rowId) { - if (dictionary == null) { - ColumnVector.Array array = getByteArray(rowId); - byte[] bytes = new byte[array.length]; - System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); - return bytes; - } else { - return dictionary.decodeToBinary(dictionaryIds.getDictId(rowId)); - } - } - - /** - * Append APIs. These APIs all behave similarly and will append data to the current vector. It - * is not valid to mix the put and append APIs. The append APIs are slower and should only be - * used if the sizes are not known up front. - * In all these cases, the return value is the rowId for the first appended element. - */ - public final int appendNull() { - assert (!(dataType() instanceof StructType)); // Use appendStruct() - reserve(elementsAppended + 1); - putNull(elementsAppended); - return elementsAppended++; - } - - public final int appendNotNull() { - reserve(elementsAppended + 1); - putNotNull(elementsAppended); - return elementsAppended++; - } - - public final int appendNulls(int count) { - assert (!(dataType() instanceof StructType)); - reserve(elementsAppended + count); - int result = elementsAppended; - putNulls(elementsAppended, count); - elementsAppended += count; - return result; - } - - public final int appendNotNulls(int count) { - assert (!(dataType() instanceof StructType)); - reserve(elementsAppended + count); - int result = elementsAppended; - putNotNulls(elementsAppended, count); - elementsAppended += count; - return result; - } - - public final int appendBoolean(boolean v) { - reserve(elementsAppended + 1); - putBoolean(elementsAppended, v); - return elementsAppended++; - } - - public final int appendBooleans(int count, boolean v) { - reserve(elementsAppended + count); - int result = elementsAppended; - putBooleans(elementsAppended, count, v); - elementsAppended += count; - return result; - } - - public final int appendByte(byte v) { - reserve(elementsAppended + 1); - putByte(elementsAppended, v); - return elementsAppended++; - } - - public final int appendBytes(int count, byte v) { - reserve(elementsAppended + count); - int result = elementsAppended; - putBytes(elementsAppended, count, v); - elementsAppended += count; - return result; - } - - public final int appendBytes(int length, byte[] src, int offset) { - reserve(elementsAppended + length); - int result = elementsAppended; - putBytes(elementsAppended, length, src, offset); - elementsAppended += length; - return result; - } - - public final int appendShort(short v) { - reserve(elementsAppended + 1); - putShort(elementsAppended, v); - return elementsAppended++; - } - - public final int appendShorts(int count, short v) { - reserve(elementsAppended + count); - int result = elementsAppended; - putShorts(elementsAppended, count, v); - elementsAppended += count; - return result; - } - - public final int appendShorts(int length, short[] src, int offset) { - reserve(elementsAppended + length); - int result = elementsAppended; - putShorts(elementsAppended, length, src, offset); - elementsAppended += length; - return result; - } - - public final int appendInt(int v) { - reserve(elementsAppended + 1); - putInt(elementsAppended, v); - return elementsAppended++; - } - - public final int appendInts(int count, int v) { - reserve(elementsAppended + count); - int result = elementsAppended; - putInts(elementsAppended, count, v); - elementsAppended += count; - return result; - } - - public final int appendInts(int length, int[] src, int offset) { - reserve(elementsAppended + length); - int result = elementsAppended; - putInts(elementsAppended, length, src, offset); - elementsAppended += length; - return result; - } - - public final int appendLong(long v) { - reserve(elementsAppended + 1); - putLong(elementsAppended, v); - return elementsAppended++; - } - - public final int appendLongs(int count, long v) { - reserve(elementsAppended + count); - int result = elementsAppended; - putLongs(elementsAppended, count, v); - elementsAppended += count; - return result; - } - - public final int appendLongs(int length, long[] src, int offset) { - reserve(elementsAppended + length); - int result = elementsAppended; - putLongs(elementsAppended, length, src, offset); - elementsAppended += length; - return result; - } - - public final int appendFloat(float v) { - reserve(elementsAppended + 1); - putFloat(elementsAppended, v); - return elementsAppended++; - } - - public final int appendFloats(int count, float v) { - reserve(elementsAppended + count); - int result = elementsAppended; - putFloats(elementsAppended, count, v); - elementsAppended += count; - return result; - } - - public final int appendFloats(int length, float[] src, int offset) { - reserve(elementsAppended + length); - int result = elementsAppended; - putFloats(elementsAppended, length, src, offset); - elementsAppended += length; - return result; - } - - public final int appendDouble(double v) { - reserve(elementsAppended + 1); - putDouble(elementsAppended, v); - return elementsAppended++; - } - - public final int appendDoubles(int count, double v) { - reserve(elementsAppended + count); - int result = elementsAppended; - putDoubles(elementsAppended, count, v); - elementsAppended += count; - return result; - } - - public final int appendDoubles(int length, double[] src, int offset) { - reserve(elementsAppended + length); - int result = elementsAppended; - putDoubles(elementsAppended, length, src, offset); - elementsAppended += length; - return result; - } - - public final int appendByteArray(byte[] value, int offset, int length) { - int copiedOffset = arrayData().appendBytes(length, value, offset); - reserve(elementsAppended + 1); - putArray(elementsAppended, copiedOffset, length); - return elementsAppended++; - } - - public final int appendArray(int length) { - reserve(elementsAppended + 1); - putArray(elementsAppended, arrayData().elementsAppended, length); - return elementsAppended++; - } - - /** - * Appends a NULL struct. This *has* to be used for structs instead of appendNull() as this - * recursively appends a NULL to its children. - * We don't have this logic as the general appendNull implementation to optimize the more - * common non-struct case. - */ - public final int appendStruct(boolean isNull) { - if (isNull) { - appendNull(); - for (ColumnVector c: childColumns) { - if (c.type instanceof StructType) { - c.appendStruct(true); - } else { - c.appendNull(); - } - } - } else { - appendNotNull(); - } - return elementsAppended; - } + public abstract byte[] getBinary(int rowId); /** * Returns the data for the underlying array. */ - public final ColumnVector arrayData() { return childColumns[0]; } + public abstract ColumnVector arrayData(); /** * Returns the ordinal's child data column. */ - public final ColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } - - /** - * Returns the elements appended. - */ - public final int getElementsAppended() { return elementsAppended; } + public abstract ColumnVector getChildColumn(int ordinal); /** * Returns true if this column is an array. */ public final boolean isArray() { return resultArray != null; } - /** - * Marks this column as being constant. - */ - public final void setIsConstant() { isConstant = true; } - - /** - * Maximum number of rows that can be stored in this column. - */ - protected int capacity; - - /** - * Upper limit for the maximum capacity for this column. - */ - @VisibleForTesting - protected int MAX_CAPACITY = Integer.MAX_VALUE; - /** * Data type for this column. */ protected DataType type; - /** - * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks. - */ - protected int numNulls; - - /** - * True if there is at least one NULL byte set. This is an optimization for the writer, to skip - * having to clear NULL bits. - */ - protected boolean anyNullsSet; - - /** - * True if this column's values are fixed. This means the column values never change, even - * across resets. - */ - protected boolean isConstant; - - /** - * Default size of each array length value. This grows as necessary. - */ - protected static final int DEFAULT_ARRAY_LENGTH = 4; - - /** - * Current write cursor (row index) when appending data. - */ - protected int elementsAppended; - - /** - * If this is a nested type (array or struct), the column for the child data. - */ - protected ColumnVector[] childColumns; - /** * Reusable Array holder for getArray(). */ - protected Array resultArray; + protected ColumnVector.Array resultArray; /** * Reusable Struct holder for getStruct(). @@ -1012,32 +463,11 @@ public final int appendStruct(boolean isNull) { */ protected ColumnVector dictionaryIds; - /** - * Update the dictionary. - */ - public void setDictionary(Dictionary dictionary) { - this.dictionary = dictionary; - } - /** * Returns true if this column has a dictionary. */ public boolean hasDictionary() { return this.dictionary != null; } - /** - * Reserve a integer column for ids of dictionary. - */ - public ColumnVector reserveDictionaryIds(int capacity) { - if (dictionaryIds == null) { - dictionaryIds = allocate(capacity, DataTypes.IntegerType, - this instanceof OnHeapColumnVector ? MemoryMode.ON_HEAP : MemoryMode.OFF_HEAP); - } else { - dictionaryIds.reset(); - dictionaryIds.reserve(capacity); - } - return dictionaryIds; - } - /** * Returns the underlying integer column for ids of dictionary. */ @@ -1049,43 +479,7 @@ public ColumnVector getDictionaryIds() { * Sets up the common state and also handles creating the child columns if this is a nested * type. */ - protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { - this.capacity = capacity; + protected ColumnVector(DataType type) { this.type = type; - - if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType - || DecimalType.isByteArrayDecimalType(type)) { - DataType childType; - int childCapacity = capacity; - if (type instanceof ArrayType) { - childType = ((ArrayType)type).elementType(); - } else { - childType = DataTypes.ByteType; - childCapacity *= DEFAULT_ARRAY_LENGTH; - } - this.childColumns = new ColumnVector[1]; - this.childColumns[0] = ColumnVector.allocate(childCapacity, childType, memMode); - this.resultArray = new Array(this.childColumns[0]); - this.resultStruct = null; - } else if (type instanceof StructType) { - StructType st = (StructType)type; - this.childColumns = new ColumnVector[st.fields().length]; - for (int i = 0; i < childColumns.length; ++i) { - this.childColumns[i] = ColumnVector.allocate(capacity, st.fields()[i].dataType(), memMode); - } - this.resultArray = null; - this.resultStruct = new ColumnarBatch.Row(this.childColumns); - } else if (type instanceof CalendarIntervalType) { - // Two columns. Months as int. Microseconds as Long. - this.childColumns = new ColumnVector[2]; - this.childColumns[0] = ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode); - this.childColumns[1] = ColumnVector.allocate(capacity, DataTypes.LongType, memMode); - this.resultArray = null; - this.resultStruct = new ColumnarBatch.Row(this.childColumns); - } else { - this.childColumns = null; - this.resultArray = null; - this.resultStruct = null; - } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 900d7c431e723..adb859ed17757 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -40,7 +40,7 @@ public class ColumnVectorUtils { /** * Populates the entire `col` with `row[fieldIdx]` */ - public static void populate(ColumnVector col, InternalRow row, int fieldIdx) { + public static void populate(WritableColumnVector col, InternalRow row, int fieldIdx) { int capacity = col.capacity; DataType t = col.dataType(); @@ -115,7 +115,7 @@ public static Object toPrimitiveJavaArray(ColumnVector.Array array) { } } - private static void appendValue(ColumnVector dst, DataType t, Object o) { + private static void appendValue(WritableColumnVector dst, DataType t, Object o) { if (o == null) { if (t instanceof CalendarIntervalType) { dst.appendStruct(true); @@ -165,7 +165,7 @@ private static void appendValue(ColumnVector dst, DataType t, Object o) { } } - private static void appendValue(ColumnVector dst, DataType t, Row src, int fieldIdx) { + private static void appendValue(WritableColumnVector dst, DataType t, Row src, int fieldIdx) { if (t instanceof ArrayType) { ArrayType at = (ArrayType)t; if (src.isNullAt(fieldIdx)) { @@ -198,15 +198,23 @@ private static void appendValue(ColumnVector dst, DataType t, Row src, int field */ public static ColumnarBatch toBatch( StructType schema, MemoryMode memMode, Iterator row) { - ColumnarBatch batch = ColumnarBatch.allocate(schema, memMode); + int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE; + WritableColumnVector[] columnVectors; + if (memMode == MemoryMode.OFF_HEAP) { + columnVectors = OffHeapColumnVector.allocateColumns(capacity, schema); + } else { + columnVectors = OnHeapColumnVector.allocateColumns(capacity, schema); + } + int n = 0; while (row.hasNext()) { Row r = row.next(); for (int i = 0; i < schema.fields().length; i++) { - appendValue(batch.column(i), schema.fields()[i].dataType(), r, i); + appendValue(columnVectors[i], schema.fields()[i].dataType(), r, i); } n++; } + ColumnarBatch batch = new ColumnarBatch(schema, columnVectors, capacity); batch.setNumRows(n); return batch; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 34dc3af9b85c8..e782756a3e781 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -19,7 +19,6 @@ import java.math.BigDecimal; import java.util.*; -import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -44,8 +43,7 @@ * - Compaction: The batch and columns should be able to compact based on a selection vector. */ public final class ColumnarBatch { - private static final int DEFAULT_BATCH_SIZE = 4 * 1024; - private static MemoryMode DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP; + public static final int DEFAULT_BATCH_SIZE = 4 * 1024; private final StructType schema; private final int capacity; @@ -64,18 +62,6 @@ public final class ColumnarBatch { // Staging row returned from getRow. final Row row; - public static ColumnarBatch allocate(StructType schema, MemoryMode memMode) { - return new ColumnarBatch(schema, DEFAULT_BATCH_SIZE, memMode); - } - - public static ColumnarBatch allocate(StructType type) { - return new ColumnarBatch(type, DEFAULT_BATCH_SIZE, DEFAULT_MEMORY_MODE); - } - - public static ColumnarBatch allocate(StructType schema, MemoryMode memMode, int maxRows) { - return new ColumnarBatch(schema, maxRows, memMode); - } - /** * Called to close all the columns in this batch. It is not valid to access the data after * calling this. This must be called at the end to clean up memory allocations. @@ -95,12 +81,19 @@ public static final class Row extends InternalRow { private final ColumnarBatch parent; private final int fixedLenRowSize; private final ColumnVector[] columns; + private final WritableColumnVector[] writableColumns; // Ctor used if this is a top level row. private Row(ColumnarBatch parent) { this.parent = parent; this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols()); this.columns = parent.columns; + this.writableColumns = new WritableColumnVector[this.columns.length]; + for (int i = 0; i < this.columns.length; i++) { + if (this.columns[i] instanceof WritableColumnVector) { + this.writableColumns[i] = (WritableColumnVector) this.columns[i]; + } + } } // Ctor used if this is a struct. @@ -108,6 +101,12 @@ protected Row(ColumnVector[] columns) { this.parent = null; this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(columns.length); this.columns = columns; + this.writableColumns = new WritableColumnVector[this.columns.length]; + for (int i = 0; i < this.columns.length; i++) { + if (this.columns[i] instanceof WritableColumnVector) { + this.writableColumns[i] = (WritableColumnVector) this.columns[i]; + } + } } /** @@ -307,64 +306,69 @@ public void update(int ordinal, Object value) { @Override public void setNullAt(int ordinal) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNull(rowId); + getWritableColumn(ordinal).putNull(rowId); } @Override public void setBoolean(int ordinal, boolean value) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNotNull(rowId); - columns[ordinal].putBoolean(rowId, value); + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putBoolean(rowId, value); } @Override public void setByte(int ordinal, byte value) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNotNull(rowId); - columns[ordinal].putByte(rowId, value); + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putByte(rowId, value); } @Override public void setShort(int ordinal, short value) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNotNull(rowId); - columns[ordinal].putShort(rowId, value); + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putShort(rowId, value); } @Override public void setInt(int ordinal, int value) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNotNull(rowId); - columns[ordinal].putInt(rowId, value); + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putInt(rowId, value); } @Override public void setLong(int ordinal, long value) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNotNull(rowId); - columns[ordinal].putLong(rowId, value); + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putLong(rowId, value); } @Override public void setFloat(int ordinal, float value) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNotNull(rowId); - columns[ordinal].putFloat(rowId, value); + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putFloat(rowId, value); } @Override public void setDouble(int ordinal, double value) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNotNull(rowId); - columns[ordinal].putDouble(rowId, value); + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putDouble(rowId, value); } @Override public void setDecimal(int ordinal, Decimal value, int precision) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNotNull(rowId); - columns[ordinal].putDecimal(rowId, value, precision); + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putDecimal(rowId, value, precision); + } + + private WritableColumnVector getWritableColumn(int ordinal) { + WritableColumnVector column = writableColumns[ordinal]; + assert (!column.isConstant); + return column; } } @@ -409,7 +413,9 @@ public void remove() { */ public void reset() { for (int i = 0; i < numCols(); ++i) { - columns[i].reset(); + if (columns[i] instanceof WritableColumnVector) { + ((WritableColumnVector) columns[i]).reset(); + } } if (this.numRowsFiltered > 0) { Arrays.fill(filteredRows, false); @@ -427,7 +433,7 @@ public void setNumRows(int numRows) { this.numRows = numRows; for (int ordinal : nullFilteredColumns) { - if (columns[ordinal].numNulls != 0) { + if (columns[ordinal].numNulls() != 0) { for (int rowId = 0; rowId < numRows; rowId++) { if (!filteredRows[rowId] && columns[ordinal].isNullAt(rowId)) { filteredRows[rowId] = true; @@ -505,18 +511,12 @@ public void filterNullsInColumn(int ordinal) { nullFilteredColumns.add(ordinal); } - private ColumnarBatch(StructType schema, int maxRows, MemoryMode memMode) { + public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) { this.schema = schema; - this.capacity = maxRows; - this.columns = new ColumnVector[schema.size()]; + this.columns = columns; + this.capacity = capacity; this.nullFilteredColumns = new HashSet<>(); - this.filteredRows = new boolean[maxRows]; - - for (int i = 0; i < schema.fields().length; ++i) { - StructField field = schema.fields()[i]; - columns[i] = ColumnVector.allocate(maxRows, field.dataType(), memMode); - } - + this.filteredRows = new boolean[capacity]; this.row = new Row(this); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 2d1f3da8e7463..35682756ed6c3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -19,18 +19,39 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; -import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; /** * Column data backed using offheap memory. */ -public final class OffHeapColumnVector extends ColumnVector { +public final class OffHeapColumnVector extends WritableColumnVector { private static final boolean bigEndianPlatform = ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + /** + * Allocates columns to store elements of each field of the schema off heap. + * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is + * in number of elements, not number of bytes. + */ + public static OffHeapColumnVector[] allocateColumns(int capacity, StructType schema) { + return allocateColumns(capacity, schema.fields()); + } + + /** + * Allocates columns to store elements of each field off heap. + * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is + * in number of elements, not number of bytes. + */ + public static OffHeapColumnVector[] allocateColumns(int capacity, StructField[] fields) { + OffHeapColumnVector[] vectors = new OffHeapColumnVector[fields.length]; + for (int i = 0; i < fields.length; i++) { + vectors[i] = new OffHeapColumnVector(capacity, fields[i].dataType()); + } + return vectors; + } + // The data stored in these two allocations need to maintain binary compatible. We can // directly pass this buffer to external components. private long nulls; @@ -40,8 +61,8 @@ public final class OffHeapColumnVector extends ColumnVector { private long lengthData; private long offsetData; - protected OffHeapColumnVector(int capacity, DataType type) { - super(capacity, type, MemoryMode.OFF_HEAP); + public OffHeapColumnVector(int capacity, DataType type) { + super(capacity, type); nulls = 0; data = 0; @@ -519,4 +540,9 @@ protected void reserveInternal(int newCapacity) { Platform.setMemory(nulls + oldCapacity, (byte)0, newCapacity - oldCapacity); capacity = newCapacity; } + + @Override + protected OffHeapColumnVector reserveNewColumn(int capacity, DataType type) { + return new OffHeapColumnVector(capacity, type); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 506434364be48..96a452978cb35 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -20,7 +20,6 @@ import java.nio.ByteOrder; import java.util.Arrays; -import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; @@ -28,11 +27,33 @@ * A column backed by an in memory JVM array. This stores the NULLs as a byte per value * and a java array for the values. */ -public final class OnHeapColumnVector extends ColumnVector { +public final class OnHeapColumnVector extends WritableColumnVector { private static final boolean bigEndianPlatform = ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + /** + * Allocates columns to store elements of each field of the schema on heap. + * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is + * in number of elements, not number of bytes. + */ + public static OnHeapColumnVector[] allocateColumns(int capacity, StructType schema) { + return allocateColumns(capacity, schema.fields()); + } + + /** + * Allocates columns to store elements of each field on heap. + * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is + * in number of elements, not number of bytes. + */ + public static OnHeapColumnVector[] allocateColumns(int capacity, StructField[] fields) { + OnHeapColumnVector[] vectors = new OnHeapColumnVector[fields.length]; + for (int i = 0; i < fields.length; i++) { + vectors[i] = new OnHeapColumnVector(capacity, fields[i].dataType()); + } + return vectors; + } + // The data stored in these arrays need to maintain binary compatible. We can // directly pass this buffer to external components. @@ -51,8 +72,9 @@ public final class OnHeapColumnVector extends ColumnVector { private int[] arrayLengths; private int[] arrayOffsets; - protected OnHeapColumnVector(int capacity, DataType type) { - super(capacity, type, MemoryMode.ON_HEAP); + public OnHeapColumnVector(int capacity, DataType type) { + super(capacity, type); + reserveInternal(capacity); reset(); } @@ -529,4 +551,9 @@ protected void reserveInternal(int newCapacity) { capacity = newCapacity; } + + @Override + protected OnHeapColumnVector reserveNewColumn(int capacity, DataType type) { + return new OnHeapColumnVector(capacity, type); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java deleted file mode 100644 index e9f6e7c631fd4..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java +++ /dev/null @@ -1,251 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.vectorized; - -import org.apache.spark.memory.MemoryMode; -import org.apache.spark.sql.types.*; - -/** - * An abstract class for read-only column vector. - */ -public abstract class ReadOnlyColumnVector extends ColumnVector { - - protected ReadOnlyColumnVector(int capacity, DataType type, MemoryMode memMode) { - super(capacity, DataTypes.NullType, memMode); - this.type = type; - isConstant = true; - } - - // - // APIs dealing with nulls - // - - @Override - public final void putNotNull(int rowId) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putNull(int rowId) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putNulls(int rowId, int count) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putNotNulls(int rowId, int count) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Booleans - // - - @Override - public final void putBoolean(int rowId, boolean value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putBooleans(int rowId, int count, boolean value) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Bytes - // - - @Override - public final void putByte(int rowId, byte value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putBytes(int rowId, int count, byte value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Shorts - // - - @Override - public final void putShort(int rowId, short value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putShorts(int rowId, int count, short value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putShorts(int rowId, int count, short[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Ints - // - - @Override - public final void putInt(int rowId, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putInts(int rowId, int count, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putInts(int rowId, int count, int[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Longs - // - - @Override - public final void putLong(int rowId, long value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putLongs(int rowId, int count, long value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putLongs(int rowId, int count, long[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with floats - // - - @Override - public final void putFloat(int rowId, float value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putFloats(int rowId, int count, float value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putFloats(int rowId, int count, float[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with doubles - // - - @Override - public final void putDouble(int rowId, double value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putDoubles(int rowId, int count, double value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putDoubles(int rowId, int count, double[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Arrays - // - - @Override - public final void putArray(int rowId, int offset, int length) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Byte Arrays - // - - @Override - public final int putByteArray(int rowId, byte[] value, int offset, int count) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Decimals - // - - @Override - public final void putDecimal(int rowId, Decimal value, int precision) { - throw new UnsupportedOperationException(); - } - - // - // Other APIs - // - - @Override - public final void setDictionary(Dictionary dictionary) { - throw new UnsupportedOperationException(); - } - - @Override - public final ColumnVector reserveDictionaryIds(int capacity) { - throw new UnsupportedOperationException(); - } - - @Override - protected final void reserveInternal(int newCapacity) { - throw new UnsupportedOperationException(); - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java new file mode 100644 index 0000000000000..b4f753c0bc2a3 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -0,0 +1,674 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.math.BigDecimal; +import java.math.BigInteger; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * This class adds write APIs to ColumnVector. + * It supports all the types and contains put APIs as well as their batched versions. + * The batched versions are preferable whenever possible. + * + * Capacity: The data stored is dense but the arrays are not fixed capacity. It is the + * responsibility of the caller to call reserve() to ensure there is enough room before adding + * elements. This means that the put() APIs do not check as in common cases (i.e. flat schemas), + * the lengths are known up front. + * + * A ColumnVector should be considered immutable once originally created. In other words, it is not + * valid to call put APIs after reads until reset() is called. + */ +public abstract class WritableColumnVector extends ColumnVector { + + /** + * Resets this column for writing. The currently stored values are no longer accessible. + */ + public void reset() { + if (isConstant) return; + + if (childColumns != null) { + for (ColumnVector c: childColumns) { + ((WritableColumnVector) c).reset(); + } + } + numNulls = 0; + elementsAppended = 0; + if (anyNullsSet) { + putNotNulls(0, capacity); + anyNullsSet = false; + } + } + + public void reserve(int requiredCapacity) { + if (requiredCapacity > capacity) { + int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L); + if (requiredCapacity <= newCapacity) { + try { + reserveInternal(newCapacity); + } catch (OutOfMemoryError outOfMemoryError) { + throwUnsupportedException(requiredCapacity, outOfMemoryError); + } + } else { + throwUnsupportedException(requiredCapacity, null); + } + } + } + + private void throwUnsupportedException(int requiredCapacity, Throwable cause) { + String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + + "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + + "vectorized reader by setting " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + + " to false."; + throw new RuntimeException(message, cause); + } + + @Override + public int numNulls() { return numNulls; } + + @Override + public boolean anyNullsSet() { return anyNullsSet; } + + /** + * Ensures that there is enough storage to store capacity elements. That is, the put() APIs + * must work for all rowIds < capacity. + */ + protected abstract void reserveInternal(int capacity); + + /** + * Sets the value at rowId to null/not null. + */ + public abstract void putNotNull(int rowId); + public abstract void putNull(int rowId); + + /** + * Sets the values from [rowId, rowId + count) to null/not null. + */ + public abstract void putNulls(int rowId, int count); + public abstract void putNotNulls(int rowId, int count); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putBoolean(int rowId, boolean value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putBooleans(int rowId, int count, boolean value); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putByte(int rowId, byte value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putBytes(int rowId, int count, byte value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putBytes(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putShort(int rowId, short value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putShorts(int rowId, int count, short value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putShorts(int rowId, int count, short[] src, int srcIndex); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putInt(int rowId, int value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putInts(int rowId, int count, int value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putInts(int rowId, int count, int[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be 4-byte little endian ints. + */ + public abstract void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putLong(int rowId, long value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putLongs(int rowId, int count, long value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putLongs(int rowId, int count, long[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be 8-byte little endian longs. + */ + public abstract void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putFloat(int rowId, float value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putFloats(int rowId, int count, float value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putFloats(int rowId, int count, float[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be ieee formatted floats. + */ + public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putDouble(int rowId, double value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putDoubles(int rowId, int count, double value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putDoubles(int rowId, int count, double[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be ieee formatted doubles. + */ + public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex); + + /** + * Puts a byte array that already exists in this column. + */ + public abstract void putArray(int rowId, int offset, int length); + + /** + * Sets the value at rowId to `value`. + */ + public abstract int putByteArray(int rowId, byte[] value, int offset, int count); + public final int putByteArray(int rowId, byte[] value) { + return putByteArray(rowId, value, 0, value.length); + } + + /** + * Returns the value for rowId. + */ + private ColumnVector.Array getByteArray(int rowId) { + ColumnVector.Array array = getArray(rowId); + array.data.loadBytes(array); + return array; + } + + /** + * Returns the decimal for rowId. + */ + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (precision <= Decimal.MAX_INT_DIGITS()) { + return Decimal.createUnsafe(getInt(rowId), precision, scale); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.createUnsafe(getLong(rowId), precision, scale); + } else { + // TODO: best perf? + byte[] bytes = getBinary(rowId); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } + } + + public void putDecimal(int rowId, Decimal value, int precision) { + if (precision <= Decimal.MAX_INT_DIGITS()) { + putInt(rowId, (int) value.toUnscaledLong()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + putLong(rowId, value.toUnscaledLong()); + } else { + BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue(); + putByteArray(rowId, bigInteger.toByteArray()); + } + } + + /** + * Returns the UTF8String for rowId. + */ + @Override + public UTF8String getUTF8String(int rowId) { + if (dictionary == null) { + ColumnVector.Array a = getByteArray(rowId); + return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); + } else { + byte[] bytes = dictionary.decodeToBinary(dictionaryIds.getDictId(rowId)); + return UTF8String.fromBytes(bytes); + } + } + + /** + * Returns the byte array for rowId. + */ + @Override + public byte[] getBinary(int rowId) { + if (dictionary == null) { + ColumnVector.Array array = getByteArray(rowId); + byte[] bytes = new byte[array.length]; + System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); + return bytes; + } else { + return dictionary.decodeToBinary(dictionaryIds.getDictId(rowId)); + } + } + + /** + * Append APIs. These APIs all behave similarly and will append data to the current vector. It + * is not valid to mix the put and append APIs. The append APIs are slower and should only be + * used if the sizes are not known up front. + * In all these cases, the return value is the rowId for the first appended element. + */ + public final int appendNull() { + assert (!(dataType() instanceof StructType)); // Use appendStruct() + reserve(elementsAppended + 1); + putNull(elementsAppended); + return elementsAppended++; + } + + public final int appendNotNull() { + reserve(elementsAppended + 1); + putNotNull(elementsAppended); + return elementsAppended++; + } + + public final int appendNulls(int count) { + assert (!(dataType() instanceof StructType)); + reserve(elementsAppended + count); + int result = elementsAppended; + putNulls(elementsAppended, count); + elementsAppended += count; + return result; + } + + public final int appendNotNulls(int count) { + assert (!(dataType() instanceof StructType)); + reserve(elementsAppended + count); + int result = elementsAppended; + putNotNulls(elementsAppended, count); + elementsAppended += count; + return result; + } + + public final int appendBoolean(boolean v) { + reserve(elementsAppended + 1); + putBoolean(elementsAppended, v); + return elementsAppended++; + } + + public final int appendBooleans(int count, boolean v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putBooleans(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendByte(byte v) { + reserve(elementsAppended + 1); + putByte(elementsAppended, v); + return elementsAppended++; + } + + public final int appendBytes(int count, byte v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putBytes(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendBytes(int length, byte[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putBytes(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendShort(short v) { + reserve(elementsAppended + 1); + putShort(elementsAppended, v); + return elementsAppended++; + } + + public final int appendShorts(int count, short v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putShorts(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendShorts(int length, short[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putShorts(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendInt(int v) { + reserve(elementsAppended + 1); + putInt(elementsAppended, v); + return elementsAppended++; + } + + public final int appendInts(int count, int v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putInts(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendInts(int length, int[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putInts(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendLong(long v) { + reserve(elementsAppended + 1); + putLong(elementsAppended, v); + return elementsAppended++; + } + + public final int appendLongs(int count, long v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putLongs(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendLongs(int length, long[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putLongs(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendFloat(float v) { + reserve(elementsAppended + 1); + putFloat(elementsAppended, v); + return elementsAppended++; + } + + public final int appendFloats(int count, float v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putFloats(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendFloats(int length, float[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putFloats(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendDouble(double v) { + reserve(elementsAppended + 1); + putDouble(elementsAppended, v); + return elementsAppended++; + } + + public final int appendDoubles(int count, double v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putDoubles(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendDoubles(int length, double[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putDoubles(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendByteArray(byte[] value, int offset, int length) { + int copiedOffset = arrayData().appendBytes(length, value, offset); + reserve(elementsAppended + 1); + putArray(elementsAppended, copiedOffset, length); + return elementsAppended++; + } + + public final int appendArray(int length) { + reserve(elementsAppended + 1); + putArray(elementsAppended, arrayData().elementsAppended, length); + return elementsAppended++; + } + + /** + * Appends a NULL struct. This *has* to be used for structs instead of appendNull() as this + * recursively appends a NULL to its children. + * We don't have this logic as the general appendNull implementation to optimize the more + * common non-struct case. + */ + public final int appendStruct(boolean isNull) { + if (isNull) { + appendNull(); + for (ColumnVector c: childColumns) { + if (c.type instanceof StructType) { + ((WritableColumnVector) c).appendStruct(true); + } else { + ((WritableColumnVector) c).appendNull(); + } + } + } else { + appendNotNull(); + } + return elementsAppended; + } + + /** + * Returns the data for the underlying array. + */ + @Override + public WritableColumnVector arrayData() { return childColumns[0]; } + + /** + * Returns the ordinal's child data column. + */ + @Override + public WritableColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + + /** + * Returns the elements appended. + */ + public final int getElementsAppended() { return elementsAppended; } + + /** + * Marks this column as being constant. + */ + public final void setIsConstant() { isConstant = true; } + + /** + * Maximum number of rows that can be stored in this column. + */ + protected int capacity; + + /** + * Upper limit for the maximum capacity for this column. + */ + @VisibleForTesting + protected int MAX_CAPACITY = Integer.MAX_VALUE; + + /** + * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks. + */ + protected int numNulls; + + /** + * True if there is at least one NULL byte set. This is an optimization for the writer, to skip + * having to clear NULL bits. + */ + protected boolean anyNullsSet; + + /** + * True if this column's values are fixed. This means the column values never change, even + * across resets. + */ + protected boolean isConstant; + + /** + * Default size of each array length value. This grows as necessary. + */ + protected static final int DEFAULT_ARRAY_LENGTH = 4; + + /** + * Current write cursor (row index) when appending data. + */ + protected int elementsAppended; + + /** + * If this is a nested type (array or struct), the column for the child data. + */ + protected WritableColumnVector[] childColumns; + + /** + * Update the dictionary. + */ + public void setDictionary(Dictionary dictionary) { + this.dictionary = dictionary; + } + + /** + * Reserve a integer column for ids of dictionary. + */ + public WritableColumnVector reserveDictionaryIds(int capacity) { + WritableColumnVector dictionaryIds = (WritableColumnVector) this.dictionaryIds; + if (dictionaryIds == null) { + dictionaryIds = reserveNewColumn(capacity, DataTypes.IntegerType); + this.dictionaryIds = dictionaryIds; + } else { + dictionaryIds.reset(); + dictionaryIds.reserve(capacity); + } + return dictionaryIds; + } + + /** + * Returns the underlying integer column for ids of dictionary. + */ + @Override + public WritableColumnVector getDictionaryIds() { + return (WritableColumnVector) dictionaryIds; + } + + /** + * Reserve a new column. + */ + protected abstract WritableColumnVector reserveNewColumn(int capacity, DataType type); + + /** + * Sets up the common state and also handles creating the child columns if this is a nested + * type. + */ + protected WritableColumnVector(int capacity, DataType type) { + super(type); + this.capacity = capacity; + + if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType + || DecimalType.isByteArrayDecimalType(type)) { + DataType childType; + int childCapacity = capacity; + if (type instanceof ArrayType) { + childType = ((ArrayType)type).elementType(); + } else { + childType = DataTypes.ByteType; + childCapacity *= DEFAULT_ARRAY_LENGTH; + } + this.childColumns = new WritableColumnVector[1]; + this.childColumns[0] = reserveNewColumn(childCapacity, childType); + this.resultArray = new ColumnVector.Array(this.childColumns[0]); + this.resultStruct = null; + } else if (type instanceof StructType) { + StructType st = (StructType)type; + this.childColumns = new WritableColumnVector[st.fields().length]; + for (int i = 0; i < childColumns.length; ++i) { + this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); + } + this.resultArray = null; + this.resultStruct = new ColumnarBatch.Row(this.childColumns); + } else if (type instanceof CalendarIntervalType) { + // Two columns. Months as int. Microseconds as Long. + this.childColumns = new WritableColumnVector[2]; + this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType); + this.childColumns[1] = reserveNewColumn(capacity, DataTypes.LongType); + this.resultArray = null; + this.resultStruct = new ColumnarBatch.Row(this.childColumns); + } else { + this.childColumns = null; + this.resultArray = null; + this.resultStruct = null; + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 0c40417db0837..13f79275cac41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -76,6 +76,8 @@ class VectorizedHashMapGenerator( }.mkString("\n").concat(";") s""" + | private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] batchVectors; + | private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] bufferVectors; | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; | private org.apache.spark.sql.execution.vectorized.ColumnarBatch aggregateBufferBatch; | private int[] buckets; @@ -89,14 +91,19 @@ class VectorizedHashMapGenerator( | $generatedAggBufferSchema | | public $generatedClassName() { - | batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema, - | org.apache.spark.memory.MemoryMode.ON_HEAP, capacity); - | // TODO: Possibly generate this projection in HashAggregate directly - | aggregateBufferBatch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate( - | aggregateBufferSchema, org.apache.spark.memory.MemoryMode.ON_HEAP, capacity); - | for (int i = 0 ; i < aggregateBufferBatch.numCols(); i++) { - | aggregateBufferBatch.setColumn(i, batch.column(i+${groupingKeys.length})); + | batchVectors = org.apache.spark.sql.execution.vectorized + | .OnHeapColumnVector.allocateColumns(capacity, schema); + | batch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch( + | schema, batchVectors, capacity); + | + | bufferVectors = new org.apache.spark.sql.execution.vectorized + | .OnHeapColumnVector[aggregateBufferSchema.fields().length]; + | for (int i = 0; i < aggregateBufferSchema.fields().length; i++) { + | bufferVectors[i] = batchVectors[i + ${groupingKeys.length}]; | } + | // TODO: Possibly generate this projection in HashAggregate directly + | aggregateBufferBatch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch( + | aggregateBufferSchema, bufferVectors, capacity); | | buckets = new int[numBuckets]; | java.util.Arrays.fill(buckets, -1); @@ -112,8 +119,8 @@ class VectorizedHashMapGenerator( * * {{{ * private boolean equals(int idx, long agg_key, long agg_key1) { - * return batch.column(0).getLong(buckets[idx]) == agg_key && - * batch.column(1).getLong(buckets[idx]) == agg_key1; + * return batchVectors[0].getLong(buckets[idx]) == agg_key && + * batchVectors[1].getLong(buckets[idx]) == agg_key1; * } * }}} */ @@ -121,8 +128,8 @@ class VectorizedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"""(${ctx.genEqual(key.dataType, ctx.getValue("batch", "buckets[idx]", - key.dataType, ordinal), key.name)})""" + s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"batchVectors[$ordinal]", "buckets[idx]", + key.dataType), key.name)})""" }.mkString(" && ") } @@ -150,9 +157,9 @@ class VectorizedHashMapGenerator( * while (step < maxSteps) { * // Return bucket index if it's either an empty slot or already contains the key * if (buckets[idx] == -1) { - * batch.column(0).putLong(numRows, agg_key); - * batch.column(1).putLong(numRows, agg_key1); - * batch.column(2).putLong(numRows, 0); + * batchVectors[0].putLong(numRows, agg_key); + * batchVectors[1].putLong(numRows, agg_key1); + * batchVectors[2].putLong(numRows, 0); * buckets[idx] = numRows++; * return batch.getRow(buckets[idx]); * } else if (equals(idx, agg_key, agg_key1)) { @@ -170,13 +177,13 @@ class VectorizedHashMapGenerator( def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.setValue("batch", "numRows", key.dataType, ordinal, key.name) + ctx.setValue(s"batchVectors[$ordinal]", "numRows", key.dataType, key.name) } } def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = { bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.updateColumn("batch", "numRows", key.dataType, groupingKeys.length + ordinal, + ctx.updateColumn(s"batchVectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType, buffVars(ordinal), nullable = true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 67b3d98c1daed..1331f157363b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -24,7 +24,10 @@ import scala.util.Random import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.vectorized.ColumnVector -import org.apache.spark.sql.types.{BinaryType, IntegerType} +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.execution.vectorized.WritableColumnVector +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType} import org.apache.spark.unsafe.Platform import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.BitSet @@ -34,6 +37,14 @@ import org.apache.spark.util.collection.BitSet */ object ColumnarBatchBenchmark { + def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = { + if (memMode == MemoryMode.OFF_HEAP) { + new OffHeapColumnVector(capacity, dt) + } else { + new OnHeapColumnVector(capacity, dt) + } + } + // This benchmark reads and writes an array of ints. // TODO: there is a big (2x) penalty for a random access API for off heap. // Note: carefully if modifying this code. It's hard to reason about the JIT. @@ -140,7 +151,7 @@ object ColumnarBatchBenchmark { // Access through the column API with on heap memory val columnOnHeap = { i: Int => - val col = ColumnVector.allocate(count, IntegerType, MemoryMode.ON_HEAP) + val col = allocate(count, IntegerType, MemoryMode.ON_HEAP) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -159,7 +170,7 @@ object ColumnarBatchBenchmark { // Access through the column API with off heap memory def columnOffHeap = { i: Int => { - val col = ColumnVector.allocate(count, IntegerType, MemoryMode.OFF_HEAP) + val col = allocate(count, IntegerType, MemoryMode.OFF_HEAP) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -178,7 +189,7 @@ object ColumnarBatchBenchmark { // Access by directly getting the buffer backing the column. val columnOffheapDirect = { i: Int => - val col = ColumnVector.allocate(count, IntegerType, MemoryMode.OFF_HEAP) + val col = allocate(count, IntegerType, MemoryMode.OFF_HEAP) var sum = 0L for (n <- 0L until iters) { var addr = col.valuesNativeAddress() @@ -244,7 +255,7 @@ object ColumnarBatchBenchmark { // Adding values by appending, instead of putting. val onHeapAppend = { i: Int => - val col = ColumnVector.allocate(count, IntegerType, MemoryMode.ON_HEAP) + val col = allocate(count, IntegerType, MemoryMode.ON_HEAP) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -362,7 +373,7 @@ object ColumnarBatchBenchmark { .map(_.getBytes(StandardCharsets.UTF_8)).toArray def column(memoryMode: MemoryMode) = { i: Int => - val column = ColumnVector.allocate(count, BinaryType, memoryMode) + val column = allocate(count, BinaryType, memoryMode) var sum = 0L for (n <- 0L until iters) { var i = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index c8461dcb9dfdb..08ccbd628cf8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -34,11 +34,20 @@ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.CalendarInterval class ColumnarBatchSuite extends SparkFunSuite { + + def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = { + if (memMode == MemoryMode.OFF_HEAP) { + new OffHeapColumnVector(capacity, dt) + } else { + new OnHeapColumnVector(capacity, dt) + } + } + test("Null Apis") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val reference = mutable.ArrayBuffer.empty[Boolean] - val column = ColumnVector.allocate(1024, IntegerType, memMode) + val column = allocate(1024, IntegerType, memMode) var idx = 0 assert(column.anyNullsSet() == false) assert(column.numNulls() == 0) @@ -109,7 +118,7 @@ class ColumnarBatchSuite extends SparkFunSuite { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val reference = mutable.ArrayBuffer.empty[Byte] - val column = ColumnVector.allocate(1024, ByteType, memMode) + val column = allocate(1024, ByteType, memMode) var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toByte).toArray column.appendBytes(2, values, 0) @@ -167,7 +176,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Short] - val column = ColumnVector.allocate(1024, ShortType, memMode) + val column = allocate(1024, ShortType, memMode) var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toShort).toArray column.appendShorts(2, values, 0) @@ -247,7 +256,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Int] - val column = ColumnVector.allocate(1024, IntegerType, memMode) + val column = allocate(1024, IntegerType, memMode) var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).toArray column.appendInts(2, values, 0) @@ -332,7 +341,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Long] - val column = ColumnVector.allocate(1024, LongType, memMode) + val column = allocate(1024, LongType, memMode) var values = (10L :: 20L :: 30L :: 40L :: 50L :: Nil).toArray column.appendLongs(2, values, 0) @@ -419,7 +428,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Float] - val column = ColumnVector.allocate(1024, FloatType, memMode) + val column = allocate(1024, FloatType, memMode) var values = (.1f :: .2f :: .3f :: .4f :: .5f :: Nil).toArray column.appendFloats(2, values, 0) @@ -510,7 +519,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Double] - val column = ColumnVector.allocate(1024, DoubleType, memMode) + val column = allocate(1024, DoubleType, memMode) var values = (.1 :: .2 :: .3 :: .4 :: .5 :: Nil).toArray column.appendDoubles(2, values, 0) @@ -599,7 +608,7 @@ class ColumnarBatchSuite extends SparkFunSuite { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val reference = mutable.ArrayBuffer.empty[String] - val column = ColumnVector.allocate(6, BinaryType, memMode) + val column = allocate(6, BinaryType, memMode) assert(column.arrayData().elementsAppended == 0) val str = "string" @@ -656,7 +665,7 @@ class ColumnarBatchSuite extends SparkFunSuite { test("Int Array") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - val column = ColumnVector.allocate(10, new ArrayType(IntegerType, true), memMode) + val column = allocate(10, new ArrayType(IntegerType, true), memMode) // Fill the underlying data with all the arrays back to back. val data = column.arrayData(); @@ -714,43 +723,43 @@ class ColumnarBatchSuite extends SparkFunSuite { (MemoryMode.ON_HEAP :: Nil).foreach { memMode => { val len = 4 - val columnBool = ColumnVector.allocate(len, new ArrayType(BooleanType, false), memMode) + val columnBool = allocate(len, new ArrayType(BooleanType, false), memMode) val boolArray = Array(false, true, false, true) boolArray.zipWithIndex.map { case (v, i) => columnBool.arrayData.putBoolean(i, v) } columnBool.putArray(0, 0, len) assert(columnBool.getArray(0).toBooleanArray === boolArray) - val columnByte = ColumnVector.allocate(len, new ArrayType(ByteType, false), memMode) + val columnByte = allocate(len, new ArrayType(ByteType, false), memMode) val byteArray = Array[Byte](0, 1, 2, 3) byteArray.zipWithIndex.map { case (v, i) => columnByte.arrayData.putByte(i, v) } columnByte.putArray(0, 0, len) assert(columnByte.getArray(0).toByteArray === byteArray) - val columnShort = ColumnVector.allocate(len, new ArrayType(ShortType, false), memMode) + val columnShort = allocate(len, new ArrayType(ShortType, false), memMode) val shortArray = Array[Short](0, 1, 2, 3) shortArray.zipWithIndex.map { case (v, i) => columnShort.arrayData.putShort(i, v) } columnShort.putArray(0, 0, len) assert(columnShort.getArray(0).toShortArray === shortArray) - val columnInt = ColumnVector.allocate(len, new ArrayType(IntegerType, false), memMode) + val columnInt = allocate(len, new ArrayType(IntegerType, false), memMode) val intArray = Array(0, 1, 2, 3) intArray.zipWithIndex.map { case (v, i) => columnInt.arrayData.putInt(i, v) } columnInt.putArray(0, 0, len) assert(columnInt.getArray(0).toIntArray === intArray) - val columnLong = ColumnVector.allocate(len, new ArrayType(LongType, false), memMode) + val columnLong = allocate(len, new ArrayType(LongType, false), memMode) val longArray = Array[Long](0, 1, 2, 3) longArray.zipWithIndex.map { case (v, i) => columnLong.arrayData.putLong(i, v) } columnLong.putArray(0, 0, len) assert(columnLong.getArray(0).toLongArray === longArray) - val columnFloat = ColumnVector.allocate(len, new ArrayType(FloatType, false), memMode) + val columnFloat = allocate(len, new ArrayType(FloatType, false), memMode) val floatArray = Array(0.0F, 1.1F, 2.2F, 3.3F) floatArray.zipWithIndex.map { case (v, i) => columnFloat.arrayData.putFloat(i, v) } columnFloat.putArray(0, 0, len) assert(columnFloat.getArray(0).toFloatArray === floatArray) - val columnDouble = ColumnVector.allocate(len, new ArrayType(DoubleType, false), memMode) + val columnDouble = allocate(len, new ArrayType(DoubleType, false), memMode) val doubleArray = Array(0.0, 1.1, 2.2, 3.3) doubleArray.zipWithIndex.map { case (v, i) => columnDouble.arrayData.putDouble(i, v) } columnDouble.putArray(0, 0, len) @@ -761,7 +770,7 @@ class ColumnarBatchSuite extends SparkFunSuite { test("Struct Column") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val schema = new StructType().add("int", IntegerType).add("double", DoubleType) - val column = ColumnVector.allocate(1024, schema, memMode) + val column = allocate(1024, schema, memMode) val c1 = column.getChildColumn(0) val c2 = column.getChildColumn(1) @@ -790,7 +799,7 @@ class ColumnarBatchSuite extends SparkFunSuite { test("Nest Array in Array.") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => - val column = ColumnVector.allocate(10, new ArrayType(new ArrayType(IntegerType, true), true), + val column = allocate(10, new ArrayType(new ArrayType(IntegerType, true), true), memMode) val childColumn = column.arrayData() val data = column.arrayData().arrayData() @@ -823,7 +832,7 @@ class ColumnarBatchSuite extends SparkFunSuite { test("Nest Struct in Array.") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => val schema = new StructType().add("int", IntegerType).add("long", LongType) - val column = ColumnVector.allocate(10, new ArrayType(schema, true), memMode) + val column = allocate(10, new ArrayType(schema, true), memMode) val data = column.arrayData() val c0 = data.getChildColumn(0) val c1 = data.getChildColumn(1) @@ -853,7 +862,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val schema = new StructType() .add("int", IntegerType) .add("array", new ArrayType(IntegerType, true)) - val column = ColumnVector.allocate(10, schema, memMode) + val column = allocate(10, schema, memMode) val c0 = column.getChildColumn(0) val c1 = column.getChildColumn(1) c0.putInt(0, 0) @@ -885,7 +894,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val schema = new StructType() .add("int", IntegerType) .add("struct", subSchema) - val column = ColumnVector.allocate(10, schema, memMode) + val column = allocate(10, schema, memMode) val c0 = column.getChildColumn(0) val c1 = column.getChildColumn(1) c0.putInt(0, 0) @@ -918,7 +927,11 @@ class ColumnarBatchSuite extends SparkFunSuite { .add("intCol2", IntegerType) .add("string", BinaryType) - val batch = ColumnarBatch.allocate(schema, memMode) + val capacity = ColumnarBatch.DEFAULT_BATCH_SIZE + val columns = schema.fields.map { field => + allocate(capacity, field.dataType, memMode) + } + val batch = new ColumnarBatch(schema, columns.toArray, ColumnarBatch.DEFAULT_BATCH_SIZE) assert(batch.numCols() == 4) assert(batch.numRows() == 0) assert(batch.numValidRows() == 0) @@ -926,10 +939,10 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.rowIterator().hasNext == false) // Add a row [1, 1.1, NULL] - batch.column(0).putInt(0, 1) - batch.column(1).putDouble(0, 1.1) - batch.column(2).putNull(0) - batch.column(3).putByteArray(0, "Hello".getBytes(StandardCharsets.UTF_8)) + columns(0).putInt(0, 1) + columns(1).putDouble(0, 1.1) + columns(2).putNull(0) + columns(3).putByteArray(0, "Hello".getBytes(StandardCharsets.UTF_8)) batch.setNumRows(1) // Verify the results of the row. @@ -939,12 +952,12 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.rowIterator().hasNext == true) assert(batch.rowIterator().hasNext == true) - assert(batch.column(0).getInt(0) == 1) - assert(batch.column(0).isNullAt(0) == false) - assert(batch.column(1).getDouble(0) == 1.1) - assert(batch.column(1).isNullAt(0) == false) - assert(batch.column(2).isNullAt(0) == true) - assert(batch.column(3).getUTF8String(0).toString == "Hello") + assert(columns(0).getInt(0) == 1) + assert(columns(0).isNullAt(0) == false) + assert(columns(1).getDouble(0) == 1.1) + assert(columns(1).isNullAt(0) == false) + assert(columns(2).isNullAt(0) == true) + assert(columns(3).getUTF8String(0).toString == "Hello") // Verify the iterator works correctly. val it = batch.rowIterator() @@ -955,7 +968,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(row.getDouble(1) == 1.1) assert(row.isNullAt(1) == false) assert(row.isNullAt(2) == true) - assert(batch.column(3).getUTF8String(0).toString == "Hello") + assert(columns(3).getUTF8String(0).toString == "Hello") assert(it.hasNext == false) assert(it.hasNext == false) @@ -972,20 +985,20 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.rowIterator().hasNext == false) // Add rows [NULL, 2.2, 2, "abc"], [3, NULL, 3, ""], [4, 4.4, 4, "world] - batch.column(0).putNull(0) - batch.column(1).putDouble(0, 2.2) - batch.column(2).putInt(0, 2) - batch.column(3).putByteArray(0, "abc".getBytes(StandardCharsets.UTF_8)) - - batch.column(0).putInt(1, 3) - batch.column(1).putNull(1) - batch.column(2).putInt(1, 3) - batch.column(3).putByteArray(1, "".getBytes(StandardCharsets.UTF_8)) - - batch.column(0).putInt(2, 4) - batch.column(1).putDouble(2, 4.4) - batch.column(2).putInt(2, 4) - batch.column(3).putByteArray(2, "world".getBytes(StandardCharsets.UTF_8)) + columns(0).putNull(0) + columns(1).putDouble(0, 2.2) + columns(2).putInt(0, 2) + columns(3).putByteArray(0, "abc".getBytes(StandardCharsets.UTF_8)) + + columns(0).putInt(1, 3) + columns(1).putNull(1) + columns(2).putInt(1, 3) + columns(3).putByteArray(1, "".getBytes(StandardCharsets.UTF_8)) + + columns(0).putInt(2, 4) + columns(1).putDouble(2, 4.4) + columns(2).putInt(2, 4) + columns(3).putByteArray(2, "world".getBytes(StandardCharsets.UTF_8)) batch.setNumRows(3) def rowEquals(x: InternalRow, y: Row): Unit = { @@ -1232,7 +1245,7 @@ class ColumnarBatchSuite extends SparkFunSuite { test("exceeding maximum capacity should throw an error") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => - val column = ColumnVector.allocate(1, ByteType, memMode) + val column = allocate(1, ByteType, memMode) column.MAX_CAPACITY = 15 column.appendBytes(5, 0.toByte) // Successfully allocate twice the requested capacity From 183d4cb71fbcbf484fc85d8621e1fe04cbbc8195 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 24 Aug 2017 21:46:58 +0800 Subject: [PATCH 043/187] [SPARK-21759][SQL] In.checkInputDataTypes should not wrongly report unresolved plans for IN correlated subquery ## What changes were proposed in this pull request? With the check for structural integrity proposed in SPARK-21726, it is found that the optimization rule `PullupCorrelatedPredicates` can produce unresolved plans. For a correlated IN query looks like: SELECT t1.a FROM t1 WHERE t1.a IN (SELECT t2.c FROM t2 WHERE t1.b < t2.d); The query plan might look like: Project [a#0] +- Filter a#0 IN (list#4 [b#1]) : +- Project [c#2] : +- Filter (outer(b#1) < d#3) : +- LocalRelation , [c#2, d#3] +- LocalRelation , [a#0, b#1] After `PullupCorrelatedPredicates`, it produces query plan like: 'Project [a#0] +- 'Filter a#0 IN (list#4 [(b#1 < d#3)]) : +- Project [c#2, d#3] : +- LocalRelation , [c#2, d#3] +- LocalRelation , [a#0, b#1] Because the correlated predicate involves another attribute `d#3` in subquery, it has been pulled out and added into the `Project` on the top of the subquery. When `list` in `In` contains just one `ListQuery`, `In.checkInputDataTypes` checks if the size of `value` expressions matches the output size of subquery. In the above example, there is only `value` expression and the subquery output has two attributes `c#2, d#3`, so it fails the check and `In.resolved` returns `false`. We should not let `In.checkInputDataTypes` wrongly report unresolved plans to fail the structural integrity check. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #18968 from viirya/SPARK-21759. --- .../sql/catalyst/analysis/Analyzer.scala | 6 +- .../sql/catalyst/analysis/TypeCoercion.scala | 5 +- .../sql/catalyst/expressions/predicates.scala | 65 +++++++++---------- .../sql/catalyst/expressions/subquery.scala | 13 +++- .../sql/catalyst/optimizer/subquery.scala | 10 +-- .../PullupCorrelatedPredicatesSuite.scala | 52 +++++++++++++++ .../subq-input-typecheck.sql.out | 6 +- 7 files changed, 106 insertions(+), 51 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 70a3885d21531..1e934d0aa0e51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1286,8 +1286,10 @@ class Analyzer( resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) - case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved => - val expr = resolveSubQuery(l, plans)(ListQuery(_, _, exprId)) + case In(value, Seq(l @ ListQuery(sub, _, exprId, _))) if value.resolved && !l.resolved => + val expr = resolveSubQuery(l, plans)((plan, exprs) => { + ListQuery(plan, exprs, exprId, plan.output) + }) In(value, Seq(expr)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 06d8350db9891..9ffe646b5e4ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -402,7 +402,7 @@ object TypeCoercion { // Handle type casting required between value expression and subquery output // in IN subquery. - case i @ In(a, Seq(ListQuery(sub, children, exprId))) + case i @ In(a, Seq(ListQuery(sub, children, exprId, _))) if !i.resolved && flattenExpr(a).length == sub.output.length => // LHS is the value expression of IN subquery. val lhs = flattenExpr(a) @@ -434,7 +434,8 @@ object TypeCoercion { case _ => CreateStruct(castedLhs) } - In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId))) + val newSub = Project(castedRhs, sub) + In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output))) } else { i } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7bf10f199f1c7..613d6202b0b26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -138,32 +138,33 @@ case class Not(child: Expression) case class In(value: Expression, list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") + override def checkInputDataTypes(): TypeCheckResult = { - list match { - case ListQuery(sub, _, _) :: Nil => - val valExprs = value match { - case cns: CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - if (valExprs.length != sub.output.length) { - TypeCheckResult.TypeCheckFailure( - s""" - |The number of columns in the left hand side of an IN subquery does not match the - |number of columns in the output of subquery. - |#columns in left hand side: ${valExprs.length}. - |#columns in right hand side: ${sub.output.length}. - |Left side columns: - |[${valExprs.map(_.sql).mkString(", ")}]. - |Right side columns: - |[${sub.output.map(_.sql).mkString(", ")}]. - """.stripMargin) - } else { - val mismatchedColumns = valExprs.zip(sub.output).flatMap { - case (l, r) if l.dataType != r.dataType => - s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" - case _ => None + val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType)) + if (mismatchOpt.isDefined) { + list match { + case ListQuery(_, _, _, childOutputs) :: Nil => + val valExprs = value match { + case cns: CreateNamedStruct => cns.valExprs + case expr => Seq(expr) } - if (mismatchedColumns.nonEmpty) { + if (valExprs.length != childOutputs.length) { + TypeCheckResult.TypeCheckFailure( + s""" + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${valExprs.length}. + |#columns in right hand side: ${childOutputs.length}. + |Left side columns: + |[${valExprs.map(_.sql).mkString(", ")}]. + |Right side columns: + |[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) + } else { + val mismatchedColumns = valExprs.zip(childOutputs).flatMap { + case (l, r) if l.dataType != r.dataType => + s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" + case _ => None + } TypeCheckResult.TypeCheckFailure( s""" |The data type of one or more elements in the left hand side of an IN subquery @@ -173,20 +174,14 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { |Left side: |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. |Right side: - |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. - """.stripMargin) - } else { - TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") + |[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) } - } - case _ => - val mismatchOpt = list.find(l => l.dataType != value.dataType) - if (mismatchOpt.isDefined) { + case _ => TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + s"${value.dataType} != ${mismatchOpt.get.dataType}") - } else { - TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") - } + } + } else { + TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index d7b493d521ddb..c6146042ef1a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -274,9 +274,15 @@ object ScalarSubquery { case class ListQuery( plan: LogicalPlan, children: Seq[Expression] = Seq.empty, - exprId: ExprId = NamedExpression.newExprId) + exprId: ExprId = NamedExpression.newExprId, + childOutputs: Seq[Attribute] = Seq.empty) extends SubqueryExpression(plan, children, exprId) with Unevaluable { - override def dataType: DataType = plan.schema.fields.head.dataType + override def dataType: DataType = if (childOutputs.length > 1) { + childOutputs.toStructType + } else { + childOutputs.head.dataType + } + override lazy val resolved: Boolean = childrenResolved && plan.resolved && childOutputs.nonEmpty override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) override def toString: String = s"list#${exprId.id} $conditionString" @@ -284,7 +290,8 @@ case class ListQuery( ListQuery( plan.canonicalized, children.map(_.canonicalized), - ExprId(0)) + ExprId(0), + childOutputs.map(_.canonicalized.asInstanceOf[Attribute])) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 9dbb6b14aaac3..4386a10162767 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -68,11 +68,11 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { case (p, Not(Exists(sub, conditions, _))) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) Join(outerPlan, sub, LeftAnti, joinCond) - case (p, In(value, Seq(ListQuery(sub, conditions, _)))) => + case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) => val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) Join(outerPlan, sub, LeftSemi, joinCond) - case (p, Not(In(value, Seq(ListQuery(sub, conditions, _))))) => + case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. @@ -116,7 +116,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val exists = AttributeReference("exists", BooleanType, nullable = false)() newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)) exists - case In(value, Seq(ListQuery(sub, conditions, _))) => + case In(value, Seq(ListQuery(sub, conditions, _, _))) => val exists = AttributeReference("exists", BooleanType, nullable = false)() val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) val newConditions = (inConditions ++ conditions).reduceLeftOption(And) @@ -227,9 +227,9 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper case Exists(sub, children, exprId) if children.nonEmpty => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) Exists(newPlan, newCond, exprId) - case ListQuery(sub, _, exprId) => + case ListQuery(sub, _, exprId, childOutputs) => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) - ListQuery(newPlan, newCond, exprId) + ListQuery(newPlan, newCond, exprId, childOutputs) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala new file mode 100644 index 0000000000000..169b8737d808b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{In, ListQuery} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class PullupCorrelatedPredicatesSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("PullupCorrelatedPredicates", Once, + PullupCorrelatedPredicates) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.double) + val testRelation2 = LocalRelation('c.int, 'd.double) + + test("PullupCorrelatedPredicates should not produce unresolved plan") { + val correlatedSubquery = + testRelation2 + .where('b < 'd) + .select('c) + val outerQuery = + testRelation + .where(In('a, Seq(ListQuery(correlatedSubquery)))) + .select('a).analyze + assert(outerQuery.resolved) + + val optimized = Optimize.execute(outerQuery) + assert(optimized.resolved) + } +} diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out index 9ea9d3c4c6f40..70aeb9373f3c7 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out @@ -80,8 +80,7 @@ number of columns in the output of subquery. Left side columns: [t1.`t1a`]. Right side columns: -[t2.`t2a`, t2.`t2b`]. - ; +[t2.`t2a`, t2.`t2b`].; -- !query 6 @@ -102,5 +101,4 @@ number of columns in the output of subquery. Left side columns: [t1.`t1a`, t1.`t1b`]. Right side columns: -[t2.`t2a`]. - ; +[t2.`t2a`].; From 2dd37d827f2e443dcb3eaf8a95437d179130d55c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 24 Aug 2017 16:44:12 +0200 Subject: [PATCH 044/187] [SPARK-21826][SQL] outer broadcast hash join should not throw NPE ## What changes were proposed in this pull request? This is a bug introduced by https://github.com/apache/spark/pull/11274/files#diff-7adb688cbfa583b5711801f196a074bbL274 . Non-equal join condition should only be applied when the equal-join condition matches. ## How was this patch tested? regression test Author: Wenchen Fan Closes #19036 from cloud-fan/bug. --- .../joins/BroadcastHashJoinExec.scala | 2 +- .../org/apache/spark/sql/JoinSuite.scala | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index bfa1e9d49a545..2f52a089ef9bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -283,8 +283,8 @@ case class BroadcastHashJoinExec( s""" |boolean $conditionPassed = true; |${eval.trim} - |${ev.code} |if ($matched != null) { + | ${ev.code} | $conditionPassed = !${ev.isNull} && ${ev.value}; |} """.stripMargin diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 86fe09bd977af..453052a8ce191 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.language.existentials @@ -26,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType class JoinSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -767,4 +769,22 @@ class JoinSuite extends QueryTest with SharedSQLContext { } } } + + test("outer broadcast hash join should not throw NPE") { + withTempView("v1", "v2") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { + Seq(2 -> 2).toDF("x", "y").createTempView("v1") + + spark.createDataFrame( + Seq(Row(1, "a")).asJava, + new StructType().add("i", "int", nullable = false).add("j", "string", nullable = false) + ).createTempView("v2") + + checkAnswer( + sql("select x, y, i, j from v1 left join v2 on x = i and y < length(j)"), + Row(2, 2, null, null) + ) + } + } + } } From d3abb36990d928a8445a8c69ddebeabdfeb1484d Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 24 Aug 2017 10:23:59 -0700 Subject: [PATCH 045/187] [SPARK-21788][SS] Handle more exceptions when stopping a streaming query ## What changes were proposed in this pull request? Add more cases we should view as a normal query stop rather than a failure. ## How was this patch tested? The new unit tests. Author: Shixiong Zhu Closes #18997 from zsxwing/SPARK-21788. --- .../execution/streaming/StreamExecution.scala | 34 ++++++++++- .../spark/sql/streaming/StreamSuite.scala | 60 ++++++++++++++++++- 2 files changed, 89 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 432b2d4925ae2..c224f2f9f1404 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.execution.streaming -import java.io.{InterruptedIOException, IOException} +import java.io.{InterruptedIOException, IOException, UncheckedIOException} +import java.nio.channels.ClosedByInterruptException import java.util.UUID -import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.{CountDownLatch, ExecutionException, TimeUnit} import java.util.concurrent.atomic.AtomicReference import java.util.concurrent.locks.ReentrantLock @@ -27,6 +28,7 @@ import scala.collection.mutable.{Map => MutableMap} import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal +import com.google.common.util.concurrent.UncheckedExecutionException import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging @@ -335,7 +337,7 @@ class StreamExecution( // `stop()` is already called. Let `finally` finish the cleanup. } } catch { - case _: InterruptedException | _: InterruptedIOException if state.get == TERMINATED => + case e if isInterruptedByStop(e) => // interrupted by stop() updateStatusMessage("Stopped") case e: IOException if e.getMessage != null @@ -407,6 +409,32 @@ class StreamExecution( } } + private def isInterruptedByStop(e: Throwable): Boolean = { + if (state.get == TERMINATED) { + e match { + // InterruptedIOException - thrown when an I/O operation is interrupted + // ClosedByInterruptException - thrown when an I/O operation upon a channel is interrupted + case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException => + true + // The cause of the following exceptions may be one of the above exceptions: + // + // UncheckedIOException - thrown by codes that cannot throw a checked IOException, such as + // BiFunction.apply + // ExecutionException - thrown by codes running in a thread pool and these codes throw an + // exception + // UncheckedExecutionException - thrown by codes that cannot throw a checked + // ExecutionException, such as BiFunction.apply + case e2 @ (_: UncheckedIOException | _: ExecutionException | _: UncheckedExecutionException) + if e2.getCause != null => + isInterruptedByStop(e2.getCause) + case _ => + false + } + } else { + false + } + } + /** * Populate the start offsets to start the execution at the current offsets stored in the sink * (i.e. avoid reprocessing data that we have already processed). This function must be called diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 012cccfdd9166..d0b2041a8644f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.streaming -import java.io.{File, InterruptedIOException, IOException} -import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} +import java.io.{File, InterruptedIOException, IOException, UncheckedIOException} +import java.nio.channels.ClosedByInterruptException +import java.util.concurrent.{CountDownLatch, ExecutionException, TimeoutException, TimeUnit} import scala.reflect.ClassTag import scala.util.control.ControlThrowable +import com.google.common.util.concurrent.UncheckedExecutionException import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration @@ -691,6 +693,31 @@ class StreamSuite extends StreamTest { } } } + + for (e <- Seq( + new InterruptedException, + new InterruptedIOException, + new ClosedByInterruptException, + new UncheckedIOException("test", new ClosedByInterruptException), + new ExecutionException("test", new InterruptedException), + new UncheckedExecutionException("test", new InterruptedException))) { + test(s"view ${e.getClass.getSimpleName} as a normal query stop") { + ThrowingExceptionInCreateSource.createSourceLatch = new CountDownLatch(1) + ThrowingExceptionInCreateSource.exception = e + val query = spark + .readStream + .format(classOf[ThrowingExceptionInCreateSource].getName) + .load() + .writeStream + .format("console") + .start() + assert(ThrowingExceptionInCreateSource.createSourceLatch + .await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS), + "ThrowingExceptionInCreateSource.createSource wasn't called before timeout") + query.stop() + assert(query.exception.isEmpty) + } + } } abstract class FakeSource extends StreamSourceProvider { @@ -824,3 +851,32 @@ class TestStateStoreProvider extends StateStoreProvider { override def getStore(version: Long): StateStore = null } + +/** A fake source that throws `ThrowingExceptionInCreateSource.exception` in `createSource` */ +class ThrowingExceptionInCreateSource extends FakeSource { + + override def createSource( + spark: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + ThrowingExceptionInCreateSource.createSourceLatch.countDown() + try { + Thread.sleep(30000) + throw new TimeoutException("sleep was not interrupted in 30 seconds") + } catch { + case _: InterruptedException => + throw ThrowingExceptionInCreateSource.exception + } + } +} + +object ThrowingExceptionInCreateSource { + /** + * A latch to allow the user to wait until `ThrowingExceptionInCreateSource.createSource` is + * called. + */ + @volatile var createSourceLatch: CountDownLatch = null + @volatile var exception: Exception = null +} From 763b83ee84cbb6f263218c471dd9198dd6bee411 Mon Sep 17 00:00:00 2001 From: "xu.zhang" Date: Thu, 24 Aug 2017 14:27:52 -0700 Subject: [PATCH 046/187] [SPARK-21701][CORE] Enable RPC client to use ` SO_RCVBUF` and ` SO_SNDBUF` in SparkConf. ## What changes were proposed in this pull request? TCP parameters like SO_RCVBUF and SO_SNDBUF can be set in SparkConf, and `org.apache.spark.network.server.TransportServe`r can use those parameters to build server by leveraging netty. But for TransportClientFactory, there is no such way to set those parameters from SparkConf. This could be inconsistent in server and client side when people set parameters in SparkConf. So this PR make RPC client to be enable to use those TCP parameters as well. ## How was this patch tested? Existing tests. Author: xu.zhang Closes #18964 from neoremind/add_client_param. --- .../spark/network/client/TransportClientFactory.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index b50e043d5c9ce..8add4e1ab021d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -210,6 +210,14 @@ private TransportClient createClient(InetSocketAddress address) .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()) .option(ChannelOption.ALLOCATOR, pooledAllocator); + if (conf.receiveBuf() > 0) { + bootstrap.option(ChannelOption.SO_RCVBUF, conf.receiveBuf()); + } + + if (conf.sendBuf() > 0) { + bootstrap.option(ChannelOption.SO_SNDBUF, conf.sendBuf()); + } + final AtomicReference clientRef = new AtomicReference<>(); final AtomicReference channelRef = new AtomicReference<>(); From 05af2de0fdce625041b99908adc320c576bac116 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 24 Aug 2017 16:33:55 -0700 Subject: [PATCH 047/187] [SPARK-21830][SQL] Bump ANTLR version and fix a few issues. ## What changes were proposed in this pull request? This PR bumps the ANTLR version to 4.7, and fixes a number of small parser related issues uncovered by the bump. The main reason for upgrading is that in some cases the current version of ANTLR (4.5) can exhibit exponential slowdowns if it needs to parse boolean predicates. For example the following query will take forever to parse: ```sql SELECT * FROM RANGE(1000) WHERE TRUE AND NOT upper(DESCRIPTION) LIKE '%FOO%' AND NOT upper(DESCRIPTION) LIKE '%FOO%' AND NOT upper(DESCRIPTION) LIKE '%FOO%' AND NOT upper(DESCRIPTION) LIKE '%FOO%' AND NOT upper(DESCRIPTION) LIKE '%FOO%' AND NOT upper(DESCRIPTION) LIKE '%FOO%' AND NOT upper(DESCRIPTION) LIKE '%FOO%' AND NOT upper(DESCRIPTION) LIKE '%FOO%' AND NOT upper(DESCRIPTION) LIKE '%FOO%' AND NOT upper(DESCRIPTION) LIKE '%FOO%' AND NOT upper(DESCRIPTION) LIKE '%FOO%' AND NOT upper(DESCRIPTION) LIKE '%FOO%' AND NOT upper(DESCRIPTION) LIKE '%FOO%' AND NOT upper(DESCRIPTION) LIKE '%FOO%' AND NOT upper(DESCRIPTION) LIKE '%FOO%' AND NOT upper(DESCRIPTION) LIKE '%FOO%' AND NOT upper(DESCRIPTION) LIKE '%FOO%' AND NOT upper(DESCRIPTION) LIKE '%FOO%' ``` This is caused by a know bug in ANTLR (https://github.com/antlr/antlr4/issues/994), which was fixed in version 4.6. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #19042 from hvanhovell/SPARK-21830. --- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- pom.xml | 2 +- project/SparkBuild.scala | 1 + .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 6 +++++- .../spark/sql/catalyst/parser/AstBuilder.scala | 4 ++++ .../spark/sql/catalyst/parser/ParseDriver.scala | 2 +- .../catalyst/parser/TableSchemaParserSuite.scala | 14 ++++++++------ .../sql-tests/results/show-tables.sql.out | 4 ++-- .../org/apache/spark/sql/jdbc/JDBCWriteSuite.scala | 2 +- 10 files changed, 25 insertions(+), 14 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 01af2c75b0251..de1750777d36c 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -5,7 +5,7 @@ activation-1.1.1.jar aircompressor-0.3.jar antlr-2.7.7.jar antlr-runtime-3.4.jar -antlr4-runtime-4.5.3.jar +antlr4-runtime-4.7.jar aopalliance-1.0.jar aopalliance-repackaged-2.4.0-b34.jar apache-log4j-extras-1.2.17.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 69f3a4bb60f8b..da826a7ee8b12 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -5,7 +5,7 @@ activation-1.1.1.jar aircompressor-0.3.jar antlr-2.7.7.jar antlr-runtime-3.4.jar -antlr4-runtime-4.5.3.jar +antlr4-runtime-4.7.jar aopalliance-1.0.jar aopalliance-repackaged-2.4.0-b34.jar apache-log4j-extras-1.2.17.jar diff --git a/pom.xml b/pom.xml index c0df3ef0fe200..8b4a6c5425a98 100644 --- a/pom.xml +++ b/pom.xml @@ -178,7 +178,7 @@ 3.5.2 1.3.9 0.9.3 - 4.5.3 + 4.7 1.1 2.52.0 2.6 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 7565e14c9b9ed..18059adc864b5 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -474,6 +474,7 @@ object OldDeps { object Catalyst { lazy val settings = antlr4Settings ++ Seq( + antlr4Version in Antlr4 := "4.7", antlr4PackageName in Antlr4 := Some("org.apache.spark.sql.catalyst.parser"), antlr4GenListener in Antlr4 := true, antlr4GenVisitor in Antlr4 := true diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 954955b6b1293..5d4363f945bf8 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -64,6 +64,10 @@ singleDataType : dataType EOF ; +singleTableSchema + : colTypeList EOF + ; + statement : query #statementDefault | USE db=identifier #use @@ -974,7 +978,7 @@ CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; STRING : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' - | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"' + | '"' ( ~('"'|'\\') | ('\\' .) )* '"' ; BIGINT_LITERAL diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 22c5484b76638..8a45c5216781b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -89,6 +89,10 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging visitSparkDataType(ctx.dataType) } + override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = { + withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList))) + } + /* ******************************************************************************************** * Plan parsing * ******************************************************************************************** */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 09598ffe770c6..0d9ad218e48db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -61,7 +61,7 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { * definitions which will preserve the correct Hive metadata. */ override def parseTableSchema(sqlText: String): StructType = parse(sqlText) { parser => - StructType(astBuilder.visitColTypeList(parser.colTypeList())) + astBuilder.visitSingleTableSchema(parser.singleTableSchema()) } /** Creates LogicalPlan for a given SQL string. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala index 48aaec44885d4..6803fc307f919 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala @@ -79,10 +79,12 @@ class TableSchemaParserSuite extends SparkFunSuite { } // Negative cases - assertError("") - assertError("a") - assertError("a INT b long") - assertError("a INT,, b long") - assertError("a INT, b long,,") - assertError("a INT, b long, c int,") + test("Negative cases") { + assertError("") + assertError("a") + assertError("a INT b long") + assertError("a INT,, b long") + assertError("a INT, b long,,") + assertError("a INT, b long, c int,") + } } diff --git a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out index da729cd757cfc..975bb06124744 100644 --- a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out @@ -164,7 +164,7 @@ struct<> -- !query 13 output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input '' expecting 'LIKE'(line 1, pos 19) +mismatched input '' expecting {'FROM', 'IN', 'LIKE'}(line 1, pos 19) == SQL == SHOW TABLE EXTENDED @@ -187,7 +187,7 @@ struct<> -- !query 15 output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input 'PARTITION' expecting 'LIKE'(line 1, pos 20) +mismatched input 'PARTITION' expecting {'FROM', 'IN', 'LIKE'}(line 1, pos 20) == SQL == SHOW TABLE EXTENDED PARTITION(c='Us', d=1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index b7f97f204b24c..1985b1dc82879 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -468,7 +468,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { .option("createTableColumnTypes", "`name char(20)") // incorrectly quoted column .jdbc(url1, "TEST.USERDBTYPETEST", properties) }.getMessage() - assert(msg.contains("no viable alternative at input")) + assert(msg.contains("extraneous input")) } test("SPARK-10849: jdbc CreateTableColumnTypes duplicate columns") { From f3676d63913e0706e071b71e1742b8d57b102fba Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Fri, 25 Aug 2017 10:22:27 +0800 Subject: [PATCH 048/187] [SPARK-21108][ML] convert LinearSVC to aggregator framework ## What changes were proposed in this pull request? convert LinearSVC to new aggregator framework ## How was this patch tested? existing unit test. Author: Yuhao Yang Closes #18315 from hhbyyh/svcAggregator. --- .../spark/ml/classification/LinearSVC.scala | 204 ++---------------- .../ml/optim/aggregator/HingeAggregator.scala | 105 +++++++++ .../ml/classification/LinearSVCSuite.scala | 7 +- .../aggregator/HingeAggregatorSuite.scala | 163 ++++++++++++++ .../aggregator/LogisticAggregatorSuite.scala | 2 - 5 files changed, 286 insertions(+), 195 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeAggregator.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeAggregatorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 8d556deef2be8..3b0666c36d20a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -25,11 +25,11 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg._ -import org.apache.spark.ml.linalg.BLAS._ +import org.apache.spark.ml.optim.aggregator.HingeAggregator +import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ @@ -214,10 +214,20 @@ class LinearSVC @Since("2.2.0") ( } val featuresStd = summarizer.variance.toArray.map(math.sqrt) + val getFeaturesStd = (j: Int) => featuresStd(j) val regParamL2 = $(regParam) val bcFeaturesStd = instances.context.broadcast(featuresStd) - val costFun = new LinearSVCCostFun(instances, $(fitIntercept), - $(standardization), bcFeaturesStd, regParamL2, $(aggregationDepth)) + val regularization = if (regParamL2 != 0.0) { + val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures + Some(new L2Regularization(regParamL2, shouldApply, + if ($(standardization)) None else Some(getFeaturesStd))) + } else { + None + } + + val getAggregatorFunc = new HingeAggregator(bcFeaturesStd, $(fitIntercept))(_) + val costFun = new RDDLossFunction(instances, getAggregatorFunc, regularization, + $(aggregationDepth)) def regParamL1Fun = (index: Int) => 0D val optimizer = new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) @@ -372,189 +382,3 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] { } } } - -/** - * LinearSVCCostFun implements Breeze's DiffFunction[T] for hinge loss function - */ -private class LinearSVCCostFun( - instances: RDD[Instance], - fitIntercept: Boolean, - standardization: Boolean, - bcFeaturesStd: Broadcast[Array[Double]], - regParamL2: Double, - aggregationDepth: Int) extends DiffFunction[BDV[Double]] { - - override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { - val coeffs = Vectors.fromBreeze(coefficients) - val bcCoeffs = instances.context.broadcast(coeffs) - val featuresStd = bcFeaturesStd.value - val numFeatures = featuresStd.length - - val svmAggregator = { - val seqOp = (c: LinearSVCAggregator, instance: Instance) => c.add(instance) - val combOp = (c1: LinearSVCAggregator, c2: LinearSVCAggregator) => c1.merge(c2) - - instances.treeAggregate( - new LinearSVCAggregator(bcCoeffs, bcFeaturesStd, fitIntercept) - )(seqOp, combOp, aggregationDepth) - } - - val totalGradientArray = svmAggregator.gradient.toArray - // regVal is the sum of coefficients squares excluding intercept for L2 regularization. - val regVal = if (regParamL2 == 0.0) { - 0.0 - } else { - var sum = 0.0 - coeffs.foreachActive { case (index, value) => - // We do not apply regularization to the intercepts - if (index != numFeatures) { - // The following code will compute the loss of the regularization; also - // the gradient of the regularization, and add back to totalGradientArray. - sum += { - if (standardization) { - totalGradientArray(index) += regParamL2 * value - value * value - } else { - if (featuresStd(index) != 0.0) { - // If `standardization` is false, we still standardize the data - // to improve the rate of convergence; as a result, we have to - // perform this reverse standardization by penalizing each component - // differently to get effectively the same objective function when - // the training dataset is not standardized. - val temp = value / (featuresStd(index) * featuresStd(index)) - totalGradientArray(index) += regParamL2 * temp - value * temp - } else { - 0.0 - } - } - } - } - } - 0.5 * regParamL2 * sum - } - bcCoeffs.destroy(blocking = false) - - (svmAggregator.loss + regVal, new BDV(totalGradientArray)) - } -} - -/** - * LinearSVCAggregator computes the gradient and loss for hinge loss function, as used - * in binary classification for instances in sparse or dense vector in an online fashion. - * - * Two LinearSVCAggregator can be merged together to have a summary of loss and gradient of - * the corresponding joint dataset. - * - * This class standardizes feature values during computation using bcFeaturesStd. - * - * @param bcCoefficients The coefficients corresponding to the features. - * @param fitIntercept Whether to fit an intercept term. - * @param bcFeaturesStd The standard deviation values of the features. - */ -private class LinearSVCAggregator( - bcCoefficients: Broadcast[Vector], - bcFeaturesStd: Broadcast[Array[Double]], - fitIntercept: Boolean) extends Serializable { - - private val numFeatures: Int = bcFeaturesStd.value.length - private val numFeaturesPlusIntercept: Int = if (fitIntercept) numFeatures + 1 else numFeatures - private var weightSum: Double = 0.0 - private var lossSum: Double = 0.0 - @transient private lazy val coefficientsArray = bcCoefficients.value match { - case DenseVector(values) => values - case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector" + - s" but got type ${bcCoefficients.value.getClass}.") - } - private lazy val gradientSumArray = new Array[Double](numFeaturesPlusIntercept) - - /** - * Add a new training instance to this LinearSVCAggregator, and update the loss and gradient - * of the objective function. - * - * @param instance The instance of data point to be added. - * @return This LinearSVCAggregator object. - */ - def add(instance: Instance): this.type = { - instance match { case Instance(label, weight, features) => - - if (weight == 0.0) return this - val localFeaturesStd = bcFeaturesStd.value - val localCoefficients = coefficientsArray - val localGradientSumArray = gradientSumArray - - val dotProduct = { - var sum = 0.0 - features.foreachActive { (index, value) => - if (localFeaturesStd(index) != 0.0 && value != 0.0) { - sum += localCoefficients(index) * value / localFeaturesStd(index) - } - } - if (fitIntercept) sum += localCoefficients(numFeaturesPlusIntercept - 1) - sum - } - // Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x))) - // Therefore the gradient is -(2y - 1)*x - val labelScaled = 2 * label - 1.0 - val loss = if (1.0 > labelScaled * dotProduct) { - weight * (1.0 - labelScaled * dotProduct) - } else { - 0.0 - } - - if (1.0 > labelScaled * dotProduct) { - val gradientScale = -labelScaled * weight - features.foreachActive { (index, value) => - if (localFeaturesStd(index) != 0.0 && value != 0.0) { - localGradientSumArray(index) += value * gradientScale / localFeaturesStd(index) - } - } - if (fitIntercept) { - localGradientSumArray(localGradientSumArray.length - 1) += gradientScale - } - } - - lossSum += loss - weightSum += weight - this - } - } - - /** - * Merge another LinearSVCAggregator, and update the loss and gradient - * of the objective function. - * (Note that it's in place merging; as a result, `this` object will be modified.) - * - * @param other The other LinearSVCAggregator to be merged. - * @return This LinearSVCAggregator object. - */ - def merge(other: LinearSVCAggregator): this.type = { - - if (other.weightSum != 0.0) { - weightSum += other.weightSum - lossSum += other.lossSum - - var i = 0 - val localThisGradientSumArray = this.gradientSumArray - val localOtherGradientSumArray = other.gradientSumArray - val len = localThisGradientSumArray.length - while (i < len) { - localThisGradientSumArray(i) += localOtherGradientSumArray(i) - i += 1 - } - } - this - } - - def loss: Double = if (weightSum != 0) lossSum / weightSum else 0.0 - - def gradient: Vector = { - if (weightSum != 0) { - val result = Vectors.dense(gradientSumArray.clone()) - scal(1.0 / weightSum, result) - result - } else { - Vectors.dense(new Array[Double](numFeaturesPlusIntercept)) - } - } -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeAggregator.scala new file mode 100644 index 0000000000000..0300500a34ec0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeAggregator.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg._ + +/** + * HingeAggregator computes the gradient and loss for Hinge loss function as used in + * binary classification for instances in sparse or dense vector in an online fashion. + * + * Two HingeAggregators can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + * This class standardizes feature values during computation using bcFeaturesStd. + * + * @param bcCoefficients The coefficients corresponding to the features. + * @param fitIntercept Whether to fit an intercept term. + * @param bcFeaturesStd The standard deviation values of the features. + */ +private[ml] class HingeAggregator( + bcFeaturesStd: Broadcast[Array[Double]], + fitIntercept: Boolean)(bcCoefficients: Broadcast[Vector]) + extends DifferentiableLossAggregator[Instance, HingeAggregator] { + + private val numFeatures: Int = bcFeaturesStd.value.length + private val numFeaturesPlusIntercept: Int = if (fitIntercept) numFeatures + 1 else numFeatures + @transient private lazy val coefficientsArray = bcCoefficients.value match { + case DenseVector(values) => values + case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector" + + s" but got type ${bcCoefficients.value.getClass}.") + } + protected override val dim: Int = numFeaturesPlusIntercept + + /** + * Add a new training instance to this HingeAggregator, and update the loss and gradient + * of the objective function. + * + * @param instance The instance of data point to be added. + * @return This HingeAggregator object. + */ + def add(instance: Instance): this.type = { + instance match { case Instance(label, weight, features) => + require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + + s" Expecting $numFeatures but got ${features.size}.") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") + + if (weight == 0.0) return this + val localFeaturesStd = bcFeaturesStd.value + val localCoefficients = coefficientsArray + val localGradientSumArray = gradientSumArray + + val dotProduct = { + var sum = 0.0 + features.foreachActive { (index, value) => + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + sum += localCoefficients(index) * value / localFeaturesStd(index) + } + } + if (fitIntercept) sum += localCoefficients(numFeaturesPlusIntercept - 1) + sum + } + // Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x))) + // Therefore the gradient is -(2y - 1)*x + val labelScaled = 2 * label - 1.0 + val loss = if (1.0 > labelScaled * dotProduct) { + (1.0 - labelScaled * dotProduct) * weight + } else { + 0.0 + } + + if (1.0 > labelScaled * dotProduct) { + val gradientScale = -labelScaled * weight + features.foreachActive { (index, value) => + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += value * gradientScale / localFeaturesStd(index) + } + } + if (fitIntercept) { + localGradientSumArray(localGradientSumArray.length - 1) += gradientScale + } + } + + lossSum += loss + weightSum += weight + this + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index f2b00d0bae1d6..41a5d22dd6283 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -25,7 +25,8 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LinearSVCSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.ml.param.{ParamMap, ParamsSuite} +import org.apache.spark.ml.optim.aggregator.HingeAggregator +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -170,10 +171,10 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau assert(model2.intercept !== 0.0) } - test("sparse coefficients in SVCAggregator") { + test("sparse coefficients in HingeAggregator") { val bcCoefficients = spark.sparkContext.broadcast(Vectors.sparse(2, Array(0), Array(1.0))) val bcFeaturesStd = spark.sparkContext.broadcast(Array(1.0)) - val agg = new LinearSVCAggregator(bcCoefficients, bcFeaturesStd, true) + val agg = new HingeAggregator(bcFeaturesStd, true)(bcCoefficients) val thrown = withClue("LinearSVCAggregator cannot handle sparse coefficients") { intercept[IllegalArgumentException] { agg.add(Instance(1.0, 1.0, Vectors.dense(1.0))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeAggregatorSuite.scala new file mode 100644 index 0000000000000..61b48ffa10944 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeAggregatorSuite.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class HingeAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { + + import DifferentiableLossAggregatorSuite.getClassificationSummarizers + + @transient var instances: Array[Instance] = _ + @transient var instancesConstantFeature: Array[Instance] = _ + @transient var instancesConstantFeatureFiltered: Array[Instance] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + instances = Array( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)), + Instance(0.0, 0.3, Vectors.dense(4.0, 0.5)) + ) + instancesConstantFeature = Array( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.0, 1.0)), + Instance(1.0, 0.3, Vectors.dense(1.0, 0.5)) + ) + instancesConstantFeatureFiltered = Array( + Instance(0.0, 0.1, Vectors.dense(2.0)), + Instance(1.0, 0.5, Vectors.dense(1.0)), + Instance(2.0, 0.3, Vectors.dense(0.5)) + ) + } + + /** Get summary statistics for some data and create a new HingeAggregator. */ + private def getNewAggregator( + instances: Array[Instance], + coefficients: Vector, + fitIntercept: Boolean): HingeAggregator = { + val (featuresSummarizer, ySummarizer) = + DifferentiableLossAggregatorSuite.getClassificationSummarizers(instances) + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) + val bcFeaturesStd = spark.sparkContext.broadcast(featuresStd) + val bcCoefficients = spark.sparkContext.broadcast(coefficients) + new HingeAggregator(bcFeaturesStd, fitIntercept)(bcCoefficients) + } + + test("aggregator add method input size") { + val coefArray = Array(1.0, 2.0) + val interceptArray = Array(2.0) + val agg = getNewAggregator(instances, Vectors.dense(coefArray ++ interceptArray), + fitIntercept = true) + withClue("HingeAggregator features dimension must match coefficients dimension") { + intercept[IllegalArgumentException] { + agg.add(Instance(1.0, 1.0, Vectors.dense(2.0))) + } + } + } + + test("negative weight") { + val coefArray = Array(1.0, 2.0) + val interceptArray = Array(2.0) + val agg = getNewAggregator(instances, Vectors.dense(coefArray ++ interceptArray), + fitIntercept = true) + withClue("HingeAggregator does not support negative instance weights") { + intercept[IllegalArgumentException] { + agg.add(Instance(1.0, -1.0, Vectors.dense(2.0, 1.0))) + } + } + } + + test("check sizes") { + val rng = new scala.util.Random + val numFeatures = instances.head.features.size + val coefWithIntercept = Vectors.dense(Array.fill(numFeatures + 1)(rng.nextDouble)) + val coefWithoutIntercept = Vectors.dense(Array.fill(numFeatures)(rng.nextDouble)) + val aggIntercept = getNewAggregator(instances, coefWithIntercept, fitIntercept = true) + val aggNoIntercept = getNewAggregator(instances, coefWithoutIntercept, + fitIntercept = false) + instances.foreach(aggIntercept.add) + instances.foreach(aggNoIntercept.add) + + assert(aggIntercept.gradient.size === numFeatures + 1) + assert(aggNoIntercept.gradient.size === numFeatures) + } + + test("check correctness") { + val coefArray = Array(1.0, 2.0) + val intercept = 1.0 + val numFeatures = instances.head.features.size + val (featuresSummarizer, _) = getClassificationSummarizers(instances) + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) + val weightSum = instances.map(_.weight).sum + + val agg = getNewAggregator(instances, Vectors.dense(coefArray ++ Array(intercept)), + fitIntercept = true) + instances.foreach(agg.add) + + // compute the loss + val stdCoef = coefArray.indices.map(i => coefArray(i) / featuresStd(i)).toArray + val lossSum = instances.map { case Instance(l, w, f) => + val margin = BLAS.dot(Vectors.dense(stdCoef), f) + intercept + val labelScaled = 2 * l - 1.0 + if (1.0 > labelScaled * margin) { + (1.0 - labelScaled * margin) * w + } else { + 0.0 + } + }.sum + val loss = lossSum / weightSum + + // compute the gradients + val gradientCoef = new Array[Double](numFeatures) + var gradientIntercept = 0.0 + instances.foreach { case Instance(l, w, f) => + val margin = BLAS.dot(f, Vectors.dense(coefArray)) + intercept + if (1.0 > (2 * l - 1.0) * margin) { + gradientCoef.indices.foreach { i => + gradientCoef(i) += f(i) * -(2 * l - 1.0) * w / featuresStd(i) + } + gradientIntercept += -(2 * l - 1.0) * w + } + } + val gradient = Vectors.dense((gradientCoef ++ Array(gradientIntercept)).map(_ / weightSum)) + + assert(loss ~== agg.loss relTol 0.01) + assert(gradient ~== agg.gradient relTol 0.01) + } + + test("check with zero standard deviation") { + val binaryCoefArray = Array(1.0, 2.0) + val intercept = 1.0 + val aggConstantFeatureBinary = getNewAggregator(instancesConstantFeature, + Vectors.dense(binaryCoefArray ++ Array(intercept)), fitIntercept = true) + instancesConstantFeature.foreach(aggConstantFeatureBinary.add) + + val aggConstantFeatureBinaryFiltered = getNewAggregator(instancesConstantFeatureFiltered, + Vectors.dense(binaryCoefArray ++ Array(intercept)), fitIntercept = true) + instancesConstantFeatureFiltered.foreach(aggConstantFeatureBinaryFiltered.add) + + // constant features should not affect gradient + assert(aggConstantFeatureBinary.gradient(0) === 0.0) + assert(aggConstantFeatureBinary.gradient(1) == aggConstantFeatureBinaryFiltered.gradient(0)) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala index 16ef4af4f94e8..4c7913d5d2577 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala @@ -217,8 +217,6 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { }.sum val loss = lossSum / weightSum - - // compute the gradients val gradientCoef = new Array[Double](numFeatures) var gradientIntercept = 0.0 From 7d16776d28da5bcf656f0d8556b15ed3a5edca44 Mon Sep 17 00:00:00 2001 From: mike Date: Fri, 25 Aug 2017 07:22:34 +0100 Subject: [PATCH 049/187] [SPARK-21255][SQL][WIP] Fixed NPE when creating encoder for enum ## What changes were proposed in this pull request? Fixed NPE when creating encoder for enum. When you try to create an encoder for Enum type (or bean with enum property) via Encoders.bean(...), it fails with NullPointerException at TypeToken:495. I did a little research and it turns out, that in JavaTypeInference following code ``` def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { val beanInfo = Introspector.getBeanInfo(beanClass) beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") .filter(_.getReadMethod != null) } ``` filters out properties named "class", because we wouldn't want to serialize that. But enum types have another property of type Class named "declaringClass", which we are trying to inspect recursively. Eventually we try to inspect ClassLoader class, which has property "defaultAssertionStatus" with no read method, which leads to NPE at TypeToken:495. I added property name "declaringClass" to filtering to resolve this. ## How was this patch tested? Unit test in JavaDatasetSuite which creates an encoder for enum Author: mike Author: Mikhail Sveshnikov Closes #18488 from mike0sv/enum-support. --- .../sql/catalyst/JavaTypeInference.scala | 40 ++++++++++ .../catalyst/encoders/ExpressionEncoder.scala | 14 +++- .../expressions/objects/objects.scala | 4 +- .../apache/spark/sql/JavaDatasetSuite.java | 77 +++++++++++++++++++ 4 files changed, 131 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 21363d3ba82c1..33f6ce080c339 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils /** * Type-inference utilities for POJOs and Java collections. @@ -118,6 +119,10 @@ object JavaTypeInference { val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet) (MapType(keyDataType, valueDataType, nullable), true) + case other if other.isEnum => + (StructType(Seq(StructField(typeToken.getRawType.getSimpleName, + StringType, nullable = false))), true) + case other => if (seenTypeSet.contains(other)) { throw new UnsupportedOperationException( @@ -140,6 +145,7 @@ object JavaTypeInference { def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { val beanInfo = Introspector.getBeanInfo(beanClass) beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + .filterNot(_.getName == "declaringClass") .filter(_.getReadMethod != null) } @@ -303,6 +309,11 @@ object JavaTypeInference { keyData :: valueData :: Nil, returnNullable = false) + case other if other.isEnum => + StaticInvoke(JavaTypeInference.getClass, ObjectType(other), "deserializeEnumName", + expressions.Literal.create(other.getEnumConstants.apply(0), ObjectType(other)) + :: getPath :: Nil) + case other => val properties = getJavaBeanReadableAndWritableProperties(other) val setters = properties.map { p => @@ -345,6 +356,30 @@ object JavaTypeInference { } } + /** Returns a mapping from enum value to int for given enum type */ + def enumSerializer[T <: Enum[T]](enum: Class[T]): T => UTF8String = { + assert(enum.isEnum) + inputObject: T => + UTF8String.fromString(inputObject.name()) + } + + /** Returns value index for given enum type and value */ + def serializeEnumName[T <: Enum[T]](enum: UTF8String, inputObject: T): UTF8String = { + enumSerializer(Utils.classForName(enum.toString).asInstanceOf[Class[T]])(inputObject) + } + + /** Returns a mapping from int to enum value for given enum type */ + def enumDeserializer[T <: Enum[T]](enum: Class[T]): InternalRow => T = { + assert(enum.isEnum) + value: InternalRow => + Enum.valueOf(enum, value.getUTF8String(0).toString) + } + + /** Returns enum value for given enum type and value index */ + def deserializeEnumName[T <: Enum[T]](typeDummy: T, inputObject: InternalRow): T = { + enumDeserializer(typeDummy.getClass.asInstanceOf[Class[T]])(inputObject) + } + private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { @@ -429,6 +464,11 @@ object JavaTypeInference { valueNullable = true ) + case other if other.isEnum => + CreateNamedStruct(expressions.Literal("enum") :: + StaticInvoke(JavaTypeInference.getClass, StringType, "serializeEnumName", + expressions.Literal.create(other.getName, StringType) :: inputObject :: Nil) :: Nil) + case other => val properties = getJavaBeanReadableAndWritableProperties(other) val nonNullOutput = CreateNamedStruct(properties.flatMap { p => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index efc2882f0a3d3..9ed5e120344b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation} -import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType} +import org.apache.spark.sql.types.{BooleanType, DataType, ObjectType, StringType, StructField, StructType} import org.apache.spark.util.Utils /** @@ -81,9 +81,19 @@ object ExpressionEncoder { ClassTag[T](cls)) } + def javaEnumSchema[T](beanClass: Class[T]): DataType = { + StructType(Seq(StructField("enum", + StructType(Seq(StructField(beanClass.getSimpleName, StringType, nullable = false))), + nullable = false))) + } + // TODO: improve error message for java bean encoder. def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = { - val schema = JavaTypeInference.inferDataType(beanClass)._1 + val schema = if (beanClass.isEnum) { + javaEnumSchema(beanClass) + } else { + JavaTypeInference.inferDataType(beanClass)._1 + } assert(schema.isInstanceOf[StructType]) val serializer = JavaTypeInference.serializerFor(beanClass) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9b28a18035b1c..7c466fe03cdcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -154,13 +154,13 @@ case class StaticInvoke( val evaluate = if (returnNullable) { if (ctx.defaultValue(dataType) == "null") { s""" - ${ev.value} = $callFunc; + ${ev.value} = (($javaType) ($callFunc)); ${ev.isNull} = ${ev.value} == null; """ } else { val boxedResult = ctx.freshName("boxedResult") s""" - ${ctx.boxedType(dataType)} $boxedResult = $callFunc; + ${ctx.boxedType(dataType)} $boxedResult = (($javaType) ($callFunc)); ${ev.isNull} = $boxedResult == null; if (!${ev.isNull}) { ${ev.value} = $boxedResult; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 4ca3b6406a328..a34474683013f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1283,6 +1283,83 @@ public void test() { ds.collectAsList(); } + public enum EnumBean { + A("www.elgoog.com"), + B("www.google.com"); + + private String url; + + EnumBean(String url) { + this.url = url; + } + + public String getUrl() { + return url; + } + + public void setUrl(String url) { + this.url = url; + } + } + + @Test + public void testEnum() { + List data = Arrays.asList(EnumBean.B); + Encoder encoder = Encoders.bean(EnumBean.class); + Dataset ds = spark.createDataset(data, encoder); + Assert.assertEquals(ds.collectAsList(), data); + } + + public static class BeanWithEnum { + EnumBean enumField; + String regularField; + + public String getRegularField() { + return regularField; + } + + public void setRegularField(String regularField) { + this.regularField = regularField; + } + + public EnumBean getEnumField() { + return enumField; + } + + public void setEnumField(EnumBean field) { + this.enumField = field; + } + + public BeanWithEnum(EnumBean enumField, String regularField) { + this.enumField = enumField; + this.regularField = regularField; + } + + public BeanWithEnum() { + } + + public String toString() { + return "BeanWithEnum(" + enumField + ", " + regularField + ")"; + } + + public boolean equals(Object other) { + if (other instanceof BeanWithEnum) { + BeanWithEnum beanWithEnum = (BeanWithEnum) other; + return beanWithEnum.regularField.equals(regularField) && beanWithEnum.enumField.equals(enumField); + } + return false; + } + } + + @Test + public void testBeanWithEnum() { + List data = Arrays.asList(new BeanWithEnum(EnumBean.A, "mira avenue"), + new BeanWithEnum(EnumBean.B, "flower boulevard")); + Encoder encoder = Encoders.bean(BeanWithEnum.class); + Dataset ds = spark.createDataset(data, encoder); + Assert.assertEquals(ds.collectAsList(), data); + } + public static class EmptyBean implements Serializable {} @Test From 574ef6c987c636210828e96d2f797d8f10aff05e Mon Sep 17 00:00:00 2001 From: zhoukang Date: Fri, 25 Aug 2017 22:59:31 +0800 Subject: [PATCH 050/187] [SPARK-21527][CORE] Use buffer limit in order to use JAVA NIO Util's buffercache ## What changes were proposed in this pull request? Right now, ChunkedByteBuffer#writeFully do not slice bytes first.We observe code in java nio Util#getTemporaryDirectBuffer below: BufferCache cache = bufferCache.get(); ByteBuffer buf = cache.get(size); if (buf != null) { return buf; } else { // No suitable buffer in the cache so we need to allocate a new // one. To avoid the cache growing then we remove the first // buffer from the cache and free it. if (!cache.isEmpty()) { buf = cache.removeFirst(); free(buf); } return ByteBuffer.allocateDirect(size); } If we slice first with a fixed size, we can use buffer cache and only need to allocate at the first write call. Since we allocate new buffer, we can not control the free time of this buffer.This once cause memory issue in our production cluster. In this patch, i supply a new api which will slice with fixed size for buffer writing. ## How was this patch tested? Unit test and test in production. Author: zhoukang Author: zhoukang Closes #18730 from caneGuy/zhoukang/improve-chunkwrite. --- .../org/apache/spark/internal/config/package.scala | 9 +++++++++ .../org/apache/spark/util/io/ChunkedByteBuffer.scala | 11 ++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 9495cd2835f97..0457a66af8e89 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -293,6 +293,15 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val BUFFER_WRITE_CHUNK_SIZE = + ConfigBuilder("spark.buffer.write.chunkSize") + .internal() + .doc("The chunk size during writing out the bytes of ChunkedByteBuffer.") + .bytesConf(ByteUnit.BYTE) + .checkValue(_ <= Int.MaxValue, "The chunk size during writing out the bytes of" + + " ChunkedByteBuffer should not larger than Int.MaxValue.") + .createWithDefault(64 * 1024 * 1024) + private[spark] val CHECKPOINT_COMPRESS = ConfigBuilder("spark.checkpoint.compress") .doc("Whether to compress RDD checkpoints. Generally a good idea. Compression will use " + diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index f48bfd5c25f77..c28570fb24560 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -24,6 +24,8 @@ import java.nio.channels.WritableByteChannel import com.google.common.primitives.UnsignedBytes import io.netty.buffer.{ByteBuf, Unpooled} +import org.apache.spark.SparkEnv +import org.apache.spark.internal.config import org.apache.spark.network.util.ByteArrayWritableChannel import org.apache.spark.storage.StorageUtils @@ -40,6 +42,11 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { require(chunks != null, "chunks must not be null") require(chunks.forall(_.position() == 0), "chunks' positions must be 0") + // Chunk size in bytes + private val bufferWriteChunkSize = + Option(SparkEnv.get).map(_.conf.get(config.BUFFER_WRITE_CHUNK_SIZE)) + .getOrElse(config.BUFFER_WRITE_CHUNK_SIZE.defaultValue.get).toInt + private[this] var disposed: Boolean = false /** @@ -56,7 +63,9 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { */ def writeFully(channel: WritableByteChannel): Unit = { for (bytes <- getChunks()) { - while (bytes.remaining > 0) { + while (bytes.remaining() > 0) { + val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) + bytes.limit(bytes.position + ioSize) channel.write(bytes) } } From de7af295c2047f1b508cb02e735e0e743395f181 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 25 Aug 2017 16:07:13 +0100 Subject: [PATCH 051/187] [MINOR][BUILD] Fix build warnings and Java lint errors ## What changes were proposed in this pull request? Fix build warnings and Java lint errors. This just helps a bit in evaluating (new) warnings in another PR I have open. ## How was this patch tested? Existing tests Author: Sean Owen Closes #19051 from srowen/JavaWarnings. --- .../java/org/apache/spark/util/kvstore/InMemoryStore.java | 2 +- .../org/apache/spark/util/kvstore/KVStoreIterator.java | 3 ++- .../apache/spark/network/TransportRequestHandlerSuite.java | 7 +++++-- .../java/org/apache/spark/launcher/SparkLauncherSuite.java | 1 - .../org/apache/spark/launcher/ChildProcAppHandleSuite.java | 1 - .../org/apache/spark/ml/tuning/CrossValidatorSuite.scala | 7 +++---- .../apache/spark/ml/tuning/TrainValidationSplitSuite.scala | 7 +++---- pom.xml | 2 +- .../datasources/parquet/VectorizedColumnReader.java | 3 ++- .../spark/sql/execution/vectorized/AggregateHashMap.java | 1 - .../spark/sql/execution/vectorized/ArrowColumnVector.java | 1 - 11 files changed, 17 insertions(+), 18 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java index 9cae5da5d2600..5ca4371285198 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java @@ -171,7 +171,7 @@ public int size() { public InMemoryView view(Class type) { Preconditions.checkArgument(ti.type().equals(type), "Unexpected type: %s", type); Collection all = (Collection) data.values(); - return new InMemoryView(type, all, ti); + return new InMemoryView<>(type, all, ti); } } diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreIterator.java index 28a432b26d98e..e6254a9368ff5 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreIterator.java @@ -17,6 +17,7 @@ package org.apache.spark.util.kvstore; +import java.io.Closeable; import java.util.Iterator; import java.util.List; @@ -31,7 +32,7 @@ *

*/ @Private -public interface KVStoreIterator extends Iterator, AutoCloseable { +public interface KVStoreIterator extends Iterator, Closeable { /** * Retrieve multiple elements from the store. diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java index 1ed57116bc7bf..2656cbee95a20 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java @@ -102,7 +102,7 @@ public void handleFetchRequestAndStreamRequest() throws Exception { private class ExtendedChannelPromise extends DefaultChannelPromise { - private List listeners = new ArrayList<>(); + private List>> listeners = new ArrayList<>(); private boolean success; ExtendedChannelPromise(Channel channel) { @@ -113,7 +113,10 @@ private class ExtendedChannelPromise extends DefaultChannelPromise { @Override public ChannelPromise addListener( GenericFutureListener> listener) { - listeners.add(listener); + @SuppressWarnings("unchecked") + GenericFutureListener> gfListener = + (GenericFutureListener>) listener; + listeners.add(gfListener); return super.addListener(listener); } diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index db4fc26cdf353..ac4391e3ef99b 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -21,7 +21,6 @@ import java.util.HashMap; import java.util.Map; -import org.junit.Before; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/launcher/src/test/java/org/apache/spark/launcher/ChildProcAppHandleSuite.java b/launcher/src/test/java/org/apache/spark/launcher/ChildProcAppHandleSuite.java index 64a87b365d6a9..602f55a50564d 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/ChildProcAppHandleSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/ChildProcAppHandleSuite.java @@ -18,7 +18,6 @@ package org.apache.spark.launcher; import java.io.File; -import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index dc6043ef19fe2..90778d7890064 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -187,14 +187,13 @@ class CrossValidatorSuite cv2.getEstimator match { case ova2: OneVsRest => assert(ova.uid === ova2.uid) - val classifier = ova2.getClassifier - classifier match { + ova2.getClassifier match { case lr: LogisticRegression => assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter === lr.getMaxIter) - case _ => + case other => throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + - s" LogisticREgression but found ${classifier.getClass.getName}") + s" LogisticRegression but found ${other.getClass.getName}") } case other => diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 7c97865e45202..aa8b4cf173cc3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -173,14 +173,13 @@ class TrainValidationSplitSuite tvs2.getEstimator match { case ova2: OneVsRest => assert(ova.uid === ova2.uid) - val classifier = ova2.getClassifier - classifier match { + ova2.getClassifier match { case lr: LogisticRegression => assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter === lr.getMaxIter) - case _ => + case other => throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" + - s" LogisticREgression but found ${classifier.getClass.getName}") + s" LogisticRegression but found ${other.getClass.getName}") } case other => diff --git a/pom.xml b/pom.xml index 8b4a6c5425a98..fffd70ec1d929 100644 --- a/pom.xml +++ b/pom.xml @@ -2058,7 +2058,7 @@ ${java.version} -target ${java.version} - -Xlint:all,-serial,-path + -Xlint:all,-serial,-path,-try diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index f37864a0f5393..2173bbce3eea9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -350,7 +350,8 @@ private void decodeDictionaryIds( * is guaranteed that num is smaller than the number of values left in the current page. */ - private void readBooleanBatch(int rowId, int num, WritableColumnVector column) throws IOException { + private void readBooleanBatch(int rowId, int num, WritableColumnVector column) + throws IOException { assert(column.dataType() == DataTypes.BooleanType); defColumn.readBooleans( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java index 1c94f706dc685..cb3ad4eab1f60 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java @@ -21,7 +21,6 @@ import com.google.common.annotations.VisibleForTesting; -import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.types.StructType; import static org.apache.spark.sql.types.DataTypes.LongType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index be2a9c246747c..1f171049820b2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -21,7 +21,6 @@ import org.apache.arrow.vector.complex.*; import org.apache.arrow.vector.holders.NullableVarCharHolder; -import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.execution.arrow.ArrowUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.UTF8String; From 1f24ceee606f17c4f3ca969fa4b5631256fa09e8 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 25 Aug 2017 08:59:48 -0700 Subject: [PATCH 052/187] [SPARK-21832][TEST] Merge SQLBuilderTest into ExpressionSQLBuilderSuite ## What changes were proposed in this pull request? After [SPARK-19025](https://github.com/apache/spark/pull/16869), there is no need to keep SQLBuilderTest. ExpressionSQLBuilderSuite is the only place to use it. This PR aims to remove SQLBuilderTest. ## How was this patch tested? Pass the updated `ExpressionSQLBuilderSuite`. Author: Dongjoon Hyun Closes #19044 from dongjoon-hyun/SPARK-21832. --- .../catalyst/ExpressionSQLBuilderSuite.scala | 23 ++++++++-- .../spark/sql/catalyst/SQLBuilderTest.scala | 44 ------------------- 2 files changed, 20 insertions(+), 47 deletions(-) delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala index 90f90599d5bf4..d9cf1f361c1d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala @@ -19,12 +19,29 @@ package org.apache.spark.sql.catalyst import java.sql.Timestamp +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{If, Literal, SpecifiedWindowFrame, TimeAdd, - TimeSub, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.unsafe.types.CalendarInterval -class ExpressionSQLBuilderSuite extends SQLBuilderTest { +class ExpressionSQLBuilderSuite extends QueryTest with TestHiveSingleton { + protected def checkSQL(e: Expression, expectedSQL: String): Unit = { + val actualSQL = e.sql + try { + assert(actualSQL == expectedSQL) + } catch { + case cause: Throwable => + fail( + s"""Wrong SQL generated for the following expression: + | + |${e.prettyName} + | + |$cause + """.stripMargin) + } + } + test("literal") { checkSQL(Literal("foo"), "'foo'") checkSQL(Literal("\"foo\""), "'\"foo\"'") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala deleted file mode 100644 index 157783abc8c2f..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import scala.util.control.NonFatal - -import org.apache.spark.sql.{DataFrame, Dataset, QueryTest} -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.hive.test.TestHiveSingleton - - -abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { - protected def checkSQL(e: Expression, expectedSQL: String): Unit = { - val actualSQL = e.sql - try { - assert(actualSQL === expectedSQL) - } catch { - case cause: Throwable => - fail( - s"""Wrong SQL generated for the following expression: - | - |${e.prettyName} - | - |$cause - """.stripMargin) - } - } -} From 1813c4a8dd4388fe76a4ec772c9be151be0f60a1 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 25 Aug 2017 09:57:53 -0700 Subject: [PATCH 053/187] [SPARK-21714][CORE][YARN] Avoiding re-uploading remote resources in yarn client mode ## What changes were proposed in this pull request? With SPARK-10643, Spark supports download resources from remote in client deploy mode. But the implementation overrides variables which representing added resources (like `args.jars`, `args.pyFiles`) to local path, And yarn client leverage this local path to re-upload resources to distributed cache. This is unnecessary to break the semantics of putting resources in a shared FS. So here proposed to fix it. ## How was this patch tested? This is manually verified with jars, pyFiles in local and remote storage, both in client and cluster mode. Author: jerryshao Closes #18962 from jerryshao/SPARK-21714. --- .../org/apache/spark/deploy/SparkSubmit.scala | 64 +++++++++++------ .../spark/internal/config/package.scala | 2 +- .../scala/org/apache/spark/util/Utils.scala | 25 ++++--- .../spark/deploy/SparkSubmitSuite.scala | 70 +++++++++++++++---- .../scala/org/apache/spark/repl/Main.scala | 2 +- 5 files changed, 114 insertions(+), 49 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index e56925102d47e..548149a88a49d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -212,14 +212,20 @@ object SparkSubmit extends CommandLineUtils { /** * Prepare the environment for submitting an application. - * This returns a 4-tuple: - * (1) the arguments for the child process, - * (2) a list of classpath entries for the child, - * (3) a map of system properties, and - * (4) the main class for the child + * + * @param args the parsed SparkSubmitArguments used for environment preparation. + * @param conf the Hadoop Configuration, this argument will only be set in unit test. + * @return a 4-tuple: + * (1) the arguments for the child process, + * (2) a list of classpath entries for the child, + * (3) a map of system properties, and + * (4) the main class for the child + * * Exposed for testing. */ - private[deploy] def prepareSubmitEnvironment(args: SparkSubmitArguments) + private[deploy] def prepareSubmitEnvironment( + args: SparkSubmitArguments, + conf: Option[HadoopConfiguration] = None) : (Seq[String], Seq[String], Map[String, String], String) = { // Return values val childArgs = new ArrayBuffer[String]() @@ -322,7 +328,7 @@ object SparkSubmit extends CommandLineUtils { } } - val hadoopConf = new HadoopConfiguration() + val hadoopConf = conf.getOrElse(new HadoopConfiguration()) val targetDir = DependencyUtils.createTempDir() // Resolve glob path for different resources. @@ -332,19 +338,21 @@ object SparkSubmit extends CommandLineUtils { args.archives = Option(args.archives).map(resolveGlobPaths(_, hadoopConf)).orNull // In client mode, download remote files. + var localPrimaryResource: String = null + var localJars: String = null + var localPyFiles: String = null if (deployMode == CLIENT) { - args.primaryResource = Option(args.primaryResource).map { + localPrimaryResource = Option(args.primaryResource).map { downloadFile(_, targetDir, args.sparkProperties, hadoopConf) }.orNull - args.jars = Option(args.jars).map { + localJars = Option(args.jars).map { downloadFileList(_, targetDir, args.sparkProperties, hadoopConf) }.orNull - args.pyFiles = Option(args.pyFiles).map { + localPyFiles = Option(args.pyFiles).map { downloadFileList(_, targetDir, args.sparkProperties, hadoopConf) }.orNull } - // If we're running a python app, set the main class to our specific python runner if (args.isPython && deployMode == CLIENT) { if (args.primaryResource == PYSPARK_SHELL) { @@ -353,7 +361,7 @@ object SparkSubmit extends CommandLineUtils { // If a python file is provided, add it to the child arguments and list of files to deploy. // Usage: PythonAppRunner
[app arguments] args.mainClass = "org.apache.spark.deploy.PythonRunner" - args.childArgs = ArrayBuffer(args.primaryResource, args.pyFiles) ++ args.childArgs + args.childArgs = ArrayBuffer(localPrimaryResource, localPyFiles) ++ args.childArgs if (clusterManager != YARN) { // The YARN backend distributes the primary file differently, so don't merge it. args.files = mergeFileLists(args.files, args.primaryResource) @@ -363,8 +371,8 @@ object SparkSubmit extends CommandLineUtils { // The YARN backend handles python files differently, so don't merge the lists. args.files = mergeFileLists(args.files, args.pyFiles) } - if (args.pyFiles != null) { - sysProps("spark.submit.pyFiles") = args.pyFiles + if (localPyFiles != null) { + sysProps("spark.submit.pyFiles") = localPyFiles } } @@ -418,7 +426,7 @@ object SparkSubmit extends CommandLineUtils { // If an R file is provided, add it to the child arguments and list of files to deploy. // Usage: RRunner
[app arguments] args.mainClass = "org.apache.spark.deploy.RRunner" - args.childArgs = ArrayBuffer(args.primaryResource) ++ args.childArgs + args.childArgs = ArrayBuffer(localPrimaryResource) ++ args.childArgs args.files = mergeFileLists(args.files, args.primaryResource) } } @@ -463,6 +471,7 @@ object SparkSubmit extends CommandLineUtils { OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.instances"), + OptionAssigner(args.pyFiles, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.pyFiles"), OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.jars"), OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.archives"), @@ -486,15 +495,28 @@ object SparkSubmit extends CommandLineUtils { sysProp = "spark.driver.cores"), OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER, sysProp = "spark.driver.supervise"), - OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy") + OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"), + + // An internal option used only for spark-shell to add user jars to repl's classloader, + // previously it uses "spark.jars" or "spark.yarn.dist.jars" which now may be pointed to + // remote jars, so adding a new option to only specify local jars for spark-shell internally. + OptionAssigner(localJars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.repl.local.jars") ) // In client mode, launch the application main class directly // In addition, add the main application jar and any added jars (if any) to the classpath - // Also add the main application jar and any added jars to classpath in case YARN client - // requires these jars. - if (deployMode == CLIENT || isYarnCluster) { + if (deployMode == CLIENT) { childMainClass = args.mainClass + if (localPrimaryResource != null && isUserJar(localPrimaryResource)) { + childClasspath += localPrimaryResource + } + if (localJars != null) { childClasspath ++= localJars.split(",") } + } + // Add the main application jar and any added jars to classpath in case YARN client + // requires these jars. + // This assumes both primaryResource and user jars are local jars, otherwise it will not be + // added to the classpath of YARN client. + if (isYarnCluster) { if (isUserJar(args.primaryResource)) { childClasspath += args.primaryResource } @@ -551,10 +573,6 @@ object SparkSubmit extends CommandLineUtils { if (args.isPython) { sysProps.put("spark.yarn.isPython", "true") } - - if (args.pyFiles != null) { - sysProps("spark.submit.pyFiles") = args.pyFiles - } } // assure a keytab is available from any place in a JVM diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 0457a66af8e89..0d3769a735869 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -87,7 +87,7 @@ package object config { .intConf .createOptional - private[spark] val PY_FILES = ConfigBuilder("spark.submit.pyFiles") + private[spark] val PY_FILES = ConfigBuilder("spark.yarn.dist.pyFiles") .internal() .stringConf .toSequence diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 900a619421903..3dce76c2c96ba 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2601,18 +2601,23 @@ private[spark] object Utils extends Logging { } /** - * In YARN mode this method returns a union of the jar files pointed by "spark.jars" and the - * "spark.yarn.dist.jars" properties, while in other modes it returns the jar files pointed by - * only the "spark.jars" property. + * Return the jar files pointed by the "spark.jars" property. Spark internally will distribute + * these jars through file server. In the YARN mode, it will return an empty list, since YARN + * has its own mechanism to distribute jars. */ - def getUserJars(conf: SparkConf, isShell: Boolean = false): Seq[String] = { + def getUserJars(conf: SparkConf): Seq[String] = { val sparkJars = conf.getOption("spark.jars") - if (conf.get("spark.master") == "yarn" && isShell) { - val yarnJars = conf.getOption("spark.yarn.dist.jars") - unionFileLists(sparkJars, yarnJars).toSeq - } else { - sparkJars.map(_.split(",")).map(_.filter(_.nonEmpty)).toSeq.flatten - } + sparkJars.map(_.split(",")).map(_.filter(_.nonEmpty)).toSeq.flatten + } + + /** + * Return the local jar files which will be added to REPL's classpath. These jar files are + * specified by --jars (spark.jars) or --packages, remote jars will be downloaded to local by + * SparkSubmit at first. + */ + def getLocalUserJarsForShell(conf: SparkConf): Seq[String] = { + val localJars = conf.getOption("spark.repl.local.jars") + localJars.map(_.split(",")).map(_.filter(_.nonEmpty)).toSeq.flatten } private[spark] val REDACTION_REPLACEMENT_TEXT = "*********(redacted)" diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 08ba41f50a2b9..95137c868cbaf 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -29,7 +29,7 @@ import scala.io.Source import com.google.common.io.ByteStreams import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -762,7 +762,7 @@ class SparkSubmitSuite (Set(jar1.toURI.toString, jar2.toURI.toString)) sysProps("spark.yarn.dist.files").split(",").toSet should be (Set(file1.toURI.toString, file2.toURI.toString)) - sysProps("spark.submit.pyFiles").split(",").toSet should be + sysProps("spark.yarn.dist.pyFiles").split(",").toSet should be (Set(pyFile1.getAbsolutePath, pyFile2.getAbsolutePath)) sysProps("spark.yarn.dist.archives").split(",").toSet should be (Set(archive1.toURI.toString, archive2.toURI.toString)) @@ -802,10 +802,7 @@ class SparkSubmitSuite test("downloadFile - file doesn't exist") { val hadoopConf = new Configuration() val tmpDir = Utils.createTempDir() - // Set s3a implementation to local file system for testing. - hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") - // Disable file system impl cache to make sure the test file system is picked up. - hadoopConf.set("fs.s3a.impl.disable.cache", "true") + updateConfWithFakeS3Fs(hadoopConf) intercept[FileNotFoundException] { SparkSubmit.downloadFile("s3a:/no/such/file", tmpDir, mutable.Map.empty, hadoopConf) } @@ -826,10 +823,7 @@ class SparkSubmitSuite FileUtils.write(jarFile, content) val hadoopConf = new Configuration() val tmpDir = Files.createTempDirectory("tmp").toFile - // Set s3a implementation to local file system for testing. - hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") - // Disable file system impl cache to make sure the test file system is picked up. - hadoopConf.set("fs.s3a.impl.disable.cache", "true") + updateConfWithFakeS3Fs(hadoopConf) val sourcePath = s"s3a://${jarFile.getAbsolutePath}" val outputPath = SparkSubmit.downloadFile(sourcePath, tmpDir, mutable.Map.empty, hadoopConf) @@ -844,10 +838,7 @@ class SparkSubmitSuite FileUtils.write(jarFile, content) val hadoopConf = new Configuration() val tmpDir = Files.createTempDirectory("tmp").toFile - // Set s3a implementation to local file system for testing. - hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") - // Disable file system impl cache to make sure the test file system is picked up. - hadoopConf.set("fs.s3a.impl.disable.cache", "true") + updateConfWithFakeS3Fs(hadoopConf) val sourcePaths = Seq("/local/file", s"s3a://${jarFile.getAbsolutePath}") val outputPaths = SparkSubmit.downloadFileList( sourcePaths.mkString(","), tmpDir, mutable.Map.empty, hadoopConf).split(",") @@ -859,6 +850,43 @@ class SparkSubmitSuite } } + test("Avoid re-upload remote resources in yarn client mode") { + val hadoopConf = new Configuration() + updateConfWithFakeS3Fs(hadoopConf) + + val tmpDir = Utils.createTempDir() + val file = File.createTempFile("tmpFile", "", tmpDir) + val pyFile = File.createTempFile("tmpPy", ".egg", tmpDir) + val mainResource = File.createTempFile("tmpPy", ".py", tmpDir) + val tmpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) + val tmpJarPath = s"s3a://${new File(tmpJar.toURI).getAbsolutePath}" + + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--jars", tmpJarPath, + "--files", s"s3a://${file.getAbsolutePath}", + "--py-files", s"s3a://${pyFile.getAbsolutePath}", + s"s3a://$mainResource" + ) + + val appArgs = new SparkSubmitArguments(args) + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf))._3 + + // All the resources should still be remote paths, so that YARN client will not upload again. + sysProps("spark.yarn.dist.jars") should be (tmpJarPath) + sysProps("spark.yarn.dist.files") should be (s"s3a://${file.getAbsolutePath}") + sysProps("spark.yarn.dist.pyFiles") should be (s"s3a://${pyFile.getAbsolutePath}") + + // Local repl jars should be a local path. + sysProps("spark.repl.local.jars") should (startWith("file:")) + + // local py files should not be a URI format. + sysProps("spark.submit.pyFiles") should (startWith("/")) + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) @@ -898,6 +926,11 @@ class SparkSubmitSuite Utils.deleteRecursively(tmpDir) } } + + private def updateConfWithFakeS3Fs(conf: Configuration): Unit = { + conf.set("fs.s3a.impl", classOf[TestFileSystem].getCanonicalName) + conf.set("fs.s3a.impl.disable.cache", "true") + } } object JarCreationTest extends Logging { @@ -967,4 +1000,13 @@ class TestFileSystem extends org.apache.hadoop.fs.LocalFileSystem { // Ignore the scheme for testing. super.copyToLocalFile(new Path(src.toUri.getPath), dst) } + + override def globStatus(pathPattern: Path): Array[FileStatus] = { + val newPath = new Path(pathPattern.toUri.getPath) + super.globStatus(newPath).map { status => + val path = s"s3a://${status.getPath.toUri.getPath}" + status.setPath(new Path(path)) + status + } + } } diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 9702a1e653c32..0b16e1b073e32 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -57,7 +57,7 @@ object Main extends Logging { // Visible for testing private[repl] def doMain(args: Array[String], _interp: SparkILoop): Unit = { interp = _interp - val jars = Utils.getUserJars(conf, isShell = true) + val jars = Utils.getLocalUserJarsForShell(conf) // Remove file:///, file:// or file:/ scheme if exists for each jar .map { x => if (x.startsWith("file:")) new File(new URI(x)).getPath else x } .mkString(File.pathSeparator) From 628bdeabda3347d0903c9ac8748d37d7b379d1e6 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 25 Aug 2017 10:04:21 -0700 Subject: [PATCH 054/187] [SPARK-17742][CORE] Fail launcher app handle if child process exits with error. This is a follow up to cba826d0; that commit set the app handle state to "LOST" when the child process exited, but that can be ambiguous. This change sets the state to "FAILED" if the exit code was non-zero and the handle state wasn't a failure state, or "LOST" if the exit status was zero. Author: Marcelo Vanzin Closes #19012 from vanzin/SPARK-17742. --- .../spark/launcher/ChildProcAppHandle.java | 27 ++++++++++++++----- .../launcher/ChildProcAppHandleSuite.java | 21 ++++++++++++++- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index bf916406f1471..5391d4a50fe47 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -156,9 +156,15 @@ synchronized void setAppId(String appId) { * the exit code. */ void monitorChild() { - while (childProc.isAlive()) { + Process proc = childProc; + if (proc == null) { + // Process may have already been disposed of, e.g. by calling kill(). + return; + } + + while (proc.isAlive()) { try { - childProc.waitFor(); + proc.waitFor(); } catch (Exception e) { LOG.log(Level.WARNING, "Exception waiting for child process to exit.", e); } @@ -173,15 +179,24 @@ void monitorChild() { int ec; try { - ec = childProc.exitValue(); + ec = proc.exitValue(); } catch (Exception e) { LOG.log(Level.WARNING, "Exception getting child process exit code, assuming failure.", e); ec = 1; } - // Only override the success state; leave other fail states alone. - if (!state.isFinal() || (ec != 0 && state == State.FINISHED)) { - state = State.LOST; + State newState = null; + if (ec != 0) { + // Override state with failure if the current state is not final, or is success. + if (!state.isFinal() || state == State.FINISHED) { + newState = State.FAILED; + } + } else if (!state.isFinal()) { + newState = State.LOST; + } + + if (newState != null) { + state = newState; fireEvent(false); } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/ChildProcAppHandleSuite.java b/launcher/src/test/java/org/apache/spark/launcher/ChildProcAppHandleSuite.java index 602f55a50564d..3b4d1b07f606e 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/ChildProcAppHandleSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/ChildProcAppHandleSuite.java @@ -46,7 +46,9 @@ public class ChildProcAppHandleSuite extends BaseSuite { private static final List TEST_SCRIPT = Arrays.asList( "#!/bin/sh", "echo \"output\"", - "echo \"error\" 1>&2"); + "echo \"error\" 1>&2", + "while [ -n \"$1\" ]; do EC=$1; shift; done", + "exit $EC"); private static File TEST_SCRIPT_PATH; @@ -176,6 +178,7 @@ public void testRedirectErrorTwiceFails() throws Exception { @Test public void testProcMonitorWithOutputRedirection() throws Exception { + assumeFalse(isWindows()); File err = Files.createTempFile("out", "txt").toFile(); SparkAppHandle handle = new TestSparkLauncher() .redirectError() @@ -187,6 +190,7 @@ public void testProcMonitorWithOutputRedirection() throws Exception { @Test public void testProcMonitorWithLogRedirection() throws Exception { + assumeFalse(isWindows()); SparkAppHandle handle = new TestSparkLauncher() .redirectToLog(getClass().getName()) .startApplication(); @@ -194,6 +198,16 @@ public void testProcMonitorWithLogRedirection() throws Exception { assertEquals(SparkAppHandle.State.LOST, handle.getState()); } + @Test + public void testFailedChildProc() throws Exception { + assumeFalse(isWindows()); + SparkAppHandle handle = new TestSparkLauncher(1) + .redirectToLog(getClass().getName()) + .startApplication(); + waitFor(handle); + assertEquals(SparkAppHandle.State.FAILED, handle.getState()); + } + private void waitFor(SparkAppHandle handle) throws Exception { long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); try { @@ -212,7 +226,12 @@ private void waitFor(SparkAppHandle handle) throws Exception { private static class TestSparkLauncher extends SparkLauncher { TestSparkLauncher() { + this(0); + } + + TestSparkLauncher(int ec) { setAppResource("outputredirtest"); + addAppArgs(String.valueOf(ec)); } @Override From 51620e288b5e0a7fffc3899c9deadabace28e6d7 Mon Sep 17 00:00:00 2001 From: vinodkc Date: Fri, 25 Aug 2017 10:18:03 -0700 Subject: [PATCH 055/187] [SPARK-21756][SQL] Add JSON option to allow unquoted control characters ## What changes were proposed in this pull request? This patch adds allowUnquotedControlChars option in JSON data source to allow JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) ## How was this patch tested? Add new test cases Author: vinodkc Closes #19008 from vinodkc/br_fix_SPARK-21756. --- python/pyspark/sql/readwriter.py | 8 ++++++-- python/pyspark/sql/streaming.py | 8 ++++++-- .../spark/sql/catalyst/json/JSONOptions.scala | 3 +++ .../org/apache/spark/sql/DataFrameReader.scala | 3 +++ .../spark/sql/streaming/DataStreamReader.scala | 3 +++ .../json/JsonParsingOptionsSuite.scala | 15 +++++++++++++++ 6 files changed, 36 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 7279173df6e4f..01da0dc27d83d 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -176,7 +176,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None): + multiLine=None, allowUnquotedControlChars=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -234,6 +234,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param multiLine: parse one record, which may span multiple lines, per file. If None is set, it uses the default value, ``false``. + :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control + characters (ASCII characters with value less than 32, + including tab and line feed characters) or not. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -250,7 +253,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, - timestampFormat=timestampFormat, multiLine=multiLine) + timestampFormat=timestampFormat, multiLine=multiLine, + allowUnquotedControlChars=allowUnquotedControlChars) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 5bbd70cf0a789..0cf702143c773 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -407,7 +407,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None): + multiLine=None, allowUnquotedControlChars=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. @@ -467,6 +467,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param multiLine: parse one record, which may span multiple lines, per file. If None is set, it uses the default value, ``false``. + :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control + characters (ASCII characters with value less than 32, + including tab and line feed characters) or not. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) >>> json_sdf.isStreaming @@ -480,7 +483,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, - timestampFormat=timestampFormat, multiLine=multiLine) + timestampFormat=timestampFormat, multiLine=multiLine, + allowUnquotedControlChars=allowUnquotedControlChars) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 1fd680ab64b5a..652412b34478a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -64,6 +64,8 @@ private[sql] class JSONOptions( parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) val allowBackslashEscapingAnyCharacter = parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) + private val allowUnquotedControlChars = + parameters.get("allowUnquotedControlChars").map(_.toBoolean).getOrElse(false) val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) val parseMode: ParseMode = parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode) @@ -92,5 +94,6 @@ private[sql] class JSONOptions( factory.configure(JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS, allowNonNumericNumbers) factory.configure(JsonParser.Feature.ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER, allowBackslashEscapingAnyCharacter) + factory.configure(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS, allowUnquotedControlChars) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 41cb019499ae1..8209cec4ba0a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -313,6 +313,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * (e.g. 00012) *
  • `allowBackslashEscapingAnyCharacter` (default `false`): allows accepting quoting of all * character using backslash quoting mechanism
  • + *
  • `allowUnquotedControlChars` (default `false`): allows JSON Strings to contain unquoted + * control characters (ASCII characters with value less than 32, including tab and line feed + * characters) or not.
  • *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. *