diff --git a/core/trino-main/src/main/java/io/trino/operator/project/GeneratedPageProjection.java b/core/trino-main/src/main/java/io/trino/operator/project/GeneratedPageProjection.java index 0e93056d876c..afe1079fd869 100644 --- a/core/trino-main/src/main/java/io/trino/operator/project/GeneratedPageProjection.java +++ b/core/trino-main/src/main/java/io/trino/operator/project/GeneratedPageProjection.java @@ -20,7 +20,10 @@ import io.trino.sql.gen.PageProjectionWork; import io.trino.sql.relational.RowExpression; +import java.lang.invoke.MethodHandle; + import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Throwables.throwIfUnchecked; import static java.util.Objects.requireNonNull; public class GeneratedPageProjection @@ -29,16 +32,16 @@ public class GeneratedPageProjection private final RowExpression projection; private final boolean isDeterministic; private final InputChannels inputChannels; - private final PageProjectionWork pageProjectionWork; + private final MethodHandle pageProjectionWorkFactory; private BlockBuilder blockBuilder; - public GeneratedPageProjection(RowExpression projection, boolean isDeterministic, InputChannels inputChannels, PageProjectionWork pageProjectionWork) + public GeneratedPageProjection(RowExpression projection, boolean isDeterministic, InputChannels inputChannels, MethodHandle pageProjectionWorkFactory) { this.projection = requireNonNull(projection, "projection is null"); this.isDeterministic = isDeterministic; this.inputChannels = requireNonNull(inputChannels, "inputChannels is null"); - this.pageProjectionWork = requireNonNull(pageProjectionWork, "pageProjectionWork is null"); + this.pageProjectionWorkFactory = requireNonNull(pageProjectionWorkFactory, "pageProjectionWorkFactory is null"); this.blockBuilder = projection.type().createBlockBuilder(null, 1); } @@ -58,7 +61,12 @@ public InputChannels getInputChannels() public Block project(ConnectorSession session, SourcePage page, SelectedPositions selectedPositions) { blockBuilder = blockBuilder.newBlockBuilderLike(selectedPositions.size(), null); - return pageProjectionWork.process(session, page, selectedPositions, blockBuilder); + try { + return ((PageProjectionWork) pageProjectionWorkFactory.invoke(blockBuilder, session, page, selectedPositions)).process(); + } + catch (Throwable throwable) { + throw propagate(throwable); + } } @Override @@ -68,4 +76,13 @@ public String toString() .add("projection", projection) .toString(); } + + private static RuntimeException propagate(Throwable throwable) + { + if (throwable instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throwIfUnchecked(throwable); + throw new RuntimeException(throwable); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/PageFunctionCompiler.java b/core/trino-main/src/main/java/io/trino/sql/gen/PageFunctionCompiler.java index dbc5049a1ecf..7d3ca184cd25 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/PageFunctionCompiler.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/PageFunctionCompiler.java @@ -61,7 +61,7 @@ import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import java.lang.reflect.Constructor; +import java.lang.invoke.MethodHandle; import java.util.List; import java.util.Map; import java.util.Optional; @@ -72,7 +72,6 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Throwables.throwIfInstanceOf; -import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.bytecode.Access.FINAL; import static io.airlift.bytecode.Access.PRIVATE; import static io.airlift.bytecode.Access.PUBLIC; @@ -98,6 +97,7 @@ import static io.trino.sql.relational.DeterminismEvaluator.isDeterministic; import static io.trino.util.CompilerUtils.defineClass; import static io.trino.util.CompilerUtils.makeClassName; +import static io.trino.util.Reflection.constructorMethodHandle; import static java.util.Objects.requireNonNull; public class PageFunctionCompiler @@ -198,10 +198,9 @@ private Supplier compileProjectionInternal(RowExpression project ClassDefinition pageProjectionWorkDefinition = definePageProjectWorkClass(result.getRewrittenExpression(), callSiteBinder, classNameSuffix); - Constructor pageProjectionWorkConstructor; + Class pageProjectionWorkClass; try { - Class pageProjectionWorkClass = defineClass(pageProjectionWorkDefinition, PageProjectionWork.class, callSiteBinder.getBindings(), getClass().getClassLoader()); - pageProjectionWorkConstructor = pageProjectionWorkClass.getConstructor(); + pageProjectionWorkClass = defineClass(pageProjectionWorkDefinition, PageProjectionWork.class, callSiteBinder.getBindings(), getClass().getClassLoader()); } catch (Exception e) { if (Throwables.getRootCause(e) instanceof MethodTooLargeException) { @@ -211,19 +210,12 @@ private Supplier compileProjectionInternal(RowExpression project throw new TrinoException(COMPILER_ERROR, e); } - return () -> { - try { - PageProjectionWork pageProjectionWork = pageProjectionWorkConstructor.newInstance(); - return new GeneratedPageProjection( - result.getRewrittenExpression(), - isExpressionDeterministic, - result.getInputChannels(), - pageProjectionWork); - } - catch (ReflectiveOperationException e) { - throw new TrinoException(COMPILER_ERROR, e); - } - }; + MethodHandle pageProjectionConstructor = constructorMethodHandle(pageProjectionWorkClass, BlockBuilder.class, ConnectorSession.class, SourcePage.class, SelectedPositions.class); + return () -> new GeneratedPageProjection( + result.getRewrittenExpression(), + isExpressionDeterministic, + result.getInputChannels(), + pageProjectionConstructor); } private static ParameterizedType generateProjectionWorkClassName(Optional classNameSuffix) @@ -240,29 +232,40 @@ private ClassDefinition definePageProjectWorkClass(RowExpression projection, Cal type(PageProjectionWork.class)); FieldDefinition blockBuilderField = classDefinition.declareField(a(PRIVATE), "blockBuilder", BlockBuilder.class); + FieldDefinition sessionField = classDefinition.declareField(a(PRIVATE), "session", ConnectorSession.class); + FieldDefinition selectedPositionsField = classDefinition.declareField(a(PRIVATE), "selectedPositions", SelectedPositions.class); CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder); - List inputChannels = getInputChannels(projection); - List blockFields = inputChannels.stream() - .map(channel -> classDefinition.declareField(a(PRIVATE), "block_" + channel, Block.class)) - .collect(toImmutableList()); // process - generateProcessMethod(classDefinition, blockBuilderField, blockFields, inputChannels); + generateProcessMethod(classDefinition, blockBuilderField, sessionField, selectedPositionsField); // evaluate Map compiledLambdaMap = generateMethodsForLambda(classDefinition, callSiteBinder, cachedInstanceBinder, projection); generateEvaluateMethod(classDefinition, callSiteBinder, cachedInstanceBinder, compiledLambdaMap, projection, blockBuilderField); // constructor - MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC)); + Parameter blockBuilder = arg("blockBuilder", BlockBuilder.class); + Parameter session = arg("session", ConnectorSession.class); + Parameter page = arg("page", SourcePage.class); + Parameter selectedPositions = arg("selectedPositions", SelectedPositions.class); + + MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC), blockBuilder, session, page, selectedPositions); BytecodeBlock body = constructorDefinition.getBody(); Variable thisVariable = constructorDefinition.getThis(); body.comment("super();") .append(thisVariable) - .invokeConstructor(Object.class); + .invokeConstructor(Object.class) + .append(thisVariable.setField(blockBuilderField, blockBuilder)) + .append(thisVariable.setField(sessionField, session)) + .append(thisVariable.setField(selectedPositionsField, selectedPositions)); + + for (int channel : getInputChannels(projection)) { + FieldDefinition blockField = classDefinition.declareField(a(PRIVATE, FINAL), "block_" + channel, Block.class); + body.append(thisVariable.setField(blockField, page.invoke("getBlock", Block.class, constantInt(channel)))); + } cachedInstanceBinder.generateInitializations(thisVariable, body); body.ret(); @@ -270,66 +273,48 @@ private ClassDefinition definePageProjectWorkClass(RowExpression projection, Cal return classDefinition; } - private static void generateProcessMethod( + private static MethodDefinition generateProcessMethod( ClassDefinition classDefinition, - FieldDefinition blockBuilderField, - List blockFields, - List inputChannels) + FieldDefinition blockBuilder, + FieldDefinition session, + FieldDefinition selectedPositions) { - Parameter session = arg("session", ConnectorSession.class); - Parameter page = arg("page", SourcePage.class); - Parameter selectedPositions = arg("selectedPositions", SelectedPositions.class); - Parameter blockBuilder = arg("blockBuilder", BlockBuilder.class); - - MethodDefinition method = classDefinition.declareMethod( - a(PUBLIC), - "process", - type(Block.class), - ImmutableList.builder() - .add(session) - .add(page) - .add(selectedPositions) - .add(blockBuilder) - .build()); + MethodDefinition method = classDefinition.declareMethod(a(PUBLIC), "process", type(Block.class), ImmutableList.of()); Scope scope = method.getScope(); Variable thisVariable = method.getThis(); BytecodeBlock body = method.getBody(); - for (int i = 0; i < inputChannels.size(); i++) { - int channel = inputChannels.get(i); - body.append(thisVariable.setField(blockFields.get(i), page.invoke("getBlock", Block.class, constantInt(channel)))); - } - body.append(thisVariable.setField(blockBuilderField, blockBuilder)); - - Variable from = scope.declareVariable("from", body, selectedPositions.invoke("getOffset", int.class)); - Variable to = scope.declareVariable("to", body, add(selectedPositions.invoke("getOffset", int.class), selectedPositions.invoke("size", int.class))); + Variable from = scope.declareVariable("from", body, thisVariable.getField(selectedPositions).invoke("getOffset", int.class)); + Variable to = scope.declareVariable("to", body, add(thisVariable.getField(selectedPositions).invoke("getOffset", int.class), thisVariable.getField(selectedPositions).invoke("size", int.class))); Variable positions = scope.declareVariable(int[].class, "positions"); Variable index = scope.declareVariable(int.class, "index"); IfStatement ifStatement = new IfStatement() - .condition(selectedPositions.invoke("isList", boolean.class)); + .condition(thisVariable.getField(selectedPositions).invoke("isList", boolean.class)); body.append(ifStatement); ifStatement.ifTrue(new BytecodeBlock() - .append(positions.set(selectedPositions.invoke("getPositions", int[].class))) + .append(positions.set(thisVariable.getField(selectedPositions).invoke("getPositions", int[].class))) .append(new ForLoop("positions loop") .initialize(index.set(from)) .condition(lessThan(index, to)) .update(index.increment()) .body(new BytecodeBlock() - .append(thisVariable.invoke("evaluate", void.class, session, positions.getElement(index)))))); + .append(thisVariable.invoke("evaluate", void.class, thisVariable.getField(session), positions.getElement(index)))))); ifStatement.ifFalse(new ForLoop("range based loop") .initialize(index.set(from)) .condition(lessThan(index, to)) .update(index.increment()) .body(new BytecodeBlock() - .append(thisVariable.invoke("evaluate", void.class, session, index)))); + .append(thisVariable.invoke("evaluate", void.class, thisVariable.getField(session), index)))); body.comment("return this.blockBuilder.build();") - .append(thisVariable.getField(blockBuilderField).invoke("build", Block.class)) + .append(thisVariable.getField(blockBuilder).invoke("build", Block.class)) .retObject(); + + return method; } private MethodDefinition generateEvaluateMethod( diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/PageProjectionWork.java b/core/trino-main/src/main/java/io/trino/sql/gen/PageProjectionWork.java index 91e684b1f8f5..05b57affde3f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/PageProjectionWork.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/PageProjectionWork.java @@ -13,13 +13,9 @@ */ package io.trino.sql.gen; -import io.trino.operator.project.SelectedPositions; import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; -import io.trino.spi.connector.ConnectorSession; -import io.trino.spi.connector.SourcePage; public interface PageProjectionWork { - Block process(ConnectorSession session, SourcePage page, SelectedPositions selectedPositions, BlockBuilder builder); + Block process(); }