diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index cb015d7301c1..51b1778ec653 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -292,4 +292,6 @@ private[hive] trait HiveClient { /** Used for testing only. Removes all metadata from this instance of Hive. */ def reset(): Unit + /** Returns the user name which is used as owner for Hive table. */ + def userName: String } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index ae9eca823d00..96e61bd54280 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -222,7 +222,7 @@ private[hive] class HiveClientImpl( hiveConf } - private val userName = UserGroupInformation.getCurrentUser.getShortUserName + override val userName = UserGroupInformation.getCurrentUser.getShortUserName override def getConf(key: String, defaultValue: String): String = { conf.get(key, defaultValue) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientUserNameSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientUserNameSuite.scala new file mode 100644 index 000000000000..77956f4fe69d --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientUserNameSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.security.PrivilegedExceptionAction + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.security.UserGroupInformation +import org.scalatest.{BeforeAndAfterAll, PrivateMethodTester} + +import org.apache.spark.util.Utils + +class HiveClientUserNameSuite(version: String) extends HiveVersionSuite(version) { + + test("username of HiveClient - no UGI") { + // Assuming we're not faking System username + assert(getUserNameFromHiveClient === System.getProperty("user.name")) + } + + test("username of HiveClient - UGI") { + val ugi = UserGroupInformation.createUserForTesting( + "fakeprincipal@EXAMPLE.COM", Array.empty) + ugi.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + assert(getUserNameFromHiveClient === ugi.getShortUserName) + } + }) + } + + test("username of HiveClient - Proxy user") { + val ugi = UserGroupInformation.createUserForTesting( + "fakeprincipal@EXAMPLE.COM", Array.empty) + val proxyUgi = UserGroupInformation.createProxyUserForTesting( + "proxyprincipal@EXAMPLE.COM", ugi, Array.empty) + proxyUgi.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + assert(getUserNameFromHiveClient === proxyUgi.getShortUserName) + } + }) + } + + private def getUserNameFromHiveClient: String = { + val hadoopConf = new Configuration() + hadoopConf.set("hive.metastore.warehouse.dir", Utils.createTempDir().toURI().toString()) + val client = buildClient(hadoopConf) + client.userName + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientUserNameSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientUserNameSuites.scala new file mode 100644 index 000000000000..e076c01c0898 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientUserNameSuites.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import scala.collection.immutable.IndexedSeq + +import org.scalatest.Suite + +class HiveClientUserNameSuites extends Suite with HiveClientVersions { + override def nestedSuites: IndexedSeq[Suite] = { + versions.map(new HiveClientUserNameSuite(_)) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala similarity index 99% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala index bda711200acd..5f4ee7d7f1c0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala @@ -31,8 +31,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType, StructType} import org.apache.spark.util.Utils -// TODO: Refactor this to `HivePartitionFilteringSuite` -class HiveClientSuite(version: String) +class HivePartitionFilteringSuite(version: String) extends HiveVersionSuite(version) with BeforeAndAfterAll { private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuites.scala similarity index 87% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuites.scala index de1be2115b2d..a43e778b13b9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuites.scala @@ -21,9 +21,9 @@ import scala.collection.immutable.IndexedSeq import org.scalatest.Suite -class HiveClientSuites extends Suite with HiveClientVersions { +class HivePartitionFilteringSuites extends Suite with HiveClientVersions { override def nestedSuites: IndexedSeq[Suite] = { // Hive 0.12 does not provide the partition filtering API we call - versions.filterNot(_ == "0.12").map(new HiveClientSuite(_)) + versions.filterNot(_ == "0.12").map(new HivePartitionFilteringSuite(_)) } }