diff --git a/core/src/main/scala/org/apache/spark/util/NonFateSharingCache.scala b/core/src/main/scala/org/apache/spark/util/NonFateSharingCache.scala index d9847313304a..21184d70b386 100644 --- a/core/src/main/scala/org/apache/spark/util/NonFateSharingCache.scala +++ b/core/src/main/scala/org/apache/spark/util/NonFateSharingCache.scala @@ -19,8 +19,7 @@ package org.apache.spark.util import java.util.concurrent.Callable -import com.google.common.cache.Cache -import com.google.common.cache.LoadingCache +import com.google.common.cache.{Cache, CacheBuilder, CacheLoader, LoadingCache} /** * SPARK-43300: Guava cache fate-sharing behavior might lead to unexpected cascade failure: @@ -50,6 +49,25 @@ private[spark] object NonFateSharingCache { def apply[K, V](loadingCache: LoadingCache[K, V]): NonFateSharingLoadingCache[K, V] = new NonFateSharingLoadingCache(loadingCache) + + /** + * SPARK-44064 add this `apply` function to break non-core modules code directly using + * Guava Cache related types as input parameter to invoke other `NonFateSharingCache#apply` + * function, which can avoid non-core modules Maven test failures caused by using + * shaded core module. + * We should refactor this function to be more general when there are other requirements, + * or remove this function when Maven testing is no longer supported. + */ + def apply[K, V](loadingFunc: K => V, maximumSize: Long = 0L): NonFateSharingLoadingCache[K, V] = { + require(loadingFunc != null) + val builder = CacheBuilder.newBuilder().asInstanceOf[CacheBuilder[K, V]] + if (maximumSize > 0L) { + builder.maximumSize(maximumSize) + } + new NonFateSharingLoadingCache(builder.build[K, V](new CacheLoader[K, V] { + override def load(k: K): V = loadingFunc.apply(k) + })) + } } private[spark] class NonFateSharingCache[K, V](protected val cache: Cache[K, V]) { diff --git a/core/src/test/scala/org/apache/spark/util/NonFateSharingCacheSuite.scala b/core/src/test/scala/org/apache/spark/util/NonFateSharingCacheSuite.scala index b1780e81b2c1..df41b6b08daa 100644 --- a/core/src/test/scala/org/apache/spark/util/NonFateSharingCacheSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/NonFateSharingCacheSuite.scala @@ -31,7 +31,7 @@ object NonFateSharingCacheSuite { private val FAIL_MESSAGE = "loading failed" private val THREAD2_HOLDER = new AtomicReference[Thread](null) - class TestCacheLoader extends CacheLoader[String, String] { + trait InternalTestStatus { var intentionalFail: ThreadLocal[Boolean] = ThreadLocal.withInitial(() => false) var startLoading = new Semaphore(0) @@ -43,7 +43,10 @@ object NonFateSharingCacheSuite { } } } + } + + class TestCacheLoader extends CacheLoader[String, String] with InternalTestStatus { override def load(key: String): String = { startLoading.release() if (Thread.currentThread().getName.contains("test-executor1")) { @@ -53,6 +56,18 @@ object NonFateSharingCacheSuite { key } } + + class TestLoaderFunc extends InternalTestStatus { + def func: String => String = key => { + startLoading.release() + if (Thread.currentThread().getName.contains("test-executor1")) { + waitUntilThread2Waiting() + } + if (intentionalFail.get) throw new RuntimeException(FAIL_MESSAGE) + key + } + } + } /** @@ -66,16 +81,22 @@ class NonFateSharingCacheSuite extends SparkFunSuite { test("loading cache loading failure should not affect concurrent query on same key") { val loader = new TestCacheLoader - val loadingCache: NonFateSharingLoadingCache[String, String] = + val loadingCache0: NonFateSharingLoadingCache[String, String] = NonFateSharingCache(CacheBuilder.newBuilder.build(loader)) - val thread1Task: WorkerFunc = () => { - loader.intentionalFail.set(true) - loadingCache.get(TEST_KEY) - } - val thread2Task: WorkerFunc = () => { - loadingCache.get(TEST_KEY) + val loaderFunc = new TestLoaderFunc + val loadingCache1: NonFateSharingLoadingCache[String, String] = + NonFateSharingCache(loaderFunc.func) + Seq((loadingCache0, loader), (loadingCache1, loaderFunc)).foreach { + case (loadingCache, status) => + val thread1Task: WorkerFunc = () => { + status.intentionalFail.set(true) + loadingCache.get(TEST_KEY) + } + val thread2Task: WorkerFunc = () => { + loadingCache.get(TEST_KEY) + } + testImpl(loadingCache, status, thread1Task, thread2Task) } - testImpl(loadingCache, loader, thread1Task, thread2Task) } test("loading cache mix usage of default loader and provided loader") { @@ -115,14 +136,14 @@ class NonFateSharingCacheSuite extends SparkFunSuite { def testImpl( cache: NonFateSharingCache[String, String], - loader: TestCacheLoader, + internalTestStatus: InternalTestStatus, thread1Task: WorkerFunc, thread2Task: WorkerFunc): Unit = { val executor1 = ThreadUtils.newDaemonSingleThreadExecutor("test-executor1") val executor2 = ThreadUtils.newDaemonSingleThreadExecutor("test-executor2") val r1: Runnable = () => thread1Task() val r2: Runnable = () => { - loader.startLoading.acquire() // wait until thread1 start loading + internalTestStatus.startLoading.acquire() // wait until thread1 start loading THREAD2_HOLDER.set(Thread.currentThread()) thread2Task() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 15d68b0f9234..8d10f6cd2952 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -25,7 +25,6 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal -import com.google.common.cache.{CacheBuilder, CacheLoader} import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} import org.codehaus.commons.compiler.{CompileException, InternalCompilerException} import org.codehaus.janino.ClassBodyEvaluator @@ -1581,23 +1580,22 @@ object CodeGenerator extends Logging { * while other queries wait on the same code, so that those other queries don't get wrongly * aborted. See [[NonFateSharingCache]] for more details. */ - private val cache = NonFateSharingCache(CacheBuilder.newBuilder() - .maximumSize(SQLConf.get.codegenCacheMaxEntries) - .build( - new CacheLoader[CodeAndComment, (GeneratedClass, ByteCodeStats)]() { - override def load(code: CodeAndComment): (GeneratedClass, ByteCodeStats) = { - val startTime = System.nanoTime() - val result = doCompile(code) - val endTime = System.nanoTime() - val duration = endTime - startTime - val timeMs: Double = duration.toDouble / NANOS_PER_MILLIS - CodegenMetrics.METRIC_SOURCE_CODE_SIZE.update(code.body.length) - CodegenMetrics.METRIC_COMPILATION_TIME.update(timeMs.toLong) - logInfo(s"Code generated in $timeMs ms") - _compileTime.add(duration) - result - } - })) + private val cache = { + def loadFunc: CodeAndComment => (GeneratedClass, ByteCodeStats) = code => { + val startTime = System.nanoTime() + val result = doCompile(code) + val endTime = System.nanoTime() + val duration = endTime - startTime + val timeMs: Double = duration.toDouble / NANOS_PER_MILLIS + CodegenMetrics.METRIC_SOURCE_CODE_SIZE.update(code.body.length) + CodegenMetrics.METRIC_COMPILATION_TIME.update(timeMs.toLong) + logInfo(s"Code generated in $timeMs ms") + _compileTime.add(duration) + result + } + NonFateSharingCache[CodeAndComment, (GeneratedClass, ByteCodeStats)]( + loadFunc, SQLConf.get.codegenCacheMaxEntries) + } /** * Name of Java primitive data type