Skip to content

Commit

Permalink
Prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
ndkoval committed Mar 28, 2024
1 parent b5dc8c9 commit 8c633dc
Showing 1 changed file with 252 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
/*
* Lincheck
*
* Copyright (C) 2019 - 2024 JetBrains s.r.o.
*
* This Source Code Form is subject to the terms of the
* Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed
* with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/

package org.jetbrains.kotlinx.lincheck.transformation

import net.bytebuddy.agent.*
import org.jetbrains.kotlinx.lincheck.*
import org.jetbrains.kotlinx.lincheck.UnsafeHolder.UNSAFE
import org.objectweb.asm.*
import java.lang.instrument.*
import java.lang.invoke.MethodHandles
import java.lang.module.*
import java.lang.reflect.*
import java.lang.reflect.Modifier.*
import java.security.*
import java.util.*
import java.util.Collections.*
import java.util.concurrent.*
import java.util.jar.*

internal inline fun <R> withLincheckJavaAgent(transformationMode: TransformationMode, block: () -> R): R {
Class.forName("kotlin.text.StringsKt")
Class.forName("kotlin.jvm.internal.ArrayIteratorKt")
Class.forName("kotlin.jvm.internal.ArrayIterator")

LincheckClassFileTransformer.install(transformationMode)
return try {
block()
} finally {
LincheckClassFileTransformer.uninstall()
}
}

object LincheckClassFileTransformer : ClassFileTransformer {
private val transformedClassesModelChecking = ConcurrentHashMap<Any, ByteArray>()
private val transformedClassesStress = ConcurrentHashMap<Any, ByteArray>()
private val nonTransformedClasses = ConcurrentHashMap<Any, ByteArray>()

private val failedToRetransformClasses = newSetFromMap<Class<*>>(ConcurrentHashMap())

private val instrumentation = ByteBuddyAgent.install()

private var transformationMode = TransformationMode.STRESS

private val instrumentedClasses = HashSet<String>()

fun ensureClassAndAllSuperClassesAreTransformed(className: String) {
if (className in instrumentedClasses) return // already instrumented
ensureClassAndAllSuperClassesAreTransformed(Class.forName(className), newSetFromMap(IdentityHashMap()))
}

fun ensureClassAndAllSuperClassesAreTransformed(clazz: Class<*>) {
if (clazz.name in instrumentedClasses) return // already instrumented
ensureClassAndAllSuperClassesAreTransformed(clazz, newSetFromMap(IdentityHashMap()))
}


fun ensureObjectIsTransformed(testInstance: Any) {
ensureObjectIsTransformedImpl(testInstance, newSetFromMap(IdentityHashMap()))
}

private fun ensureObjectIsTransformedImpl(obj: Any, processedObjects: MutableSet<Any>) {
if (processedObjects.contains(obj)) return
processedObjects += obj

var clazz: Class<*> = obj.javaClass

ensureClassAndAllSuperClassesAreTransformed(clazz)

while (true) {
clazz.declaredFields
.filter { !it.type.isPrimitive }
.filter { !isStatic(it.modifiers) }
.mapNotNull { readFieldViaUnsafe(obj, it) }
.forEach { ensureObjectIsTransformedImpl(it, processedObjects) }
clazz = clazz.superclass ?: break
}
}

private fun ensureClassAndAllSuperClassesAreTransformed(clazz: Class<*>, processedObjects: MutableSet<Any>) {
if (instrumentation.isModifiableClass(clazz) && shouldTransform(clazz.name, transformationMode)) {
instrumentedClasses += clazz.name
println("Retransform! $clazz")
instrumentation.retransformClasses(clazz)
} else {
return
}
// Traverse static fields.
clazz.declaredFields
.filter { !it.type.isPrimitive }
.filter { isStatic(it.modifiers) }
.mapNotNull { readFieldViaUnsafe(null, it) }
.forEach { ensureObjectIsTransformedImpl(it, processedObjects) }
clazz.superclass?.let {
if (it.name in instrumentedClasses) return // already instrumented
ensureClassAndAllSuperClassesAreTransformed(it, processedObjects)
}
}

private fun readFieldViaUnsafe(obj: Any?, field: Field): Any? =
if (isStatic(field.modifiers)) {
val base = UNSAFE.staticFieldBase(field)
val offset = UNSAFE.staticFieldOffset(field)
UNSAFE.getObject(base, offset)
} else {
val offset = UNSAFE.objectFieldOffset(field)
UNSAFE.getObject(obj, offset)
}

internal fun install(transformationMode: TransformationMode) {
this.transformationMode = transformationMode
TransformationInjectionsInitializer.initialize(instrumentation)
instrumentation.addTransformer(this@LincheckClassFileTransformer, true)
if (transformationMode == TransformationMode.STRESS) {
instrumentation.retransformClasses(*getLoadedClassesToInstrument().toTypedArray())
}
}

private fun getLoadedClassesToInstrument() = instrumentation.allLoadedClasses
.filter(instrumentation::isModifiableClass)
.filter { shouldTransform(it.name, transformationMode) }
.filter { it !in failedToRetransformClasses }

internal fun uninstall() {
if (System.getProperty("INTERNAL_TESTS") == "true") {
val transformedClassesNames = transformedClassCache.keys
.map { if (it is Pair<*, *>) it.second.toString() else it.toString() }
.map { it.replace("/", ".").replace("\$", "\\\$") }
.toHashSet()

val loadedClasses = if (transformationMode == TransformationMode.STRESS) {
getLoadedClassesToInstrument()
} else {
getLoadedClassesToInstrument().filter { it.name in instrumentedClasses }
}
}
instrumentation.removeTransformer(this)
val classDefinitions = getLoadedClassesToInstrument().filter {
if (transformationMode == TransformationMode.MODEL_CHECKING) {
it.name in instrumentedClasses
} else {
true
}
}.mapNotNull { clazz ->
val bytes = nonTransformedClasses[classKey(clazz.classLoader, clazz.name)]
bytes?.let { ClassDefinition(clazz, it) }
}
instrumentation.redefineClasses(*classDefinitions.toTypedArray())
instrumentedClasses.clear()
}

override fun transform(
loader: ClassLoader?,
className: String,
classBeingRedefined: Class<*>?,
protectionDomain: ProtectionDomain?,
classBytes: ByteArray
): ByteArray? {
runInIgnoredSection {
if (transformationMode == TransformationMode.STRESS) {
if (!shouldTransform(className.canonicalClassName, transformationMode)) return null
} else {
if (className.canonicalClassName !in instrumentedClasses) return null
}
return transformImpl(loader, className, classBytes)
}
}

private fun transformImpl(loader: ClassLoader?, className: String, classBytes: ByteArray): ByteArray =
transformedClassCache.computeIfAbsent(classKey(loader, className)) {
nonTransformedClasses[classKey(loader, className)] = classBytes
val reader = ClassReader(classBytes)
val writer = SafeClassWriter(reader, loader, ClassWriter.COMPUTE_FRAMES)
try {
reader.accept(LincheckClassVisitor(transformationMode, writer), ClassReader.SKIP_FRAMES)
writer.toByteArray()
} catch (e: Throwable) {
System.err.println("Unable to transform $className")
e.printStackTrace()
classBytes
}
}

private val transformedClassCache
get() = when (transformationMode) {
TransformationMode.STRESS -> transformedClassesStress
TransformationMode.MODEL_CHECKING -> transformedClassesModelChecking
}

@Suppress("SpellCheckingInspection")
private fun shouldTransform(className: String, transformationMode: TransformationMode): Boolean {
if (transformationMode == TransformationMode.STRESS) {
if (className.startsWith("java.") || className.startsWith("kotlin.")) return false
}

if (className.startsWith("net.bytebuddy.")) return false
if (className.startsWith("io.mockk.")) return false
if (className.startsWith("it.unimi.dsi.fastutil.")) return false
if (className.startsWith("kotlinx.atomicfu.")) return false
if (className.startsWith("worker.org.gradle.")) return false
if (className.startsWith("org.objectweb.asm.")) return false
if (className.startsWith("org.junit.")) return false
if (className.startsWith("junit.framework.")) return false
if (className.startsWith("com.sun.")) return false
if (className.startsWith("java.util.")) {
if (className.startsWith("java.util.concurrent.atomic.") && className.contains("Atomic")) return false
if (className.endsWith("Exception")) return false
return true
}

if (className.startsWith("org.jetbrains.kotlinx.lincheck.ExecutionClassLoader")) return false

if (className.startsWith("sun.") ||
className.startsWith("java.") ||
className.startsWith("javax.") ||
className.startsWith("jdk.") ||
className.startsWith("org.jetbrains.kotlinx.lincheck.") ||
className.startsWith("sun.nio.ch.lincheck.") ||
className == "kotlinx.coroutines.DebugKt" ||
className == "kotlin.coroutines.jvm.internal.DebugProbesKt"
) return false

return true
}
}

internal object TransformationInjectionsInitializer {
private var initialized = false

fun initialize(instrumentation: Instrumentation) {
if (initialized) return

val bootstrapJarFile = JarFile(ClassLoader.getSystemResource("bootstrap.jar").file)
instrumentation.appendToBootstrapClassLoaderSearch(bootstrapJarFile)

initialized = true
}
}

private fun classKey(loader: ClassLoader?, className: String) =
if (loader == null) {
className
} else {
className to loader
}

0 comments on commit 8c633dc

Please sign in to comment.