Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,30 @@
*/
package org.apache.spark.sql.hive.thriftserver

import java.lang.reflect.InvocationTargetException
import java.nio.ByteBuffer
import java.util.UUID

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hive.service.cli.OperationHandle
import org.apache.hive.service.cli.operation.{GetCatalogsOperation, OperationManager}
import org.apache.hive.service.cli.session.{HiveSessionImpl, SessionManager}
import org.mockito.Mockito.{mock, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.apache.hive.service.cli.operation.{GetCatalogsOperation, Operation, OperationManager}
import org.apache.hive.service.cli.session.{HiveSession, HiveSessionImpl, SessionManager}
import org.apache.hive.service.rpc.thrift.{THandleIdentifier, TOperationHandle, TOperationType}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the compilation error, this import is invalid in Hive 1.2.


import org.apache.spark.SparkFunSuite

class HiveSessionImplSuite extends SparkFunSuite {
private var session: HiveSessionImpl = _
private var operationManager: OperationManager = _
private var operationManager: OperationManagerMock = _

override def beforeAll() {
super.beforeAll()

// mock the instance first - we observed weird classloader issue on creating mock, so
// would like to avoid any cases classloader gets switched
val sessionManager = mock(classOf[SessionManager])
operationManager = mock(classOf[OperationManager])
val sessionManager = new SessionManager(null)
operationManager = new OperationManagerMock()

session = new HiveSessionImpl(
ThriftserverShimUtils.testedProtocolVersions.head,
Expand All @@ -48,13 +50,6 @@ class HiveSessionImplSuite extends SparkFunSuite {
)
session.setSessionManager(sessionManager)
session.setOperationManager(operationManager)
when(operationManager.newGetCatalogsOperation(session)).thenAnswer(
(_: InvocationOnMock) => {
val operation = mock(classOf[GetCatalogsOperation])
when(operation.getHandle).thenReturn(mock(classOf[OperationHandle]))
operation
}
)

session.open(Map.empty[String, String].asJava)
}
Expand All @@ -63,14 +58,59 @@ class HiveSessionImplSuite extends SparkFunSuite {
val operationHandle1 = session.getCatalogs
val operationHandle2 = session.getCatalogs

when(operationManager.closeOperation(operationHandle1))
.thenThrow(classOf[NullPointerException])
when(operationManager.closeOperation(operationHandle2))
.thenThrow(classOf[NullPointerException])

session.close()

verify(operationManager).closeOperation(operationHandle1)
verify(operationManager).closeOperation(operationHandle2)
assert(operationManager.getCalledHandles.contains(operationHandle1))
assert(operationManager.getCalledHandles.contains(operationHandle2))
}
}

class GetCatalogsOperationMock(parentSession: HiveSession)
extends GetCatalogsOperation(parentSession) {

override def runInternal(): Unit = {}

override def getHandle: OperationHandle = {
val uuid: UUID = UUID.randomUUID()
val tHandleIdentifier: THandleIdentifier = new THandleIdentifier()
tHandleIdentifier.setGuid(getByteBufferFromUUID(uuid))
tHandleIdentifier.setSecret(getByteBufferFromUUID(uuid))
val tOperationHandle: TOperationHandle = new TOperationHandle()
tOperationHandle.setOperationId(tHandleIdentifier)
tOperationHandle.setOperationType(TOperationType.GET_TYPE_INFO)
tOperationHandle.setHasResultSetIsSet(false)
new OperationHandle(tOperationHandle)
}

private def getByteBufferFromUUID(uuid: UUID): Array[Byte] = {
val bb: ByteBuffer = ByteBuffer.wrap(new Array[Byte](16))
bb.putLong(uuid.getMostSignificantBits)
bb.putLong(uuid.getLeastSignificantBits)
bb.array
}
}

class OperationManagerMock extends OperationManager {
private val calledHandles: mutable.Set[OperationHandle] = new mutable.HashSet[OperationHandle]()

override def newGetCatalogsOperation(parentSession: HiveSession): GetCatalogsOperation = {
val operation = new GetCatalogsOperationMock(parentSession)
try {
val m = classOf[OperationManager].getDeclaredMethod("addOperation", classOf[Operation])
m.setAccessible(true)
m.invoke(this, operation)
} catch {
case e@(_: NoSuchMethodException | _: IllegalAccessException |
_: InvocationTargetException) =>
throw new RuntimeException(e)
}
operation
}

override def closeOperation(opHandle: OperationHandle): Unit = {
calledHandles.add(opHandle)
throw new RuntimeException
}

def getCalledHandles: mutable.Set[OperationHandle] = calledHandles
}