diff --git a/src/main/java/io/airlift/bytecode/DumpBytecodeVisitor.java b/src/main/java/io/airlift/bytecode/DumpBytecodeVisitor.java index 2cf67bd..6a3bc9d 100644 --- a/src/main/java/io/airlift/bytecode/DumpBytecodeVisitor.java +++ b/src/main/java/io/airlift/bytecode/DumpBytecodeVisitor.java @@ -21,6 +21,7 @@ import io.airlift.bytecode.control.IfStatement; import io.airlift.bytecode.control.SwitchStatement; import io.airlift.bytecode.control.TryCatch; +import io.airlift.bytecode.control.TryCatch.CatchBlock; import io.airlift.bytecode.control.WhileLoop; import io.airlift.bytecode.debug.LineNumberNode; import io.airlift.bytecode.expression.BytecodeExpression; @@ -54,6 +55,7 @@ import static io.airlift.bytecode.Access.INTERFACE; import static io.airlift.bytecode.ParameterizedType.type; import static java.lang.String.format; +import static java.util.stream.Collectors.joining; public class DumpBytecodeVisitor extends BytecodeVisitor @@ -311,11 +313,15 @@ public Void visitTryCatch(BytecodeNode parent, TryCatch tryCatch) indentLevel--; printLine("}"); - printLine("catch (%s) {", tryCatch.getExceptionName()); - indentLevel++; - tryCatch.getCatchNode().accept(tryCatch, this); - indentLevel--; - printLine("}"); + for (CatchBlock catchBlock : tryCatch.getCatchBlocks()) { + printLine("catch (%s) {", catchBlock.getExceptionTypes().stream() + .map(ParameterizedType::getJavaClassName) + .collect(joining(" | "))); + indentLevel++; + catchBlock.getHandler().accept(tryCatch, this); + indentLevel--; + printLine("}"); + } return null; } diff --git a/src/main/java/io/airlift/bytecode/FastMethodHandleProxies.java b/src/main/java/io/airlift/bytecode/FastMethodHandleProxies.java new file mode 100644 index 0000000..14d5a39 --- /dev/null +++ b/src/main/java/io/airlift/bytecode/FastMethodHandleProxies.java @@ -0,0 +1,174 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.airlift.bytecode; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.bytecode.control.TryCatch; +import io.airlift.bytecode.control.TryCatch.CatchBlock; + +import java.lang.invoke.CallSite; +import java.lang.invoke.ConstantCallSite; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandleProxies; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.lang.reflect.UndeclaredThrowableException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicLong; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.MoreCollectors.onlyElement; +import static io.airlift.bytecode.Access.FINAL; +import static io.airlift.bytecode.Access.PUBLIC; +import static io.airlift.bytecode.Access.SYNTHETIC; +import static io.airlift.bytecode.Access.a; +import static io.airlift.bytecode.ClassGenerator.classGenerator; +import static io.airlift.bytecode.FastMethodHandleProxies.Bootstrap.BOOTSTRAP_METHOD; +import static io.airlift.bytecode.Parameter.arg; +import static io.airlift.bytecode.ParameterizedType.type; +import static io.airlift.bytecode.ParameterizedType.typeFromJavaClassName; +import static io.airlift.bytecode.expression.BytecodeExpressions.invokeDynamic; +import static java.lang.invoke.MethodType.methodType; +import static java.util.Arrays.stream; + +public final class FastMethodHandleProxies +{ + private static final AtomicLong CLASS_ID = new AtomicLong(); + + private FastMethodHandleProxies() {} + + /** + * Faster version of {@link MethodHandleProxies#asInterfaceInstance(Class, MethodHandle)}. + */ + public static T asInterfaceInstance(Class type, MethodHandle target) + { + checkArgument(type.isInterface() && Modifier.isPublic(type.getModifiers()), "not a public interface: %s", type.getName()); + + ClassDefinition classDefinition = new ClassDefinition( + a(PUBLIC, FINAL, SYNTHETIC), + typeFromJavaClassName("$gen." + type.getName() + "_" + CLASS_ID.incrementAndGet()), + type(Object.class), + type(type)); + + classDefinition.declareDefaultConstructor(a(PUBLIC)); + + Method method = getSingleAbstractMethod(type); + Class[] parameterTypes = method.getParameterTypes(); + MethodHandle adaptedTarget = target.asType(methodType(method.getReturnType(), parameterTypes)); + + List parameters = new ArrayList<>(); + for (int i = 0; i < parameterTypes.length; i++) { + parameters.add(arg("arg" + i, parameterTypes[i])); + } + + MethodDefinition methodDefinition = classDefinition.declareMethod( + a(PUBLIC), + method.getName(), + type(method.getReturnType()), + parameters); + + BytecodeNode invocation = invokeDynamic( + BOOTSTRAP_METHOD, + ImmutableList.of(), + method.getName(), + method.getReturnType(), + parameters) + .ret(); + + ImmutableList.Builder exceptionTypes = ImmutableList.builder(); + exceptionTypes.add(type(RuntimeException.class), type(Error.class)); + for (Class exceptionType : method.getExceptionTypes()) { + methodDefinition.addException(exceptionType.asSubclass(Throwable.class)); + exceptionTypes.add(type(exceptionType)); + } + + BytecodeNode throwUndeclared = new BytecodeBlock() + .newObject(UndeclaredThrowableException.class) + .append(OpCode.DUP_X1) + .swap() + .invokeConstructor(UndeclaredThrowableException.class, Throwable.class) + .throwObject(); + + invocation = new TryCatch(invocation, ImmutableList.of( + new CatchBlock(new BytecodeBlock().throwObject(), exceptionTypes.build()), + new CatchBlock(throwUndeclared, ImmutableList.of()))); + + methodDefinition.getBody().append(invocation); + + // note this will not work if interface class is not visible from this class loader, + // but we must use this class loader to ensure the bootstrap method is visible + ClassLoader targetClassLoader = FastMethodHandleProxies.class.getClassLoader(); + DynamicClassLoader dynamicClassLoader = new DynamicClassLoader(targetClassLoader, ImmutableMap.of(0L, adaptedTarget)); + Class newClass = classGenerator(dynamicClassLoader).defineClass(classDefinition, type); + try { + return newClass.getDeclaredConstructor().newInstance(); + } + catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + } + + private static Method getSingleAbstractMethod(Class type) + { + return stream(type.getMethods()) + .filter(method -> Modifier.isAbstract(method.getModifiers())) + .filter(method -> Modifier.isPublic(method.getModifiers())) + .filter(method -> method.getDeclaringClass() != Object.class) + .filter(FastMethodHandleProxies::notJavaObjectMethod) + .collect(onlyElement()); + } + + private static boolean notJavaObjectMethod(Method method) + { + return notMethodMatches(method, "toString", String.class) && + notMethodMatches(method, "hashCode", int.class) && + notMethodMatches(method, "equals", boolean.class, Object.class); + } + + private static boolean notMethodMatches(Method method, String name, Class returnType, Class... parameterTypes) + { + return method.getParameterCount() != parameterTypes.length || + method.getReturnType() != returnType || + !name.equals(method.getName()) || + !Arrays.equals(method.getParameterTypes(), parameterTypes); + } + + public static final class Bootstrap + { + public static final Method BOOTSTRAP_METHOD; + + static { + try { + BOOTSTRAP_METHOD = Bootstrap.class.getMethod("bootstrap", MethodHandles.Lookup.class, String.class, MethodType.class); + } + catch (NoSuchMethodException e) { + throw new AssertionError(e); + } + } + + private Bootstrap() {} + + @SuppressWarnings("unused") + public static CallSite bootstrap(MethodHandles.Lookup callerLookup, String name, MethodType type) + { + DynamicClassLoader classLoader = (DynamicClassLoader) callerLookup.lookupClass().getClassLoader(); + return new ConstantCallSite(classLoader.getCallSiteBindings().get(0L)); + } + } +} diff --git a/src/main/java/io/airlift/bytecode/control/TryCatch.java b/src/main/java/io/airlift/bytecode/control/TryCatch.java index 1548583..c0c5545 100644 --- a/src/main/java/io/airlift/bytecode/control/TryCatch.java +++ b/src/main/java/io/airlift/bytecode/control/TryCatch.java @@ -22,8 +22,11 @@ import io.airlift.bytecode.instruction.LabelNode; import org.objectweb.asm.MethodVisitor; +import java.util.ArrayList; import java.util.List; +import java.util.stream.Stream; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; public class TryCatch @@ -31,20 +34,18 @@ public class TryCatch { private final String comment; private final BytecodeNode tryNode; - private final BytecodeNode catchNode; - private final String exceptionName; + private final List catchBlocks; - public TryCatch(BytecodeNode tryNode, BytecodeNode catchNode, ParameterizedType exceptionType) + public TryCatch(BytecodeNode tryNode, List catchBlocks) { - this(null, tryNode, catchNode, exceptionType); + this(null, tryNode, catchBlocks); } - public TryCatch(String comment, BytecodeNode tryNode, BytecodeNode catchNode, ParameterizedType exceptionType) + public TryCatch(String comment, BytecodeNode tryNode, List catchBlocks) { this.comment = comment; this.tryNode = requireNonNull(tryNode, "tryNode is null"); - this.catchNode = requireNonNull(catchNode, "catchNode is null"); - this.exceptionName = (exceptionType != null) ? exceptionType.getClassName() : null; + this.catchBlocks = ImmutableList.copyOf(requireNonNull(catchBlocks, "catchBlocks is null")); } @Override @@ -58,14 +59,9 @@ public BytecodeNode getTryNode() return tryNode; } - public BytecodeNode getCatchNode() + public List getCatchBlocks() { - return catchNode; - } - - public String getExceptionName() - { - return exceptionName; + return catchBlocks; } @Override @@ -73,7 +69,7 @@ public void accept(MethodVisitor visitor, MethodGenerationContext generationCont { LabelNode tryStart = new LabelNode("tryStart"); LabelNode tryEnd = new LabelNode("tryEnd"); - LabelNode handler = new LabelNode("handler"); + List handlers = new ArrayList<>(); LabelNode done = new LabelNode("done"); BytecodeBlock block = new BytecodeBlock(); @@ -84,21 +80,48 @@ public void accept(MethodVisitor visitor, MethodGenerationContext generationCont .visitLabel(tryEnd) .gotoLabel(done); - // handler block - block.visitLabel(handler) - .append(catchNode); + // catch blocks + for (int i = 0; i < catchBlocks.size(); i++) { + BytecodeNode handlerBlock = catchBlocks.get(i).getHandler(); + LabelNode handler = new LabelNode("handler" + i); + handlers.add(handler); + block.visitLabel(handler) + .append(handlerBlock); + } // all done block.visitLabel(done); block.accept(visitor, generationContext); - visitor.visitTryCatchBlock(tryStart.getLabel(), tryEnd.getLabel(), handler.getLabel(), exceptionName); + + // exception table + for (int i = 0; i < catchBlocks.size(); i++) { + LabelNode handler = handlers.get(i); + List exceptionTypes = catchBlocks.get(i).getExceptionTypes(); + for (ParameterizedType type : exceptionTypes) { + visitor.visitTryCatchBlock( + tryStart.getLabel(), + tryEnd.getLabel(), + handler.getLabel(), + type.getClassName()); + } + if (exceptionTypes.isEmpty()) { + visitor.visitTryCatchBlock( + tryStart.getLabel(), + tryEnd.getLabel(), + handler.getLabel(), + null); + } + } } @Override public List getChildNodes() { - return ImmutableList.of(tryNode, catchNode); + return Stream.concat( + Stream.of(tryNode), + catchBlocks.stream().map(CatchBlock::getHandler)) + .collect(toImmutableList()); } @Override @@ -106,4 +129,26 @@ public T accept(BytecodeNode parent, BytecodeVisitor visitor) { return visitor.visitTryCatch(parent, this); } + + public static class CatchBlock + { + private final BytecodeNode handler; + private final List exceptionTypes; + + public CatchBlock(BytecodeNode handler, List exceptionTypes) + { + this.handler = requireNonNull(handler, "handler is null"); + this.exceptionTypes = ImmutableList.copyOf(requireNonNull(exceptionTypes, "exceptionTypes is null")); + } + + public BytecodeNode getHandler() + { + return handler; + } + + public List getExceptionTypes() + { + return exceptionTypes; + } + } } diff --git a/src/test/java/io/airlift/bytecode/TestFastMethodHandleProxies.java b/src/test/java/io/airlift/bytecode/TestFastMethodHandleProxies.java new file mode 100644 index 0000000..5a42bdb --- /dev/null +++ b/src/test/java/io/airlift/bytecode/TestFastMethodHandleProxies.java @@ -0,0 +1,165 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.airlift.bytecode; + +import com.google.common.base.VerifyException; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandleProxies; +import java.lang.invoke.MutableCallSite; +import java.lang.reflect.UndeclaredThrowableException; +import java.util.function.Consumer; +import java.util.function.IntSupplier; +import java.util.function.LongFunction; +import java.util.function.LongUnaryOperator; + +import static java.lang.invoke.MethodHandles.lookup; +import static java.lang.invoke.MethodType.methodType; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.testng.Assert.assertEquals; + +public class TestFastMethodHandleProxies +{ + @Test + public void testBasic() + throws ReflectiveOperationException + { + assertInterface( + LongUnaryOperator.class, + lookup().findStatic(getClass(), "increment", methodType(long.class, long.class)), + addOne -> assertEquals(addOne.applyAsLong(1), 2L)); + } + + private static long increment(long x) + { + return x + 1; + } + + @Test + public void testGeneric() + throws ReflectiveOperationException + { + assertInterface( + LongFunction.class, + lookup().findStatic(getClass(), "incrementAndPrint", methodType(String.class, long.class)), + print -> assertEquals(print.apply(1), "2")); + } + + private static String incrementAndPrint(long x) + { + return String.valueOf(x + 1); + } + + @Test + public void testObjectAndDefaultMethods() + throws ReflectiveOperationException + { + assertInterface( + StringLength.class, + lookup().findStatic(getClass(), "stringLength", methodType(int.class, String.class)), + length -> { + assertEquals(length.length("abc"), 3); + assertEquals(length.theAnswer(), 42); + }); + } + + private static int stringLength(String s) + { + return s.length(); + } + + public interface StringLength + { + int length(String s); + + default int theAnswer() + { + return 42; + } + + @Override + String toString(); + } + + @Test + public void testUncheckedException() + throws ReflectiveOperationException + { + assertInterface( + Runnable.class, + lookup().findStatic(getClass(), "throwUncheckedException", methodType(void.class)), + runnable -> assertThatThrownBy(runnable::run) + .isInstanceOf(VerifyException.class)); + } + + private static void throwUncheckedException() + { + throw new VerifyException("unchecked"); + } + + @Test + public void testCheckedException() + throws ReflectiveOperationException + { + assertInterface( + Runnable.class, + lookup().findStatic(getClass(), "throwCheckedException", methodType(void.class)), + runnable -> assertThatThrownBy(runnable::run) + .isInstanceOf(UndeclaredThrowableException.class) + .hasCauseInstanceOf(IOException.class)); + } + + private static void throwCheckedException() + throws IOException + { + throw new IOException("checked"); + } + + @Test + public void testMutableCallSite() + throws ReflectiveOperationException + { + MethodHandle one = lookup().findStatic(getClass(), "one", methodType(int.class)); + MethodHandle two = lookup().findStatic(getClass(), "two", methodType(int.class)); + + MutableCallSite callSite = new MutableCallSite(methodType(int.class)); + assertInterface( + IntSupplier.class, + callSite.dynamicInvoker(), + supplier -> { + callSite.setTarget(one); + assertEquals(supplier.getAsInt(), 1); + callSite.setTarget(two); + assertEquals(supplier.getAsInt(), 2); + }); + } + + private static int one() + { + return 1; + } + + private static int two() + { + return 2; + } + + private static void assertInterface(Class interfaceType, MethodHandle target, Consumer consumer) + { + consumer.accept(MethodHandleProxies.asInterfaceInstance(interfaceType, target)); + consumer.accept(FastMethodHandleProxies.asInterfaceInstance(interfaceType, target)); + } +}