Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ package org.apache.spark.util

import java.util.concurrent.Callable

import com.google.common.cache.Cache
import com.google.common.cache.LoadingCache
import com.google.common.cache.{Cache, CacheBuilder, CacheLoader, LoadingCache}

/**
* SPARK-43300: Guava cache fate-sharing behavior might lead to unexpected cascade failure:
Expand Down Expand Up @@ -50,6 +49,25 @@ private[spark] object NonFateSharingCache {

def apply[K, V](loadingCache: LoadingCache[K, V]): NonFateSharingLoadingCache[K, V] =
new NonFateSharingLoadingCache(loadingCache)

/**
* SPARK-44064 add this `apply` function to break non-core modules code directly using
* Guava Cache related types as input parameter to invoke other `NonFateSharingCache#apply`
* function, which can avoid non-core modules Maven test failures caused by using
* shaded core module.
* We should refactor this function to be more general when there are other requirements,
* or remove this function when Maven testing is no longer supported.
*/
def apply[K, V](loadingFunc: K => V, maximumSize: Long = 0L): NonFateSharingLoadingCache[K, V] = {
require(loadingFunc != null)
val builder = CacheBuilder.newBuilder().asInstanceOf[CacheBuilder[K, V]]
if (maximumSize > 0L) {
builder.maximumSize(maximumSize)
}
new NonFateSharingLoadingCache(builder.build[K, V](new CacheLoader[K, V] {
override def load(k: K): V = loadingFunc.apply(k)
}))
}
}

private[spark] class NonFateSharingCache[K, V](protected val cache: Cache[K, V]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ object NonFateSharingCacheSuite {
private val FAIL_MESSAGE = "loading failed"
private val THREAD2_HOLDER = new AtomicReference[Thread](null)

class TestCacheLoader extends CacheLoader[String, String] {
trait InternalTestStatus {
var intentionalFail: ThreadLocal[Boolean] = ThreadLocal.withInitial(() => false)
var startLoading = new Semaphore(0)

Expand All @@ -43,7 +43,10 @@ object NonFateSharingCacheSuite {
}
}
}
}


class TestCacheLoader extends CacheLoader[String, String] with InternalTestStatus {
override def load(key: String): String = {
startLoading.release()
if (Thread.currentThread().getName.contains("test-executor1")) {
Expand All @@ -53,6 +56,18 @@ object NonFateSharingCacheSuite {
key
}
}

class TestLoaderFunc extends InternalTestStatus {
def func: String => String = key => {
startLoading.release()
if (Thread.currentThread().getName.contains("test-executor1")) {
waitUntilThread2Waiting()
}
if (intentionalFail.get) throw new RuntimeException(FAIL_MESSAGE)
key
}
}

}

/**
Expand All @@ -66,16 +81,22 @@ class NonFateSharingCacheSuite extends SparkFunSuite {

test("loading cache loading failure should not affect concurrent query on same key") {
val loader = new TestCacheLoader
val loadingCache: NonFateSharingLoadingCache[String, String] =
val loadingCache0: NonFateSharingLoadingCache[String, String] =
NonFateSharingCache(CacheBuilder.newBuilder.build(loader))
val thread1Task: WorkerFunc = () => {
loader.intentionalFail.set(true)
loadingCache.get(TEST_KEY)
}
val thread2Task: WorkerFunc = () => {
loadingCache.get(TEST_KEY)
val loaderFunc = new TestLoaderFunc
val loadingCache1: NonFateSharingLoadingCache[String, String] =
NonFateSharingCache(loaderFunc.func)
Seq((loadingCache0, loader), (loadingCache1, loaderFunc)).foreach {
case (loadingCache, status) =>
val thread1Task: WorkerFunc = () => {
status.intentionalFail.set(true)
loadingCache.get(TEST_KEY)
}
val thread2Task: WorkerFunc = () => {
loadingCache.get(TEST_KEY)
}
testImpl(loadingCache, status, thread1Task, thread2Task)
}
testImpl(loadingCache, loader, thread1Task, thread2Task)
}

test("loading cache mix usage of default loader and provided loader") {
Expand Down Expand Up @@ -115,14 +136,14 @@ class NonFateSharingCacheSuite extends SparkFunSuite {

def testImpl(
cache: NonFateSharingCache[String, String],
loader: TestCacheLoader,
internalTestStatus: InternalTestStatus,
thread1Task: WorkerFunc,
thread2Task: WorkerFunc): Unit = {
val executor1 = ThreadUtils.newDaemonSingleThreadExecutor("test-executor1")
val executor2 = ThreadUtils.newDaemonSingleThreadExecutor("test-executor2")
val r1: Runnable = () => thread1Task()
val r2: Runnable = () => {
loader.startLoading.acquire() // wait until thread1 start loading
internalTestStatus.startLoading.acquire() // wait until thread1 start loading
THREAD2_HOLDER.set(Thread.currentThread())
thread2Task()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal

import com.google.common.cache.{CacheBuilder, CacheLoader}
import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException}
import org.codehaus.commons.compiler.{CompileException, InternalCompilerException}
import org.codehaus.janino.ClassBodyEvaluator
Expand Down Expand Up @@ -1581,23 +1580,22 @@ object CodeGenerator extends Logging {
* while other queries wait on the same code, so that those other queries don't get wrongly
* aborted. See [[NonFateSharingCache]] for more details.
*/
private val cache = NonFateSharingCache(CacheBuilder.newBuilder()
.maximumSize(SQLConf.get.codegenCacheMaxEntries)
.build(
new CacheLoader[CodeAndComment, (GeneratedClass, ByteCodeStats)]() {
override def load(code: CodeAndComment): (GeneratedClass, ByteCodeStats) = {
val startTime = System.nanoTime()
val result = doCompile(code)
val endTime = System.nanoTime()
val duration = endTime - startTime
val timeMs: Double = duration.toDouble / NANOS_PER_MILLIS
CodegenMetrics.METRIC_SOURCE_CODE_SIZE.update(code.body.length)
CodegenMetrics.METRIC_COMPILATION_TIME.update(timeMs.toLong)
logInfo(s"Code generated in $timeMs ms")
_compileTime.add(duration)
result
}
}))
private val cache = {
def loadFunc: CodeAndComment => (GeneratedClass, ByteCodeStats) = code => {
val startTime = System.nanoTime()
val result = doCompile(code)
val endTime = System.nanoTime()
val duration = endTime - startTime
val timeMs: Double = duration.toDouble / NANOS_PER_MILLIS
CodegenMetrics.METRIC_SOURCE_CODE_SIZE.update(code.body.length)
CodegenMetrics.METRIC_COMPILATION_TIME.update(timeMs.toLong)
logInfo(s"Code generated in $timeMs ms")
_compileTime.add(duration)
result
}
NonFateSharingCache[CodeAndComment, (GeneratedClass, ByteCodeStats)](
loadFunc, SQLConf.get.codegenCacheMaxEntries)
}

/**
* Name of Java primitive data type
Expand Down