Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ private static BytecodeNode generateFullInvocation(
Class<?> type = methodType.parameterArray()[currentParameterIndex];
stackTypes.add(type);
if (instance.isPresent() && !instanceIsBound) {
checkState(type.equals(implementation.getInstanceFactory().get().type().returnType()), "Mismatched type for instance parameter");
checkState(type.equals(binder.getAccessibleType(implementation.getInstanceFactory().get().type().returnType())), "Mismatched type for instance parameter");
block.append(instance.get());
instanceIsBound = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public CallSiteBinder getCallSiteBinder()

public FieldDefinition getCachedInstance(MethodHandle methodHandle)
{
FieldDefinition field = classDefinition.declareField(a(PRIVATE, FINAL), "__cachedInstance" + nextId, methodHandle.type().returnType());
FieldDefinition field = classDefinition.declareField(a(PRIVATE, FINAL), "__cachedInstance" + nextId, callSiteBinder.getAccessibleType(methodHandle.type().returnType()));
initializers.put(field, methodHandle);
nextId++;
return field;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,41 @@

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.HashMap;
import java.util.Map;

import static com.google.common.base.MoreObjects.toStringHelper;
import static java.util.Objects.requireNonNull;

public final class CallSiteBinder
{
private int nextId;

private final ClassLoader classLoader;
private final Map<Long, MethodHandle> bindings = new HashMap<>();

public CallSiteBinder()
{
this(CallSiteBinder.class.getClassLoader());
}

public CallSiteBinder(ClassLoader classLoader)
{
this.classLoader = requireNonNull(classLoader, "classLoader is null");
}

public Binding bind(MethodHandle method)
{
// Bound handles can come from plugin class loaders. Hide reference
// types the generated-code loader cannot resolve to the same class.
MethodType type = getAccessibleType(method.type());
if (!method.type().equals(type)) {
method = method.asType(type);
}

long bindingId = nextId++;
Binding binding = new Binding(bindingId, method.type());
Binding binding = new Binding(bindingId, type);

bindings.put(bindingId, method);
return binding;
Expand All @@ -47,6 +67,36 @@ public Map<Long, MethodHandle> getBindings()
return ImmutableMap.copyOf(bindings);
}

public MethodType getAccessibleType(MethodType type)
{
MethodType accessibleType = type.changeReturnType(getAccessibleType(type.returnType()));
for (int i = 0; i < type.parameterCount(); i++) {
accessibleType = accessibleType.changeParameterType(i, getAccessibleType(type.parameterType(i)));
}
return accessibleType;
}

public Class<?> getAccessibleType(Class<?> type)
{
if (isAccessible(type)) {
return type;
}
return Object.class;
}

private boolean isAccessible(Class<?> type)
{
if (type.isPrimitive()) {
return true;
}
try {
return Class.forName(type.getName(), false, classLoader) == type;
}
catch (ClassNotFoundException | LinkageError _) {
return false;
}
}

@Override
public String toString()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ private static BytecodeBlock generateFullInvocation(
while (currentParameterIndex < methodType.parameterArray().length) {
Class<?> type = methodType.parameterArray()[currentParameterIndex];
if (instance.isPresent() && !instanceIsBound) {
checkState(type.equals(implementation.getInstanceFactory().get().type().returnType()), "Mismatched type for instance parameter");
checkState(type.equals(binder.getAccessibleType(implementation.getInstanceFactory().get().type().returnType())), "Mismatched type for instance parameter");
block.append(instance.get());
instanceIsBound = true;
}
Expand Down Expand Up @@ -429,7 +429,7 @@ public CachedInstanceBinder(ClassDefinition classDefinition, CallSiteBinder call
public FieldDefinition getCachedInstance(MethodHandle methodHandle)
{
if (field.isEmpty()) {
field = Optional.of(classDefinition.declareField(a(PRIVATE, FINAL), "__cachedInstance", methodHandle.type().returnType()));
field = Optional.of(classDefinition.declareField(a(PRIVATE, FINAL), "__cachedInstance", callSiteBinder.getAccessibleType(methodHandle.type().returnType())));
method = Optional.of(methodHandle);
}
return field.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,72 @@
package io.trino.sql.gen;

import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.DynamicClassLoader;
import io.airlift.bytecode.FieldDefinition;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.slice.Slice;
import io.trino.metadata.InternalFunctionBundle;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.SqlScalarFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.operator.project.PageProjection;
import io.trino.operator.project.SelectedPositions;
import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction;
import io.trino.operator.scalar.SpecializedSqlScalarFunction;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.VariableWidthBlock;
import io.trino.spi.block.VariableWidthBlockBuilder;
import io.trino.spi.connector.SourcePage;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention;
import io.trino.spi.function.Signature;
import io.trino.spi.type.AbstractVariableWidthType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.sql.PlannerContext;
import io.trino.sql.relational.CallExpression;
import io.trino.transaction.TransactionManager;
import org.junit.jupiter.api.Test;

import java.lang.invoke.MethodHandle;
import java.lang.reflect.Field;
import java.util.List;
import java.util.Optional;
import java.util.function.Supplier;

import static io.airlift.bytecode.Access.FINAL;
import static io.airlift.bytecode.Access.PUBLIC;
import static io.airlift.bytecode.Access.STATIC;
import static io.airlift.bytecode.Access.a;
import static io.airlift.bytecode.Parameter.arg;
import static io.airlift.bytecode.ParameterizedType.type;
import static io.airlift.slice.Slices.allocate;
import static io.trino.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.OperatorType.ADD;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder;
import static io.trino.sql.relational.Expressions.call;
import static io.trino.sql.relational.Expressions.constant;
import static io.trino.sql.relational.Expressions.field;
import static io.trino.testing.TestingConnectorSession.SESSION;
import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy;
import static io.trino.transaction.InMemoryTransactionManager.createTestTransactionManager;
import static io.trino.util.CompilerUtils.defineClass;
import static io.trino.util.CompilerUtils.makeClassName;
import static io.trino.util.Reflection.constructorMethodHandle;
import static io.trino.util.Reflection.field;
import static io.trino.util.Reflection.methodHandle;
import static java.lang.invoke.MethodHandles.insertArguments;
import static java.util.Collections.nCopies;
import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;

public class TestPageFunctionCompiler
Expand Down Expand Up @@ -69,6 +114,36 @@ public void testFailureDoesNotCorruptFutureResults()
assertThat(goodPage.getPositionCount()).isEqualTo(goodResult.getPositionCount());
}

@Test
public void testProjectionWithPrivateJavaType()
{
HiddenFunctions hiddenFunctions = createHiddenFunctions();
Type hiddenType = hiddenFunctions.type();

TransactionManager transactionManager = createTestTransactionManager();
PlannerContext plannerContext = plannerContextBuilder()
.withTransactionManager(transactionManager)
.addType(hiddenType)
.addFunctions(InternalFunctionBundle.builder()
.function(new HiddenFunction("test_hidden_constructor", hiddenType, hiddenFunctions.constructor(), ImmutableList.of()))
.function(new HiddenFunction("test_hidden_identity", hiddenType, hiddenFunctions.identity(), ImmutableList.of(hiddenType)))
.build())
.build();
TestingFunctionResolution functionResolution = new TestingFunctionResolution(transactionManager, plannerContext);

ResolvedFunction constructor = functionResolution.resolveFunction("test_hidden_constructor", fromTypes());
ResolvedFunction identity = functionResolution.resolveFunction("test_hidden_identity", fromTypes(hiddenType));
PageProjection projection = functionResolution.getPageFunctionCompiler()
.compileProjection(call(identity, call(constructor)), Optional.empty())
.get();

Page page = createLongBlockPage(0, 1);
Block result = project(projection, page, SelectedPositions.positionsRange(0, page.getPositionCount()));
assertThat(result.getPositionCount()).isEqualTo(page.getPositionCount());
assertThat(hiddenType.getObjectValue(result, 0)).isEqualTo(42);
assertThat(hiddenType.getObjectValue(result, 1)).isEqualTo(42);
}

@Test
public void testCache()
{
Expand Down Expand Up @@ -98,4 +173,127 @@ private static Page createLongBlockPage(long... values)
}
return new Page(builder.build());
}

private static HiddenFunctions createHiddenFunctions()
{
ClassDefinition classDefinition = new ClassDefinition(
a(PUBLIC, FINAL),
makeClassName("HiddenValue"),
type(Object.class));
FieldDefinition valueField = classDefinition.declareField(a(PUBLIC), "value", int.class);

Parameter constructorValue = arg("value", int.class);
MethodDefinition constructor = classDefinition.declareConstructor(a(PUBLIC), constructorValue);
constructor.getBody()
.append(constructor.getThis())
.invokeConstructor(Object.class)
.append(constructor.getThis().setField(valueField, constructorValue))
.ret();

Parameter identityValue = arg("value", classDefinition.getType());
MethodDefinition identity = classDefinition.declareMethod(a(PUBLIC, STATIC), "identity", classDefinition.getType(), identityValue);
identity.getBody()
.append(identityValue.ret());

Class<?> hiddenClass = defineClass(classDefinition, Object.class, new DynamicClassLoader(TestPageFunctionCompiler.class.getClassLoader()));
return new HiddenFunctions(
new HiddenType(hiddenClass),
insertArguments(constructorMethodHandle(hiddenClass, int.class), 0, 42),
methodHandle(hiddenClass, "identity", hiddenClass));
}

private record HiddenFunctions(Type type, MethodHandle constructor, MethodHandle identity) {}

private static final class HiddenFunction
extends SqlScalarFunction
{
private final MethodHandle methodHandle;
private final List<InvocationArgumentConvention> argumentConventions;

private HiddenFunction(String name, Type returnType, MethodHandle methodHandle, List<Type> argumentTypes)
{
super(functionMetadata(name, returnType, argumentTypes));
this.methodHandle = requireNonNull(methodHandle, "methodHandle is null");
this.argumentConventions = nCopies(argumentTypes.size(), NEVER_NULL);
}

@Override
protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature)
{
return new ChoicesSpecializedSqlScalarFunction(boundSignature, FAIL_ON_NULL, argumentConventions, methodHandle);
}

private static FunctionMetadata functionMetadata(String name, Type returnType, List<Type> argumentTypes)
{
Signature.Builder signature = Signature.builder()
.returnType(returnType);
argumentTypes.forEach(signature::argumentType);
return FunctionMetadata.scalarBuilder(name)
.signature(signature.build())
.hidden()
.description("test hidden type function")
.build();
}
}

private static final class HiddenType
extends AbstractVariableWidthType
{
private static final TypeSignature TYPE_SIGNATURE = new TypeSignature("test_hidden");

private final Field valueField;

private HiddenType(Class<?> javaType)
{
super(TYPE_SIGNATURE, javaType);
valueField = field(javaType, "value");
}

@Override
public String getDisplayName()
{
return TYPE_SIGNATURE.toString();
}

@Override
public Object getObjectValue(Block block, int position)
{
if (block.isNull(position)) {
return null;
}
return getSlice(block, position).getInt(0);
}

@Override
public Slice getSlice(Block block, int position)
{
VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock();
return valueBlock.getSlice(block.getUnderlyingValuePosition(position));
}

@Override
public void writeSlice(BlockBuilder blockBuilder, Slice value)
{
writeSlice(blockBuilder, value, 0, value.length());
}

@Override
public void writeSlice(BlockBuilder blockBuilder, Slice value, int offset, int length)
{
((VariableWidthBlockBuilder) blockBuilder).writeEntry(value, offset, length);
}

@Override
public void writeObject(BlockBuilder blockBuilder, Object value)
{
try {
Slice slice = allocate(Integer.BYTES);
slice.setInt(0, valueField.getInt(value));
writeSlice(blockBuilder, slice);
}
catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}
}
}
Loading