From 94d094129dafcd2742a6c199a1a62789be7b0105 Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Mon, 13 Apr 2026 17:50:34 -0700 Subject: [PATCH] Handle plugin-private stack types in callsite binder --- .../java/io/trino/sql/gen/BytecodeUtils.java | 2 +- .../trino/sql/gen/CachedInstanceBinder.java | 2 +- .../java/io/trino/sql/gen/CallSiteBinder.java | 52 ++++- .../columnar/CallColumnarFilterGenerator.java | 4 +- .../sql/gen/TestPageFunctionCompiler.java | 198 ++++++++++++++++++ 5 files changed, 253 insertions(+), 5 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeUtils.java b/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeUtils.java index 5b2a3dd19d50..36b7b1731c7b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeUtils.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeUtils.java @@ -291,7 +291,7 @@ private static BytecodeNode generateFullInvocation( Class type = methodType.parameterArray()[currentParameterIndex]; stackTypes.add(type); if (instance.isPresent() && !instanceIsBound) { - checkState(type.equals(implementation.getInstanceFactory().get().type().returnType()), "Mismatched type for instance parameter"); + checkState(type.equals(binder.getAccessibleType(implementation.getInstanceFactory().get().type().returnType())), "Mismatched type for instance parameter"); block.append(instance.get()); instanceIsBound = true; } diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/CachedInstanceBinder.java b/core/trino-main/src/main/java/io/trino/sql/gen/CachedInstanceBinder.java index fc33d4060c28..e29baf80d9c3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/CachedInstanceBinder.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/CachedInstanceBinder.java @@ -49,7 +49,7 @@ public CallSiteBinder getCallSiteBinder() public FieldDefinition getCachedInstance(MethodHandle methodHandle) { - FieldDefinition field = classDefinition.declareField(a(PRIVATE, FINAL), "__cachedInstance" + nextId, methodHandle.type().returnType()); + FieldDefinition field = classDefinition.declareField(a(PRIVATE, FINAL), "__cachedInstance" + nextId, callSiteBinder.getAccessibleType(methodHandle.type().returnType())); initializers.put(field, methodHandle); nextId++; return field; diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/CallSiteBinder.java b/core/trino-main/src/main/java/io/trino/sql/gen/CallSiteBinder.java index aedc65c45f04..efe1dfa87e5d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/CallSiteBinder.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/CallSiteBinder.java @@ -17,21 +17,41 @@ import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; import java.util.HashMap; import java.util.Map; import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; public final class CallSiteBinder { private int nextId; + private final ClassLoader classLoader; private final Map bindings = new HashMap<>(); + public CallSiteBinder() + { + this(CallSiteBinder.class.getClassLoader()); + } + + public CallSiteBinder(ClassLoader classLoader) + { + this.classLoader = requireNonNull(classLoader, "classLoader is null"); + } + public Binding bind(MethodHandle method) { + // Bound handles can come from plugin class loaders. Hide reference + // types the generated-code loader cannot resolve to the same class. + MethodType type = getAccessibleType(method.type()); + if (!method.type().equals(type)) { + method = method.asType(type); + } + long bindingId = nextId++; - Binding binding = new Binding(bindingId, method.type()); + Binding binding = new Binding(bindingId, type); bindings.put(bindingId, method); return binding; @@ -47,6 +67,36 @@ public Map getBindings() return ImmutableMap.copyOf(bindings); } + public MethodType getAccessibleType(MethodType type) + { + MethodType accessibleType = type.changeReturnType(getAccessibleType(type.returnType())); + for (int i = 0; i < type.parameterCount(); i++) { + accessibleType = accessibleType.changeParameterType(i, getAccessibleType(type.parameterType(i))); + } + return accessibleType; + } + + public Class getAccessibleType(Class type) + { + if (isAccessible(type)) { + return type; + } + return Object.class; + } + + private boolean isAccessible(Class type) + { + if (type.isPrimitive()) { + return true; + } + try { + return Class.forName(type.getName(), false, classLoader) == type; + } + catch (ClassNotFoundException | LinkageError _) { + return false; + } + } + @Override public String toString() { diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/CallColumnarFilterGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/CallColumnarFilterGenerator.java index 8013ca947550..90053c968bb2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/CallColumnarFilterGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/CallColumnarFilterGenerator.java @@ -309,7 +309,7 @@ private static BytecodeBlock generateFullInvocation( while (currentParameterIndex < methodType.parameterArray().length) { Class type = methodType.parameterArray()[currentParameterIndex]; if (instance.isPresent() && !instanceIsBound) { - checkState(type.equals(implementation.getInstanceFactory().get().type().returnType()), "Mismatched type for instance parameter"); + checkState(type.equals(binder.getAccessibleType(implementation.getInstanceFactory().get().type().returnType())), "Mismatched type for instance parameter"); block.append(instance.get()); instanceIsBound = true; } @@ -429,7 +429,7 @@ public CachedInstanceBinder(ClassDefinition classDefinition, CallSiteBinder call public FieldDefinition getCachedInstance(MethodHandle methodHandle) { if (field.isEmpty()) { - field = Optional.of(classDefinition.declareField(a(PRIVATE, FINAL), "__cachedInstance", methodHandle.type().returnType())); + field = Optional.of(classDefinition.declareField(a(PRIVATE, FINAL), "__cachedInstance", callSiteBinder.getAccessibleType(methodHandle.type().returnType()))); method = Optional.of(methodHandle); } return field.get(); diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/TestPageFunctionCompiler.java b/core/trino-main/src/test/java/io/trino/sql/gen/TestPageFunctionCompiler.java index 53fd1ec1e208..e449ca04615b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/TestPageFunctionCompiler.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/TestPageFunctionCompiler.java @@ -14,27 +14,72 @@ package io.trino.sql.gen; import com.google.common.collect.ImmutableList; +import io.airlift.bytecode.ClassDefinition; +import io.airlift.bytecode.DynamicClassLoader; +import io.airlift.bytecode.FieldDefinition; +import io.airlift.bytecode.MethodDefinition; +import io.airlift.bytecode.Parameter; +import io.airlift.slice.Slice; +import io.trino.metadata.InternalFunctionBundle; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.SqlScalarFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.operator.project.PageProjection; import io.trino.operator.project.SelectedPositions; +import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; +import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.SourcePage; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; +import io.trino.spi.function.Signature; +import io.trino.spi.type.AbstractVariableWidthType; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeSignature; +import io.trino.sql.PlannerContext; import io.trino.sql.relational.CallExpression; +import io.trino.transaction.TransactionManager; import org.junit.jupiter.api.Test; +import java.lang.invoke.MethodHandle; +import java.lang.reflect.Field; +import java.util.List; import java.util.Optional; import java.util.function.Supplier; +import static io.airlift.bytecode.Access.FINAL; +import static io.airlift.bytecode.Access.PUBLIC; +import static io.airlift.bytecode.Access.STATIC; +import static io.airlift.bytecode.Access.a; +import static io.airlift.bytecode.Parameter.arg; +import static io.airlift.bytecode.ParameterizedType.type; +import static io.airlift.slice.Slices.allocate; import static io.trino.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.OperatorType.ADD; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; +import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder; import static io.trino.sql.relational.Expressions.call; import static io.trino.sql.relational.Expressions.constant; import static io.trino.sql.relational.Expressions.field; import static io.trino.testing.TestingConnectorSession.SESSION; import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy; +import static io.trino.transaction.InMemoryTransactionManager.createTestTransactionManager; +import static io.trino.util.CompilerUtils.defineClass; +import static io.trino.util.CompilerUtils.makeClassName; +import static io.trino.util.Reflection.constructorMethodHandle; +import static io.trino.util.Reflection.field; +import static io.trino.util.Reflection.methodHandle; +import static java.lang.invoke.MethodHandles.insertArguments; +import static java.util.Collections.nCopies; +import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; public class TestPageFunctionCompiler @@ -69,6 +114,36 @@ public void testFailureDoesNotCorruptFutureResults() assertThat(goodPage.getPositionCount()).isEqualTo(goodResult.getPositionCount()); } + @Test + public void testProjectionWithPrivateJavaType() + { + HiddenFunctions hiddenFunctions = createHiddenFunctions(); + Type hiddenType = hiddenFunctions.type(); + + TransactionManager transactionManager = createTestTransactionManager(); + PlannerContext plannerContext = plannerContextBuilder() + .withTransactionManager(transactionManager) + .addType(hiddenType) + .addFunctions(InternalFunctionBundle.builder() + .function(new HiddenFunction("test_hidden_constructor", hiddenType, hiddenFunctions.constructor(), ImmutableList.of())) + .function(new HiddenFunction("test_hidden_identity", hiddenType, hiddenFunctions.identity(), ImmutableList.of(hiddenType))) + .build()) + .build(); + TestingFunctionResolution functionResolution = new TestingFunctionResolution(transactionManager, plannerContext); + + ResolvedFunction constructor = functionResolution.resolveFunction("test_hidden_constructor", fromTypes()); + ResolvedFunction identity = functionResolution.resolveFunction("test_hidden_identity", fromTypes(hiddenType)); + PageProjection projection = functionResolution.getPageFunctionCompiler() + .compileProjection(call(identity, call(constructor)), Optional.empty()) + .get(); + + Page page = createLongBlockPage(0, 1); + Block result = project(projection, page, SelectedPositions.positionsRange(0, page.getPositionCount())); + assertThat(result.getPositionCount()).isEqualTo(page.getPositionCount()); + assertThat(hiddenType.getObjectValue(result, 0)).isEqualTo(42); + assertThat(hiddenType.getObjectValue(result, 1)).isEqualTo(42); + } + @Test public void testCache() { @@ -98,4 +173,127 @@ private static Page createLongBlockPage(long... values) } return new Page(builder.build()); } + + private static HiddenFunctions createHiddenFunctions() + { + ClassDefinition classDefinition = new ClassDefinition( + a(PUBLIC, FINAL), + makeClassName("HiddenValue"), + type(Object.class)); + FieldDefinition valueField = classDefinition.declareField(a(PUBLIC), "value", int.class); + + Parameter constructorValue = arg("value", int.class); + MethodDefinition constructor = classDefinition.declareConstructor(a(PUBLIC), constructorValue); + constructor.getBody() + .append(constructor.getThis()) + .invokeConstructor(Object.class) + .append(constructor.getThis().setField(valueField, constructorValue)) + .ret(); + + Parameter identityValue = arg("value", classDefinition.getType()); + MethodDefinition identity = classDefinition.declareMethod(a(PUBLIC, STATIC), "identity", classDefinition.getType(), identityValue); + identity.getBody() + .append(identityValue.ret()); + + Class hiddenClass = defineClass(classDefinition, Object.class, new DynamicClassLoader(TestPageFunctionCompiler.class.getClassLoader())); + return new HiddenFunctions( + new HiddenType(hiddenClass), + insertArguments(constructorMethodHandle(hiddenClass, int.class), 0, 42), + methodHandle(hiddenClass, "identity", hiddenClass)); + } + + private record HiddenFunctions(Type type, MethodHandle constructor, MethodHandle identity) {} + + private static final class HiddenFunction + extends SqlScalarFunction + { + private final MethodHandle methodHandle; + private final List argumentConventions; + + private HiddenFunction(String name, Type returnType, MethodHandle methodHandle, List argumentTypes) + { + super(functionMetadata(name, returnType, argumentTypes)); + this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); + this.argumentConventions = nCopies(argumentTypes.size(), NEVER_NULL); + } + + @Override + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) + { + return new ChoicesSpecializedSqlScalarFunction(boundSignature, FAIL_ON_NULL, argumentConventions, methodHandle); + } + + private static FunctionMetadata functionMetadata(String name, Type returnType, List argumentTypes) + { + Signature.Builder signature = Signature.builder() + .returnType(returnType); + argumentTypes.forEach(signature::argumentType); + return FunctionMetadata.scalarBuilder(name) + .signature(signature.build()) + .hidden() + .description("test hidden type function") + .build(); + } + } + + private static final class HiddenType + extends AbstractVariableWidthType + { + private static final TypeSignature TYPE_SIGNATURE = new TypeSignature("test_hidden"); + + private final Field valueField; + + private HiddenType(Class javaType) + { + super(TYPE_SIGNATURE, javaType); + valueField = field(javaType, "value"); + } + + @Override + public String getDisplayName() + { + return TYPE_SIGNATURE.toString(); + } + + @Override + public Object getObjectValue(Block block, int position) + { + if (block.isNull(position)) { + return null; + } + return getSlice(block, position).getInt(0); + } + + @Override + public Slice getSlice(Block block, int position) + { + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + return valueBlock.getSlice(block.getUnderlyingValuePosition(position)); + } + + @Override + public void writeSlice(BlockBuilder blockBuilder, Slice value) + { + writeSlice(blockBuilder, value, 0, value.length()); + } + + @Override + public void writeSlice(BlockBuilder blockBuilder, Slice value, int offset, int length) + { + ((VariableWidthBlockBuilder) blockBuilder).writeEntry(value, offset, length); + } + + @Override + public void writeObject(BlockBuilder blockBuilder, Object value) + { + try { + Slice slice = allocate(Integer.BYTES); + slice.setInt(0, valueField.getInt(value)); + writeSlice(blockBuilder, slice); + } + catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + } + } }