diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayTransformFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayTransformFunction.java index ca53a645eb1e..cecf5c29ae34 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayTransformFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayTransformFunction.java @@ -151,8 +151,8 @@ private static MethodHandle generateTransform(Type inputType, Type outputType) private static MethodDefinition generateTransformValueInner(ClassDefinition definition, CallSiteBinder binder, Type inputType, Type outputType) { - Class inputJavaType = Primitives.wrap(inputType.getJavaType()); - Class outputJavaType = Primitives.wrap(outputType.getJavaType()); + Class inputJavaType = binder.getAccessibleType(Primitives.wrap(inputType.getJavaType())); + Class outputJavaType = binder.getAccessibleType(Primitives.wrap(outputType.getJavaType())); Parameter block = arg("block", Block.class); Parameter function = arg("function", UnaryFunctionInterface.class); @@ -190,7 +190,7 @@ private static MethodDefinition generateTransformValueInner(ClassDefinition defi writeOutputElement = new IfStatement() .condition(equal(outputElement, constantNull(outputJavaType))) .ifTrue(elementBuilder.invoke("appendNull", BlockBuilder.class).pop()) - .ifFalse(constantType(binder, outputType).writeValue(elementBuilder, outputElement.cast(outputType.getJavaType()))); + .ifFalse(constantType(binder, outputType).writeValue(elementBuilder, outputElement.cast(outputJavaType))); } else { writeOutputElement = new BytecodeBlock().append(elementBuilder.invoke("appendNull", BlockBuilder.class).pop()); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapFilterFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapFilterFunction.java index e73d02b6a0b0..cf00f827fc68 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapFilterFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapFilterFunction.java @@ -161,8 +161,8 @@ private static MethodDefinition generateFilterInner(ClassDefinition definition, Type keyType = mapType.getKeyType(); Type valueType = mapType.getValueType(); - Class keyJavaType = Primitives.wrap(keyType.getJavaType()); - Class valueJavaType = Primitives.wrap(valueType.getJavaType()); + Class keyJavaType = binder.getAccessibleType(Primitives.wrap(keyType.getJavaType())); + Class valueJavaType = binder.getAccessibleType(Primitives.wrap(valueType.getJavaType())); Variable size = scope.declareVariable("size", body, map.invoke("getSize", int.class)); Variable rawOffset = scope.declareVariable("rawOffset", body, map.invoke("getRawOffset", int.class)); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapTransformKeysFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapTransformKeysFunction.java index 71f5eff6924e..eb481d09805c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapTransformKeysFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapTransformKeysFunction.java @@ -185,9 +185,9 @@ private static MethodDefinition generateTransformKeyInner(ClassDefinition defini BytecodeBlock body = method.getBody(); Scope scope = method.getScope(); - Class keyJavaType = Primitives.wrap(keyType.getJavaType()); - Class transformedKeyJavaType = Primitives.wrap(transformedKeyType.getJavaType()); - Class valueJavaType = Primitives.wrap(valueType.getJavaType()); + Class keyJavaType = binder.getAccessibleType(Primitives.wrap(keyType.getJavaType())); + Class transformedKeyJavaType = binder.getAccessibleType(Primitives.wrap(transformedKeyType.getJavaType())); + Class valueJavaType = binder.getAccessibleType(Primitives.wrap(valueType.getJavaType())); Variable size = scope.declareVariable("size", body, map.invoke("getSize", int.class)); Variable rawOffset = scope.declareVariable("rawOffset", body, map.invoke("getRawOffset", int.class)); @@ -243,7 +243,7 @@ private static MethodDefinition generateTransformKeyInner(ClassDefinition defini .condition(equal(transformedKeyElement, constantNull(transformedKeyJavaType))) .ifTrue(throwNullKeyException) .ifFalse(new BytecodeBlock() - .append(constantType(binder, transformedKeyType).writeValue(keyBuilder, transformedKeyElement.cast(transformedKeyType.getJavaType()))) + .append(constantType(binder, transformedKeyType).writeValue(keyBuilder, transformedKeyElement.cast(transformedKeyJavaType))) .append(valueBuilder.invoke( "append", void.class, diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapTransformValuesFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapTransformValuesFunction.java index 62adc37eb1cf..7e8c087e026c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapTransformValuesFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapTransformValuesFunction.java @@ -173,9 +173,9 @@ private static MethodDefinition generateTransformInner(ClassDefinition definitio BytecodeBlock body = method.getBody(); Scope scope = method.getScope(); - Class keyJavaType = Primitives.wrap(keyType.getJavaType()); - Class valueJavaType = Primitives.wrap(valueType.getJavaType()); - Class transformedValueJavaType = Primitives.wrap(transformedValueType.getJavaType()); + Class keyJavaType = binder.getAccessibleType(Primitives.wrap(keyType.getJavaType())); + Class valueJavaType = binder.getAccessibleType(Primitives.wrap(valueType.getJavaType())); + Class transformedValueJavaType = binder.getAccessibleType(Primitives.wrap(transformedValueType.getJavaType())); Variable size = scope.declareVariable("size", body, map.invoke("getSize", int.class)); Variable rawOffset = scope.declareVariable("rawOffset", body, map.invoke("getRawOffset", int.class)); @@ -227,7 +227,7 @@ private static MethodDefinition generateTransformInner(ClassDefinition definitio writeTransformedValueElement = new IfStatement() .condition(equal(transformedValueElement, constantNull(transformedValueJavaType))) .ifTrue(valueBuilder.invoke("appendNull", BlockBuilder.class).pop()) - .ifFalse(constantType(binder, transformedValueType).writeValue(valueBuilder, transformedValueElement.cast(transformedValueType.getJavaType()))); + .ifFalse(constantType(binder, transformedValueType).writeValue(valueBuilder, transformedValueElement.cast(transformedValueJavaType))); } else { writeTransformedValueElement = valueBuilder.invoke("appendNull", BlockBuilder.class).pop(); diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/AndCodeGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/AndCodeGenerator.java index f438cd7ad447..1c119a2fe968 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/AndCodeGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/AndCodeGenerator.java @@ -58,7 +58,7 @@ public BytecodeNode generateExpression(BytecodeGeneratorContext generator) ifWasNull.ifTrue() .comment("clear the null flag, pop residual value off stack, and push was null flag on the stack (true)") - .pop(term.type().getJavaType()) // discard residual value + .pop(generator.getCallSiteBinder().getAccessibleType(term.type().getJavaType())) // discard residual value .pop(boolean.class) // discard the previous "we've seen a null flag" .push(true); diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/ArrayConstructorCodeGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/ArrayConstructorCodeGenerator.java index 0daac1febc13..a18dbd3edb05 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/ArrayConstructorCodeGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/ArrayConstructorCodeGenerator.java @@ -55,7 +55,7 @@ public BytecodeNode generateExpression(BytecodeGeneratorContext context) Variable blockBuilder = scope.getOrCreateTempVariable(BlockBuilder.class); block.append(blockBuilder.set(constantType(binder, elementType).invoke("createBlockBuilder", BlockBuilder.class, constantNull(BlockBuilderStatus.class), constantInt(elements.size())))); - Variable element = scope.getOrCreateTempVariable(elementType.getJavaType()); + Variable element = scope.getOrCreateTempVariable(binder.getAccessibleType(elementType.getJavaType())); for (Expression item : elements) { block.append(context.wasNull().set(constantFalse())); diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/ArrayMapBytecodeExpression.java b/core/trino-main/src/main/java/io/trino/sql/gen/ArrayMapBytecodeExpression.java index 263d737c959a..20b9dcc6bc7c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/ArrayMapBytecodeExpression.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/ArrayMapBytecodeExpression.java @@ -72,8 +72,8 @@ public ArrayMapBytecodeExpression( mapperDescription = "null"; } else { - Variable element = scope.declareVariable(fromType.getJavaType(), "element_" + NEXT_VARIABLE_ID.getAndIncrement()); - Variable newElement = scope.declareVariable(toType.getJavaType(), "newElement_" + NEXT_VARIABLE_ID.getAndIncrement()); + Variable element = scope.declareVariable(binder.getAccessibleType(fromType.getJavaType()), "element_" + NEXT_VARIABLE_ID.getAndIncrement()); + Variable newElement = scope.declareVariable(binder.getAccessibleType(toType.getJavaType()), "newElement_" + NEXT_VARIABLE_ID.getAndIncrement()); SqlTypeBytecodeExpression elementTypeConstant = constantType(binder, fromType); SqlTypeBytecodeExpression newElementTypeConstant = constantType(binder, toType); mapElement = new BytecodeBlock() diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/BetweenCodeGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/BetweenCodeGenerator.java index 02cdb9600556..cbefe7b10235 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/BetweenCodeGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/BetweenCodeGenerator.java @@ -53,7 +53,8 @@ public BetweenCodeGenerator(Between between, Metadata metadata) @Override public BytecodeNode generateExpression(BytecodeGeneratorContext context) { - Variable firstValue = context.getScope().getOrCreateTempVariable(value.type().getJavaType()); + Class valueJavaType = context.getCallSiteBinder().getAccessibleType(value.type().getJavaType()); + Variable firstValue = context.getScope().getOrCreateTempVariable(valueJavaType); Reference valueReference = createTempReference(firstValue, value.type()); Logical newExpression = new Logical( @@ -68,7 +69,7 @@ public BytecodeNode generateExpression(BytecodeGeneratorContext context) BytecodeBlock block = new BytecodeBlock() .comment("check if value is null") .append(context.generate(value)) - .append(ifWasNullPopAndGoto(context.getScope(), done, boolean.class, value.type().getJavaType())) + .append(ifWasNullPopAndGoto(context.getScope(), done, boolean.class, valueJavaType)) .putVariable(firstValue) .append(context.generate(newExpression)) .visitLabel(done); diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/CoalesceCodeGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/CoalesceCodeGenerator.java index 9e26513e8b64..d1883246eb90 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/CoalesceCodeGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/CoalesceCodeGenerator.java @@ -44,6 +44,8 @@ public CoalesceCodeGenerator(Coalesce coalesce) @Override public BytecodeNode generateExpression(BytecodeGeneratorContext generatorContext) { + Class returnJavaType = generatorContext.getCallSiteBinder().getAccessibleType(returnType.getJavaType()); + List operands = new ArrayList<>(); for (Expression expression : arguments) { operands.add(generatorContext.generate(expression)); @@ -52,7 +54,7 @@ public BytecodeNode generateExpression(BytecodeGeneratorContext generatorContext Variable wasNull = generatorContext.wasNull(); BytecodeNode nullValue = new BytecodeBlock() .append(wasNull.set(constantTrue())) - .pushJavaDefault(returnType.getJavaType()); + .pushJavaDefault(returnJavaType); // reverse list because current if statement builder doesn't support if/else so we need to build the if statements bottom up for (BytecodeNode operand : operands.reversed()) { @@ -64,7 +66,7 @@ public BytecodeNode generateExpression(BytecodeGeneratorContext generatorContext // if value was null, pop the null value, clear the null flag, and process the next operand ifStatement.ifTrue() - .pop(returnType.getJavaType()) + .pop(returnJavaType) .append(wasNull.set(constantFalse())) .append(nullValue); diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/DereferenceCodeGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/DereferenceCodeGenerator.java index 1c10e6e25e95..76c74212d8bc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/DereferenceCodeGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/DereferenceCodeGenerator.java @@ -60,7 +60,7 @@ public BytecodeNode generateExpression(BytecodeGeneratorContext generator) IfStatement ifRowBlockIsNull = new IfStatement("if row block is null...") .condition(wasNull); - Class javaType = returnType.getJavaType(); + Class javaType = callSiteBinder.getAccessibleType(returnType.getJavaType()); LabelNode end = new LabelNode("end"); ifRowBlockIsNull.ifTrue() .comment("if row block is null, push null to the stack and goto 'end' label (return)") diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/InputReferenceCompiler.java b/core/trino-main/src/main/java/io/trino/sql/gen/InputReferenceCompiler.java index 21ba4f020c84..0448c65cf5bd 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/InputReferenceCompiler.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/InputReferenceCompiler.java @@ -68,8 +68,9 @@ private InputReferenceNode(CallSiteBinder callSiteBinder, Scope scope, Type type String methodName = "get" + Primitives.wrap(callType).getSimpleName(); BytecodeExpression value = constantType(callSiteBinder, type).invoke(methodName, callType, block, position); - if (callType != type.getJavaType()) { - value = value.cast(type.getJavaType()); + Class expectedType = callSiteBinder.getAccessibleType(type.getJavaType()); + if (callType != expectedType) { + value = value.cast(expectedType); } ifStatement.ifFalse(value); diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/IsNullCodeGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/IsNullCodeGenerator.java index 45d250e6255d..c525a4491bb9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/IsNullCodeGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/IsNullCodeGenerator.java @@ -49,7 +49,7 @@ public BytecodeNode generateExpression(BytecodeGeneratorContext generatorContext BytecodeBlock block = new BytecodeBlock() .comment("is null") .append(value) - .pop(argument.type().getJavaType()) + .pop(generatorContext.getCallSiteBinder().getAccessibleType(argument.type().getJavaType())) .append(wasNull); // clear the null flag diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/LambdaBytecodeGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/LambdaBytecodeGenerator.java index 243b1591096f..4c5de8fddeee 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/LambdaBytecodeGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/LambdaBytecodeGenerator.java @@ -127,7 +127,7 @@ public static CompiledLambda preGenerateLambdaExpression( parameters.add(arg("session", ConnectorSession.class)); for (int i = 0; i < lambdaExpression.arguments().size(); i++) { Symbol argument = lambdaExpression.arguments().get(i); - Class type = Primitives.wrap(argument.type().getJavaType()); + Class type = callSiteBinder.getAccessibleType(Primitives.wrap(argument.type().getJavaType())); String argumentName = argument.name(); Parameter arg = arg("lambda_" + i + "_" + BytecodeUtils.sanitizeName(argumentName), type); parameters.add(arg); @@ -153,6 +153,7 @@ public static CompiledLambda preGenerateLambdaExpression( classDefinition, methodName, parameters.build(), + callSiteBinder, lambdaExpression); } @@ -161,10 +162,11 @@ private static CompiledLambda defineLambdaMethod( ClassDefinition classDefinition, String methodName, List inputParameters, + CallSiteBinder callSiteBinder, Lambda lambda) { checkCondition(inputParameters.size() <= 254, NOT_SUPPORTED, "Too many arguments for lambda expression"); - Class returnType = Primitives.wrap(lambda.body().type().getJavaType()); + Class returnType = callSiteBinder.getAccessibleType(Primitives.wrap(lambda.body().type().getJavaType())); MethodDefinition method = classDefinition.declareMethod(a(PUBLIC), methodName, type(returnType), inputParameters); Scope scope = method.getScope(); @@ -209,7 +211,7 @@ public static BytecodeNode generateLambda( ImmutableList.Builder captureVariableBuilder = ImmutableList.builderWithExpectedSize(captureExpressions.size()); List captureTempVariables = new ArrayList<>(captureExpressions.size()); for (Expression captureExpression : captureExpressions) { - Class valueType = Primitives.wrap(captureExpression.type().getJavaType()); + Class valueType = context.getCallSiteBinder().getAccessibleType(Primitives.wrap(captureExpression.type().getJavaType())); Variable valueVariable = scope.getOrCreateTempVariable(valueType); captureTempVariables.add(valueVariable); block.append(context.generate(captureExpression)); @@ -296,7 +298,7 @@ public static Class> compileLambdaProvider(Lambda lam parameters.add(arg("session", ConnectorSession.class)); for (int i = 0; i < lambdaExpression.arguments().size(); i++) { Symbol argument = lambdaExpression.arguments().get(i); - Class type = Primitives.wrap(argument.type().getJavaType()); + Class type = callSiteBinder.getAccessibleType(Primitives.wrap(argument.type().getJavaType())); parameters.add(arg("lambda_" + i + "_" + BytecodeUtils.sanitizeName(argument.name()), type)); } diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/NullIfCodeGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/NullIfCodeGenerator.java index c00baedbf231..998416cab861 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/NullIfCodeGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/NullIfCodeGenerator.java @@ -66,16 +66,17 @@ private static Optional getCastIfNeeded(Metadata metadata, io. public BytecodeNode generateExpression(BytecodeGeneratorContext generatorContext) { Scope scope = generatorContext.getScope(); + Class firstJavaType = generatorContext.getCallSiteBinder().getAccessibleType(first.type().getJavaType()); LabelNode notMatch = new LabelNode("notMatch"); // push first arg on the stack - Variable firstValue = scope.getOrCreateTempVariable(first.type().getJavaType()); + Variable firstValue = scope.getOrCreateTempVariable(firstJavaType); BytecodeBlock block = new BytecodeBlock() .comment("check if first arg is null") .append(generatorContext.generate(first)) .append(ifWasNullPopAndGoto(scope, notMatch, void.class)) - .dup(first.type().getJavaType()) + .dup(firstJavaType) .putVariable(firstValue); BytecodeNode secondValue = generatorContext.generate(second); @@ -94,8 +95,8 @@ public BytecodeNode generateExpression(BytecodeGeneratorContext generatorContext // if first and second are equal, return null BytecodeBlock trueBlock = new BytecodeBlock() .append(generatorContext.wasNull().set(constantTrue())) - .pop(first.type().getJavaType()) - .pushJavaDefault(first.type().getJavaType()); + .pop(firstJavaType) + .pushJavaDefault(firstJavaType); // else return first (which is still on the stack block.append(new IfStatement() diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/OrCodeGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/OrCodeGenerator.java index 4c3d740b160d..c7b933ffc81a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/OrCodeGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/OrCodeGenerator.java @@ -58,7 +58,7 @@ public BytecodeNode generateExpression(BytecodeGeneratorContext generator) ifWasNull.ifTrue() .comment("clear the null flag, pop residual value off stack, and push was null flag on the stack (true)") - .pop(term.type().getJavaType()) // discard residual value + .pop(generator.getCallSiteBinder().getAccessibleType(term.type().getJavaType())) // discard residual value .pop(boolean.class) // discard the previous "we've seen a null flag" .push(true); diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/RowConstructorCodeGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/RowConstructorCodeGenerator.java index 8cd0013060ee..e70feffbdb88 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/RowConstructorCodeGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/RowConstructorCodeGenerator.java @@ -94,7 +94,7 @@ public BytecodeNode generateExpression(BytecodeGeneratorContext context) block.comment("Clean wasNull and Generate + " + i + "-th field of row"); block.append(context.wasNull().set(constantFalse())); block.append(context.generate(arguments.get(i))); - Variable field = scope.getOrCreateTempVariable(fieldType.getJavaType()); + Variable field = scope.getOrCreateTempVariable(binder.getAccessibleType(fieldType.getJavaType())); block.putVariable(field); block.append(new IfStatement() .condition(context.wasNull()) @@ -181,7 +181,7 @@ private MethodDefinition generatePartialRowConstructor(int start, int end, Bytec block.append(context.wasNull().set(constantFalse())); block.append(context.generate(arguments.get(i))); - Variable field = scope.getOrCreateTempVariable(fieldType.getJavaType()); + Variable field = scope.getOrCreateTempVariable(binder.getAccessibleType(fieldType.getJavaType())); block.putVariable(field); block.append(new IfStatement() .condition(context.wasNull()) diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/SqlTypeBytecodeExpression.java b/core/trino-main/src/main/java/io/trino/sql/gen/SqlTypeBytecodeExpression.java index 33a577f887e1..68031abfbd16 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/SqlTypeBytecodeExpression.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/SqlTypeBytecodeExpression.java @@ -37,7 +37,7 @@ public static SqlTypeBytecodeExpression constantType(CallSiteBinder callSiteBind requireNonNull(type, "type is null"); Binding binding = callSiteBinder.bind(type, Type.class); - return new SqlTypeBytecodeExpression(type, binding, BOOTSTRAP_METHOD); + return new SqlTypeBytecodeExpression(type, callSiteBinder.getAccessibleType(type.getJavaType()), binding, BOOTSTRAP_METHOD); } private static String generateName(Type type) @@ -51,13 +51,15 @@ private static String generateName(Type type) } private final Type type; + private final Class accessibleJavaElementType; private final Binding binding; private final Method bootstrapMethod; - private SqlTypeBytecodeExpression(Type type, Binding binding, Method bootstrapMethod) + private SqlTypeBytecodeExpression(Type type, Class accessibleJavaElementType, Binding binding, Method bootstrapMethod) { super(type(Type.class)); this.type = requireNonNull(type, "type is null"); + this.accessibleJavaElementType = requireNonNull(accessibleJavaElementType, "accessibleJavaElementType is null"); this.binding = requireNonNull(binding, "binding is null"); this.bootstrapMethod = requireNonNull(bootstrapMethod, "bootstrapMethod is null"); } @@ -96,7 +98,7 @@ public BytecodeExpression getValue(BytecodeExpression block, BytecodeExpression if (fromJavaElementType == Slice.class) { return invoke("getSlice", Slice.class, block, position); } - return invoke("getObject", Object.class, block, position).cast(fromJavaElementType); + return invoke("getObject", Object.class, block, position).cast(accessibleJavaElementType); } public BytecodeExpression writeValue(BytecodeExpression blockBuilder, BytecodeExpression value) @@ -104,13 +106,13 @@ public BytecodeExpression writeValue(BytecodeExpression blockBuilder, BytecodeEx Class fromJavaElementType = type.getJavaType(); if (fromJavaElementType == boolean.class) { - return invoke("writeBoolean", void.class, blockBuilder, value); + return invoke("writeBoolean", void.class, blockBuilder, value.cast(boolean.class)); } if (fromJavaElementType == long.class) { - return invoke("writeLong", void.class, blockBuilder, value); + return invoke("writeLong", void.class, blockBuilder, value.cast(long.class)); } if (fromJavaElementType == double.class) { - return invoke("writeDouble", void.class, blockBuilder, value); + return invoke("writeDouble", void.class, blockBuilder, value.cast(double.class)); } if (fromJavaElementType == Slice.class) { return invoke("writeSlice", void.class, blockBuilder, value.cast(Slice.class)); diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/SwitchCodeGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/SwitchCodeGenerator.java index 185e78fabba0..b4680c50227d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/SwitchCodeGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/SwitchCodeGenerator.java @@ -88,6 +88,7 @@ else if ( == ) { */ Scope scope = generatorContext.getScope(); + CallSiteBinder callSiteBinder = generatorContext.getCallSiteBinder(); // process value, else, and all when clauses BytecodeNode valueBytecode = generatorContext.generate(value); @@ -95,7 +96,7 @@ else if ( == ) { BytecodeNode elseValue = generatorContext.generate(defaultValue); // determine the type of the value and result - Class valueType = value.type().getJavaType(); + Class valueType = callSiteBinder.getAccessibleType(value.type().getJavaType()); // evaluate the value and store it in a variable LabelNode nullValue = new LabelNode("nullCondition"); diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/TestPageFunctionCompiler.java b/core/trino-main/src/test/java/io/trino/sql/gen/TestPageFunctionCompiler.java index e941ca7c804e..b1527392bd83 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/TestPageFunctionCompiler.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/TestPageFunctionCompiler.java @@ -31,8 +31,10 @@ import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction; import io.trino.operator.scalar.SpecializedSqlScalarFunction; import io.trino.spi.Page; +import io.trino.spi.block.ArrayBlockBuilder; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.SourcePage; @@ -41,19 +43,28 @@ import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; import io.trino.spi.function.Signature; import io.trino.spi.type.AbstractVariableWidthType; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; import io.trino.spi.type.Type; +import io.trino.spi.type.TypeOperators; import io.trino.spi.type.TypeSignature; import io.trino.sql.PlannerContext; import io.trino.sql.ir.Call; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.FieldReference; +import io.trino.sql.ir.Lambda; import io.trino.sql.ir.Reference; +import io.trino.sql.ir.Row; import io.trino.sql.planner.Symbol; import io.trino.transaction.TransactionManager; +import io.trino.type.FunctionType; import org.junit.jupiter.api.Test; import java.lang.invoke.MethodHandle; +import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.util.List; import java.util.Map; @@ -68,11 +79,13 @@ import static io.airlift.bytecode.ParameterizedType.type; import static io.airlift.slice.Slices.allocate; import static io.trino.block.BlockAssertions.createRepeatedValuesBlock; +import static io.trino.operator.scalar.ArrayTransformFunction.ARRAY_TRANSFORM_NAME; import static io.trino.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.OperatorType.ADD; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.ir.IrExpressions.call; @@ -92,6 +105,7 @@ public class TestPageFunctionCompiler { + private static final TypeOperators TYPE_OPERATORS = new TypeOperators(); private static final TestingFunctionResolution FUNCTION_RESOLUTION = new TestingFunctionResolution(); private static final Map LAYOUT = ImmutableMap.of(new Symbol(BIGINT, "$col_0"), 2); private static final Call ADD_10_EXPRESSION = call( @@ -152,6 +166,128 @@ public void testProjectionWithPrivateJavaType() assertThat(hiddenType.getObjectValue(result, 1)).isEqualTo(42); } + @Test + public void testTransformWithPrivateJavaType() + { + HiddenFunctions hiddenFunctions = createHiddenFunctions(); + Type hiddenType = hiddenFunctions.type(); + ArrayType inputArrayType = new ArrayType(BIGINT); + ArrayType outputArrayType = new ArrayType(hiddenType); + TestingFunctionResolution functionResolution = createFunctionResolution( + hiddenType, + new HiddenFunction("test_hidden_constructor", hiddenType, hiddenFunctions.constructor(), ImmutableList.of())); + + ResolvedFunction constructor = functionResolution.resolveFunction("test_hidden_constructor", fromTypes()); + ResolvedFunction transform = functionResolution.resolveFunction(ARRAY_TRANSFORM_NAME, fromTypes(inputArrayType, new FunctionType(ImmutableList.of(BIGINT), hiddenType))); + PageProjection projection = functionResolution.getPageFunctionCompiler() + .compileProjection( + call(transform, + new Reference(inputArrayType, "$col_0"), + new Lambda(ImmutableList.of(new Symbol(BIGINT, "x")), call(constructor))), + ImmutableMap.of(new Symbol(inputArrayType, "$col_0"), 0), + Optional.empty()) + .get(); + + Page page = createSingleArrayPage(inputArrayType, 1); + Block result = project(projection, page, SelectedPositions.positionsRange(0, page.getPositionCount())); + assertThat(outputArrayType.getObjectValue(result, 0)).isEqualTo(ImmutableList.of(42)); + } + + @Test + public void testTransformValuesWithPrivateJavaType() + { + HiddenFunctions hiddenFunctions = createHiddenFunctions(); + Type hiddenType = hiddenFunctions.type(); + MapType inputMapType = new MapType(BIGINT, BIGINT, TYPE_OPERATORS); + MapType outputMapType = new MapType(BIGINT, hiddenType, TYPE_OPERATORS); + TestingFunctionResolution functionResolution = createFunctionResolution( + hiddenType, + new HiddenFunction("test_hidden_constructor", hiddenType, hiddenFunctions.constructor(), ImmutableList.of())); + + ResolvedFunction constructor = functionResolution.resolveFunction("test_hidden_constructor", fromTypes()); + ResolvedFunction transformValues = functionResolution.resolveFunction("transform_values", fromTypes(inputMapType, new FunctionType(ImmutableList.of(BIGINT, BIGINT), hiddenType))); + PageProjection projection = functionResolution.getPageFunctionCompiler() + .compileProjection( + call(transformValues, + new Reference(inputMapType, "$col_0"), + new Lambda(ImmutableList.of(new Symbol(BIGINT, "k"), new Symbol(BIGINT, "v")), call(constructor))), + ImmutableMap.of(new Symbol(inputMapType, "$col_0"), 0), + Optional.empty()) + .get(); + + Page page = createSingleLongMapPage(inputMapType, 1, 11); + Block result = project(projection, page, SelectedPositions.positionsRange(0, page.getPositionCount())); + assertThat(outputMapType.getObjectValue(result, 0)).isEqualTo(ImmutableMap.of(1L, 42)); + } + + @Test + public void testTransformKeysWithPrivateJavaType() + { + HiddenFunctions hiddenFunctions = createHiddenFunctions(); + Type hiddenType = hiddenFunctions.type(); + MapType mapType = new MapType(BIGINT, hiddenType, TYPE_OPERATORS); + TestingFunctionResolution functionResolution = createFunctionResolution(hiddenType); + + ResolvedFunction transformKeys = functionResolution.resolveFunction("transform_keys", fromTypes(mapType, new FunctionType(ImmutableList.of(BIGINT, hiddenType), BIGINT))); + PageProjection projection = functionResolution.getPageFunctionCompiler() + .compileProjection( + call(transformKeys, + new Reference(mapType, "$col_0"), + new Lambda(ImmutableList.of(new Symbol(BIGINT, "k"), new Symbol(hiddenType, "v")), new Reference(BIGINT, "k"))), + ImmutableMap.of(new Symbol(mapType, "$col_0"), 0), + Optional.empty()) + .get(); + + Page page = createSingleHiddenValueMapPage(mapType, hiddenType, 1, createHiddenValue(hiddenFunctions)); + Block result = project(projection, page, SelectedPositions.positionsRange(0, page.getPositionCount())); + assertThat(mapType.getObjectValue(result, 0)).isEqualTo(ImmutableMap.of(1L, 42)); + } + + @Test + public void testMapFilterWithPrivateJavaType() + { + HiddenFunctions hiddenFunctions = createHiddenFunctions(); + Type hiddenType = hiddenFunctions.type(); + MapType mapType = new MapType(BIGINT, hiddenType, TYPE_OPERATORS); + TestingFunctionResolution functionResolution = createFunctionResolution(hiddenType); + + ResolvedFunction mapFilter = functionResolution.resolveFunction("map_filter", fromTypes(mapType, new FunctionType(ImmutableList.of(BIGINT, hiddenType), BOOLEAN))); + PageProjection projection = functionResolution.getPageFunctionCompiler() + .compileProjection( + call(mapFilter, + new Reference(mapType, "$col_0"), + new Lambda(ImmutableList.of(new Symbol(BIGINT, "k"), new Symbol(hiddenType, "v")), new Constant(BOOLEAN, true))), + ImmutableMap.of(new Symbol(mapType, "$col_0"), 0), + Optional.empty()) + .get(); + + Page page = createSingleHiddenValueMapPage(mapType, hiddenType, 1, createHiddenValue(hiddenFunctions)); + Block result = project(projection, page, SelectedPositions.positionsRange(0, page.getPositionCount())); + assertThat(mapType.getObjectValue(result, 0)).isEqualTo(ImmutableMap.of(1L, 42)); + } + + @Test + public void testRowConstructorAndDereferenceWithPrivateJavaType() + { + HiddenFunctions hiddenFunctions = createHiddenFunctions(); + Type hiddenType = hiddenFunctions.type(); + RowType rowType = RowType.anonymous(ImmutableList.of(hiddenType)); + TestingFunctionResolution functionResolution = createFunctionResolution( + hiddenType, + new HiddenFunction("test_hidden_constructor", hiddenType, hiddenFunctions.constructor(), ImmutableList.of())); + + ResolvedFunction constructor = functionResolution.resolveFunction("test_hidden_constructor", fromTypes()); + Expression row = new Row(ImmutableList.of(call(constructor)), rowType); + Expression dereference = new FieldReference(row, 0); + PageProjection projection = functionResolution.getPageFunctionCompiler() + .compileProjection(dereference, ImmutableMap.of(), Optional.empty()) + .get(); + + Page page = createLongBlockPage(0); + Block result = project(projection, page, SelectedPositions.positionsRange(0, page.getPositionCount())); + assertThat(hiddenType.getObjectValue(result, 0)).isEqualTo(42); + } + @Test public void testProjectionCache() { @@ -305,6 +441,62 @@ private static Page createPageWithDataAtChannel2(long... values) return new Page(createRepeatedValuesBlock(0L, values.length), createRepeatedValuesBlock(0L, values.length), data.getBlock(0)); } + private static Page createSingleArrayPage(ArrayType arrayType, long... values) + { + ArrayBlockBuilder builder = arrayType.createBlockBuilder(null, 1); + builder.buildEntry(elementBuilder -> { + for (long value : values) { + BIGINT.writeLong(elementBuilder, value); + } + }); + return new Page(builder.build()); + } + + private static Page createSingleLongMapPage(MapType mapType, long key, long value) + { + MapBlockBuilder builder = mapType.createBlockBuilder(null, 1); + builder.buildEntry((keyBuilder, valueBuilder) -> { + BIGINT.writeLong(keyBuilder, key); + BIGINT.writeLong(valueBuilder, value); + }); + return new Page(builder.build()); + } + + private static Page createSingleHiddenValueMapPage(MapType mapType, Type hiddenType, long key, Object value) + { + MapBlockBuilder builder = mapType.createBlockBuilder(null, 1); + builder.buildEntry((keyBuilder, valueBuilder) -> { + BIGINT.writeLong(keyBuilder, key); + hiddenType.writeObject(valueBuilder, value); + }); + return new Page(builder.build()); + } + + private static TestingFunctionResolution createFunctionResolution(Type hiddenType, SqlScalarFunction... functions) + { + TransactionManager transactionManager = createTestTransactionManager(); + InternalFunctionBundle.InternalFunctionBundleBuilder functionBundle = InternalFunctionBundle.builder(); + for (SqlScalarFunction function : functions) { + functionBundle.function(function); + } + PlannerContext plannerContext = plannerContextBuilder() + .withTransactionManager(transactionManager) + .addType(hiddenType) + .addFunctions(functionBundle.build()) + .build(); + return new TestingFunctionResolution(transactionManager, plannerContext); + } + + private static Object createHiddenValue(HiddenFunctions hiddenFunctions) + { + try { + return hiddenFunctions.constructor().invoke(); + } + catch (Throwable e) { + throw new RuntimeException(e); + } + } + private static HiddenFunctions createHiddenFunctions() { ClassDefinition classDefinition = new ClassDefinition( @@ -372,11 +564,18 @@ private static final class HiddenType { private static final TypeSignature TYPE_SIGNATURE = new TypeSignature("test_hidden"); + private final Constructor constructor; private final Field valueField; private HiddenType(Class javaType) { super(TYPE_SIGNATURE, javaType); + try { + constructor = javaType.getConstructor(int.class); + } + catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } valueField = field(javaType, "value"); } @@ -395,6 +594,17 @@ public Object getObjectValue(Block block, int position) return getSlice(block, position).getInt(0); } + @Override + public Object getObject(Block block, int position) + { + try { + return constructor.newInstance(getSlice(block, position).getInt(0)); + } + catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + } + @Override public Slice getSlice(Block block, int position) {