diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java index 2eae1525e16e..92ca087ab96d 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java @@ -23,6 +23,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.connector.CatalogHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.AggregationImplementation; @@ -240,6 +241,11 @@ private static void verifyMethodHandleSignature(BoundSignature boundSignature, S verifyFunctionSignature(parameterType.equals(Block.class) && methodType.parameterType(parameterIndex + 1).equals(int.class), "Expected %s argument types to be Block and int".formatted(argumentConvention)); break; + case VALUE_BLOCK_POSITION: + case VALUE_BLOCK_POSITION_NOT_NULL: + verifyFunctionSignature(ValueBlock.class.isAssignableFrom(parameterType) && methodType.parameterType(parameterIndex + 1).equals(int.class), + "Expected %s argument types to be ValueBlock and int".formatted(argumentConvention)); + break; case FLAT: verifyFunctionSignature(parameterType.equals(byte[].class) && methodType.parameterType(parameterIndex + 1).equals(int.class) && diff --git a/core/trino-main/src/main/java/io/trino/operator/BucketPartitionFunction.java b/core/trino-main/src/main/java/io/trino/operator/BucketPartitionFunction.java index bd5ba587f35b..37676c4f9ebc 100644 --- a/core/trino-main/src/main/java/io/trino/operator/BucketPartitionFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/BucketPartitionFunction.java @@ -35,7 +35,7 @@ public BucketPartitionFunction(BucketFunction bucketFunction, int[] bucketToPart } @Override - public int getPartitionCount() + public int partitionCount() { return partitionCount; } diff --git a/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java b/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java index ee290ea6981c..8db25801a8b0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java +++ b/core/trino-main/src/main/java/io/trino/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java @@ -68,7 +68,11 @@ public Page transformPage(Page inputPage) checkArgument(positionCount > 0, "positionCount should be > 0, but is %s", positionCount); ColumnarRow mergeRow = toColumnarRow(inputPage.getBlock(mergeRowChannel)); - checkArgument(!mergeRow.mayHaveNull(), "The mergeRow may not have null rows"); + if (mergeRow.mayHaveNull()) { + for (int position = 0; position < positionCount; position++) { + checkArgument(!mergeRow.isNull(position), "The mergeRow may not have null rows"); + } + } // We've verified that the mergeRow block has no null rows, so it's okay to get the field blocks diff --git a/core/trino-main/src/main/java/io/trino/operator/JoinDomainBuilder.java b/core/trino-main/src/main/java/io/trino/operator/JoinDomainBuilder.java index c847fbe6facd..77a3f5789ab0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/JoinDomainBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/JoinDomainBuilder.java @@ -18,6 +18,9 @@ import io.airlift.units.DataSize; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.ValueSet; import io.trino.spi.type.Type; @@ -32,8 +35,8 @@ import static io.trino.operator.VariableWidthData.EMPTY_CHUNK; import static io.trino.operator.VariableWidthData.POINTER_SIZE; import static io.trino.spi.StandardErrorCode.GENERIC_INSUFFICIENT_RESOURCES; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; @@ -88,8 +91,8 @@ public class JoinDomainBuilder private int distinctSize; private int distinctMaxFill; - private Block minValue; - private Block maxValue; + private ValueBlock minValue; + private ValueBlock maxValue; private boolean collectDistinctValues = true; private boolean collectMinMax; @@ -116,15 +119,15 @@ public JoinDomainBuilder( MethodHandle readOperator = typeOperators.getReadValueOperator(type, simpleConvention(NULLABLE_RETURN, FLAT)); readOperator = readOperator.asType(readOperator.type().changeReturnType(Object.class)); this.readFlat = readOperator; - this.writeFlat = typeOperators.getReadValueOperator(type, simpleConvention(FLAT_RETURN, BLOCK_POSITION_NOT_NULL)); + this.writeFlat = typeOperators.getReadValueOperator(type, simpleConvention(FLAT_RETURN, VALUE_BLOCK_POSITION_NOT_NULL)); this.hashFlat = typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, FLAT)); - this.hashBlock = typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)); + this.hashBlock = typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); this.distinctFlatFlat = typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, FLAT)); - this.distinctFlatBlock = typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION_NOT_NULL)); + this.distinctFlatBlock = typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, VALUE_BLOCK_POSITION_NOT_NULL)); if (collectMinMax) { this.compareFlatFlat = typeOperators.getComparisonUnorderedLastOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, FLAT)); - this.compareBlockBlock = typeOperators.getComparisonUnorderedLastOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL, BLOCK_POSITION_NOT_NULL)); + this.compareBlockBlock = typeOperators.getComparisonUnorderedLastOperator(type, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); } else { this.compareFlatFlat = null; @@ -157,9 +160,24 @@ public boolean isCollecting() public void add(Block block) { + block = block.getLoadedBlock(); if (collectDistinctValues) { - for (int position = 0; position < block.getPositionCount(); ++position) { - add(block, position); + if (block instanceof ValueBlock valueBlock) { + for (int position = 0; position < block.getPositionCount(); position++) { + add(valueBlock, position); + } + } + else if (block instanceof RunLengthEncodedBlock rleBlock) { + add(rleBlock.getValue(), 0); + } + else if (block instanceof DictionaryBlock dictionaryBlock) { + ValueBlock dictionary = dictionaryBlock.getDictionary(); + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + add(dictionary, dictionaryBlock.getId(i)); + } + } + else { + throw new IllegalArgumentException("Unsupported block type: " + block.getClass().getSimpleName()); } // if the distinct size is too large, fall back to min max, and drop the distinct values @@ -207,8 +225,10 @@ else if (collectMinMax) { int minValuePosition = -1; int maxValuePosition = -1; - for (int position = 0; position < block.getPositionCount(); ++position) { - if (block.isNull(position)) { + ValueBlock valueBlock = block.getUnderlyingValueBlock(); + for (int i = 0; i < block.getPositionCount(); i++) { + int position = block.getUnderlyingValuePosition(i); + if (valueBlock.isNull(position)) { continue; } if (minValuePosition == -1) { @@ -217,10 +237,10 @@ else if (collectMinMax) { maxValuePosition = position; continue; } - if (valueCompare(block, position, block, minValuePosition) < 0) { + if (valueCompare(valueBlock, position, valueBlock, minValuePosition) < 0) { minValuePosition = position; } - else if (valueCompare(block, position, block, maxValuePosition) > 0) { + else if (valueCompare(valueBlock, position, valueBlock, maxValuePosition) > 0) { maxValuePosition = position; } } @@ -231,18 +251,18 @@ else if (valueCompare(block, position, block, maxValuePosition) > 0) { } if (minValue == null) { - minValue = block.getSingleValueBlock(minValuePosition); - maxValue = block.getSingleValueBlock(maxValuePosition); + minValue = valueBlock.getSingleValueBlock(minValuePosition); + maxValue = valueBlock.getSingleValueBlock(maxValuePosition); return; } - if (valueCompare(block, minValuePosition, minValue, 0) < 0) { + if (valueCompare(valueBlock, minValuePosition, minValue, 0) < 0) { retainedSizeInBytes -= minValue.getRetainedSizeInBytes(); - minValue = block.getSingleValueBlock(minValuePosition); + minValue = valueBlock.getSingleValueBlock(minValuePosition); retainedSizeInBytes += minValue.getRetainedSizeInBytes(); } - if (valueCompare(block, maxValuePosition, maxValue, 0) > 0) { + if (valueCompare(valueBlock, maxValuePosition, maxValue, 0) > 0) { retainedSizeInBytes -= maxValue.getRetainedSizeInBytes(); - maxValue = block.getSingleValueBlock(maxValuePosition); + maxValue = valueBlock.getSingleValueBlock(maxValuePosition); retainedSizeInBytes += maxValue.getRetainedSizeInBytes(); } } @@ -289,7 +309,7 @@ public Domain build() return Domain.all(type); } - private void add(Block block, int position) + private void add(ValueBlock block, int position) { // Inner and right join doesn't match rows with null key column values. if (block.isNull(position)) { @@ -343,7 +363,7 @@ private int matchInVector(byte[] otherValues, VariableWidthData otherVariableWid return -1; } - private int matchInVector(Block block, int position, int vectorStartBucket, long repeated, long controlVector) + private int matchInVector(ValueBlock block, int position, int vectorStartBucket, long repeated, long controlVector) { long controlMatches = match(controlVector, repeated); while (controlMatches != 0) { @@ -367,7 +387,7 @@ private int findEmptyInVector(long vector, int vectorStartBucket) return bucket(vectorStartBucket + slot); } - private void insert(int index, Block block, int position, byte hashPrefix) + private void insert(int index, ValueBlock block, int position, byte hashPrefix) { setControl(index, hashPrefix); @@ -512,7 +532,7 @@ private Object readValueToObject(int position) } } - private Block readValueToBlock(int position) + private ValueBlock readValueToBlock(int position) { return writeNativeValue(type, readValueToObject(position)); } @@ -538,7 +558,7 @@ private long valueHashCode(byte[] values, int position) } } - private long valueHashCode(Block right, int rightPosition) + private long valueHashCode(ValueBlock right, int rightPosition) { try { return (long) hashBlock.invokeExact(right, rightPosition); @@ -549,7 +569,7 @@ private long valueHashCode(Block right, int rightPosition) } } - private boolean valueNotDistinctFrom(int leftPosition, Block right, int rightPosition) + private boolean valueNotDistinctFrom(int leftPosition, ValueBlock right, int rightPosition) { byte[] leftFixedRecordChunk = distinctRecords; int leftRecordOffset = getRecordOffset(leftPosition); @@ -603,7 +623,7 @@ private boolean valueNotDistinctFrom(int leftPosition, byte[] rightValues, Varia } } - private int valueCompare(Block left, int leftPosition, Block right, int rightPosition) + private int valueCompare(ValueBlock left, int leftPosition, ValueBlock right, int rightPosition) { try { return (int) (long) compareBlockBlock.invokeExact( diff --git a/core/trino-main/src/main/java/io/trino/operator/PartitionFunction.java b/core/trino-main/src/main/java/io/trino/operator/PartitionFunction.java index 6f9768affeb8..ef1627843f86 100644 --- a/core/trino-main/src/main/java/io/trino/operator/PartitionFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/PartitionFunction.java @@ -17,7 +17,7 @@ public interface PartitionFunction { - int getPartitionCount(); + int partitionCount(); /** * @param page the arguments to bucketing function in order (no extra columns) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AbstractMapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AbstractMapAggregationState.java index 02c3091a6cb5..e410366b95db 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AbstractMapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AbstractMapAggregationState.java @@ -17,9 +17,9 @@ import com.google.common.primitives.Ints; import io.trino.operator.VariableWidthData; import io.trino.spi.TrinoException; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import jakarta.annotation.Nullable; @@ -285,7 +285,7 @@ private void serializeEntry(BlockBuilder keyBuilder, BlockBuilder valueBuilder, } } - protected void add(int groupId, Block keyBlock, int keyPosition, Block valueBlock, int valuePosition) + protected void add(int groupId, ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { checkArgument(!keyBlock.isNull(keyPosition), "key must not be null"); checkArgument(groupId == 0 || groupRecordIndex != null, "groupId must be zero when grouping is not enabled"); @@ -322,7 +322,7 @@ protected void add(int groupId, Block keyBlock, int keyPosition, Block valueBloc } } - private int matchInVector(int groupId, Block block, int position, int vectorStartBucket, long repeated, long controlVector) + private int matchInVector(int groupId, ValueBlock block, int position, int vectorStartBucket, long repeated, long controlVector) { long controlMatches = match(controlVector, repeated); while (controlMatches != 0) { @@ -346,7 +346,7 @@ private int findEmptyInVector(long vector, int vectorStartBucket) return bucket(vectorStartBucket + slot); } - private void insert(int index, int groupId, Block keyBlock, int keyPosition, Block valueBlock, int valuePosition, byte hashPrefix) + private void insert(int index, int groupId, ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition, byte hashPrefix) { setControl(index, hashPrefix); @@ -499,7 +499,7 @@ private long keyHashCode(int groupId, byte[] records, int index) } } - private long keyHashCode(int groupId, Block right, int rightPosition) + private long keyHashCode(int groupId, ValueBlock right, int rightPosition) { try { long valueHash = (long) keyHashBlock.invokeExact(right, rightPosition); @@ -511,7 +511,7 @@ private long keyHashCode(int groupId, Block right, int rightPosition) } } - private boolean keyNotDistinctFrom(int leftPosition, Block right, int rightPosition, int rightGroupId) + private boolean keyNotDistinctFrom(int leftPosition, ValueBlock right, int rightPosition, int rightGroupId) { byte[] leftRecords = getRecords(leftPosition); int leftRecordOffset = getRecordOffset(leftPosition); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java index 2333a73b36a4..98b417aea187 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java @@ -33,6 +33,7 @@ import io.trino.spi.block.ColumnarRow; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.RowValueBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateFactory; import io.trino.spi.function.AccumulatorStateSerializer; @@ -71,6 +72,7 @@ import static io.airlift.bytecode.expression.BytecodeExpressions.invokeDynamic; import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic; import static io.airlift.bytecode.expression.BytecodeExpressions.newInstance; +import static io.trino.operator.aggregation.AggregationLoopBuilder.buildLoop; import static io.trino.operator.aggregation.AggregationMaskCompiler.generateAggregationMaskBuilder; import static io.trino.sql.gen.Bootstrap.BOOTSTRAP_METHOD; import static io.trino.sql.gen.BytecodeUtils.invoke; @@ -88,7 +90,8 @@ private AccumulatorCompiler() {} public static AccumulatorFactory generateAccumulatorFactory( BoundSignature boundSignature, AggregationImplementation implementation, - FunctionNullability functionNullability) + FunctionNullability functionNullability, + boolean specializedLoops) { // change types used in Aggregation methods to types used in the core Trino engine to simplify code generation implementation = normalizeAggregationMethods(implementation); @@ -98,19 +101,21 @@ public static AccumulatorFactory generateAccumulatorFactory( List argumentNullable = functionNullability.getArgumentNullable() .subList(0, functionNullability.getArgumentNullable().size() - implementation.getLambdaInterfaces().size()); - Constructor accumulatorConstructor = generateAccumulatorClass( + Constructor groupedAccumulatorConstructor = generateAccumulatorClass( boundSignature, - Accumulator.class, + GroupedAccumulator.class, implementation, argumentNullable, - classLoader); + classLoader, + specializedLoops); - Constructor groupedAccumulatorConstructor = generateAccumulatorClass( + Constructor accumulatorConstructor = generateAccumulatorClass( boundSignature, - GroupedAccumulator.class, + Accumulator.class, implementation, argumentNullable, - classLoader); + classLoader, + specializedLoops); List nonNullArguments = new ArrayList<>(); for (int argumentIndex = 0; argumentIndex < argumentNullable.size(); argumentIndex++) { @@ -132,7 +137,8 @@ private static Constructor generateAccumulatorClass( Class accumulatorInterface, AggregationImplementation implementation, List argumentNullable, - DynamicClassLoader classLoader) + DynamicClassLoader classLoader, + boolean specializedLoops) { boolean grouped = accumulatorInterface == GroupedAccumulator.class; @@ -180,6 +186,7 @@ private static Constructor generateAccumulatorClass( generateAddInput( definition, + specializedLoops, stateFields, argumentNullable, lambdaProviderFields, @@ -363,6 +370,7 @@ private static void generateSetGroupCount(ClassDefinition definition, List stateField, List argumentNullable, List lambdaProviderFields, @@ -395,6 +403,7 @@ private static void generateAddInput( } BytecodeBlock block = generateInputForLoop( + specializedLoops, stateField, inputFunction, scope, @@ -429,25 +438,40 @@ private static void generateAddOrRemoveInputWindowIndex( type(void.class), ImmutableList.of(index, startPosition, endPosition)); Scope scope = method.getScope(); + BytecodeBlock body = method.getBody(); Variable position = scope.declareVariable(int.class, "position"); + // input parameters + Variable inputBlockPosition = scope.declareVariable(int.class, "inputBlockPosition"); + List inputBlockVariables = new ArrayList<>(); + for (int i = 0; i < argumentNullable.size(); i++) { + inputBlockVariables.add(scope.declareVariable(Block.class, "inputBlock" + i)); + } + Binding binding = callSiteBinder.bind(inputFunction); - BytecodeExpression invokeInputFunction = invokeDynamic( + BytecodeBlock invokeInputFunction = new BytecodeBlock(); + // WindowIndex is built on PagesIndex, which simply wraps Blocks + // and currently does not understand ValueBlocks. + // Until PagesIndex is updated to understand ValueBlocks, the + // input function parameters must be directly unwrapped to ValueBlocks. + invokeInputFunction.append(inputBlockPosition.set(index.cast(InternalWindowIndex.class).invoke("getRawBlockPosition", int.class, position))); + for (int i = 0; i < inputBlockVariables.size(); i++) { + invokeInputFunction.append(inputBlockVariables.get(i).set(index.cast(InternalWindowIndex.class).invoke("getRawBlock", Block.class, constantInt(i), position))); + } + invokeInputFunction.append(invokeDynamic( BOOTSTRAP_METHOD, ImmutableList.of(binding.getBindingId()), generatedFunctionName, binding.getType(), getInvokeFunctionOnWindowIndexParameters( - scope, - argumentNullable.size(), - lambdaProviderFields, + scope.getThis(), stateField, - index, - position)); + inputBlockPosition, + inputBlockVariables, + lambdaProviderFields))); - method.getBody() - .append(new ForLoop() + body.append(new ForLoop() .initialize(position.set(startPosition)) .condition(BytecodeExpressions.lessThanOrEqual(position, endPosition)) .update(position.increment()) @@ -473,33 +497,28 @@ private static BytecodeExpression anyParametersAreNull( } private static List getInvokeFunctionOnWindowIndexParameters( - Scope scope, - int inputParameterCount, - List lambdaProviderFields, + Variable thisVariable, List stateField, - Variable index, - Variable position) + Variable inputBlockPosition, + List inputBlockVariables, + List lambdaProviderFields) { List expressions = new ArrayList<>(); // state parameters for (FieldDefinition field : stateField) { - expressions.add(scope.getThis().getField(field)); + expressions.add(thisVariable.getField(field)); } // input parameters - for (int i = 0; i < inputParameterCount; i++) { - expressions.add(index.cast(InternalWindowIndex.class).invoke("getRawBlock", Block.class, constantInt(i), position)); - } - - // position parameter - if (inputParameterCount > 0) { - expressions.add(index.cast(InternalWindowIndex.class).invoke("getRawBlockPosition", int.class, position)); + for (Variable blockVariable : inputBlockVariables) { + expressions.add(blockVariable.invoke("getUnderlyingValueBlock", ValueBlock.class)); + expressions.add(blockVariable.invoke("getUnderlyingValuePosition", int.class, inputBlockPosition)); } // lambda parameters for (FieldDefinition lambdaProviderField : lambdaProviderFields) { - expressions.add(scope.getThis().getField(lambdaProviderField) + expressions.add(thisVariable.getField(lambdaProviderField) .invoke("get", Object.class)); } @@ -507,6 +526,7 @@ private static List getInvokeFunctionOnWindowIndexParameters } private static BytecodeBlock generateInputForLoop( + boolean specializedLoops, List stateField, MethodHandle inputFunction, Scope scope, @@ -516,6 +536,30 @@ private static BytecodeBlock generateInputForLoop( CallSiteBinder callSiteBinder, boolean grouped) { + if (specializedLoops) { + BytecodeBlock newBlock = new BytecodeBlock(); + Variable thisVariable = scope.getThis(); + + MethodHandle mainLoop = buildLoop(inputFunction, stateField.size(), parameterVariables.size(), grouped); + + ImmutableList.Builder parameters = ImmutableList.builder(); + parameters.add(mask); + if (grouped) { + parameters.add(scope.getVariable("groupIds")); + } + for (FieldDefinition fieldDefinition : stateField) { + parameters.add(thisVariable.getField(fieldDefinition)); + } + parameters.addAll(parameterVariables); + for (FieldDefinition lambdaProviderField : lambdaProviderFields) { + parameters.add(scope.getThis().getField(lambdaProviderField) + .invoke("get", Object.class)); + } + + newBlock.append(invoke(callSiteBinder.bind(mainLoop), "mainLoop", parameters.build())); + return newBlock; + } + // For-loop over rows Variable positionVariable = scope.declareVariable(int.class, "position"); Variable rowsVariable = scope.declareVariable(int.class, "rows"); @@ -596,11 +640,9 @@ private static BytecodeBlock generateInvokeInputFunction( } // input parameters - parameters.addAll(parameterVariables); - - // position parameter - if (!parameterVariables.isEmpty()) { - parameters.add(position); + for (Variable variable : parameterVariables) { + parameters.add(variable.invoke("getUnderlyingValueBlock", ValueBlock.class)); + parameters.add(variable.invoke("getUnderlyingValuePosition", int.class, position)); } // lambda parameters @@ -1054,32 +1096,38 @@ private static BytecodeExpression generateRequireNotNull(BytecodeExpression expr private static AggregationImplementation normalizeAggregationMethods(AggregationImplementation implementation) { // change aggregations state variables to simply AccumulatorState to avoid any class loader issues in generated code - int stateParameterCount = implementation.getAccumulatorStateDescriptors().size(); int lambdaParameterCount = implementation.getLambdaInterfaces().size(); AggregationImplementation.Builder builder = AggregationImplementation.builder(); - builder.inputFunction(castStateParameters(implementation.getInputFunction(), stateParameterCount, lambdaParameterCount)); + builder.inputFunction(normalizeParameters(implementation.getInputFunction(), lambdaParameterCount)); implementation.getRemoveInputFunction() - .map(removeFunction -> castStateParameters(removeFunction, stateParameterCount, lambdaParameterCount)) + .map(removeFunction -> normalizeParameters(removeFunction, lambdaParameterCount)) .ifPresent(builder::removeInputFunction); implementation.getCombineFunction() - .map(combineFunction -> castStateParameters(combineFunction, stateParameterCount * 2, lambdaParameterCount)) + .map(combineFunction -> normalizeParameters(combineFunction, lambdaParameterCount)) .ifPresent(builder::combineFunction); - builder.outputFunction(castStateParameters(implementation.getOutputFunction(), stateParameterCount, 0)); + builder.outputFunction(normalizeParameters(implementation.getOutputFunction(), 0)); builder.accumulatorStateDescriptors(implementation.getAccumulatorStateDescriptors()); builder.lambdaInterfaces(implementation.getLambdaInterfaces()); return builder.build(); } - private static MethodHandle castStateParameters(MethodHandle inputFunction, int stateParameterCount, int lambdaParameterCount) + private static MethodHandle normalizeParameters(MethodHandle function, int lambdaParameterCount) { - Class[] parameterTypes = inputFunction.type().parameterArray(); - for (int i = 0; i < stateParameterCount; i++) { - parameterTypes[i] = AccumulatorState.class; + Class[] parameterTypes = function.type().parameterArray(); + for (int i = 0; i < parameterTypes.length; i++) { + Class parameterType = parameterTypes[i]; + if (AccumulatorState.class.isAssignableFrom(parameterType)) { + parameterTypes[i] = AccumulatorState.class; + } + else if (ValueBlock.class.isAssignableFrom(parameterType)) { + parameterTypes[i] = ValueBlock.class; + } } for (int i = parameterTypes.length - lambdaParameterCount; i < parameterTypes.length; i++) { parameterTypes[i] = Object.class; } - return MethodHandles.explicitCastArguments(inputFunction, MethodType.methodType(inputFunction.type().returnType(), parameterTypes)); + MethodType newType = MethodType.methodType(function.type().returnType(), parameterTypes); + return MethodHandles.explicitCastArguments(function, newType); } private static class StateFieldAndDescriptor diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java index f27391dbefd5..a46cd8bcaa73 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java @@ -280,7 +280,7 @@ private static List getInputFunctions(Class clazz, List 1) { List> parameterAnnotations = getNonDependencyParameterAnnotations(inputFunction) .subList(0, stateDetails.size()); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFunctionAdapter.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFunctionAdapter.java index 84d20bfddf86..6315b354cdd7 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFunctionAdapter.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFunctionAdapter.java @@ -15,15 +15,13 @@ import com.google.common.collect.ImmutableList; import io.trino.spi.block.Block; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.BoundSignature; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodType; import java.util.ArrayList; import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -34,7 +32,7 @@ import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; import static java.lang.invoke.MethodHandles.collectArguments; import static java.lang.invoke.MethodHandles.lookup; -import static java.lang.invoke.MethodHandles.permuteArguments; +import static java.lang.invoke.MethodType.methodType; import static java.util.Objects.requireNonNull; public final class AggregationFunctionAdapter @@ -55,10 +53,14 @@ public enum AggregationParameterKind static { try { - BOOLEAN_TYPE_GETTER = lookup().findVirtual(Type.class, "getBoolean", MethodType.methodType(boolean.class, Block.class, int.class)); - LONG_TYPE_GETTER = lookup().findVirtual(Type.class, "getLong", MethodType.methodType(long.class, Block.class, int.class)); - DOUBLE_TYPE_GETTER = lookup().findVirtual(Type.class, "getDouble", MethodType.methodType(double.class, Block.class, int.class)); - OBJECT_TYPE_GETTER = lookup().findVirtual(Type.class, "getObject", MethodType.methodType(Object.class, Block.class, int.class)); + BOOLEAN_TYPE_GETTER = lookup().findVirtual(Type.class, "getBoolean", methodType(boolean.class, Block.class, int.class)) + .asType(methodType(boolean.class, Type.class, ValueBlock.class, int.class)); + LONG_TYPE_GETTER = lookup().findVirtual(Type.class, "getLong", methodType(long.class, Block.class, int.class)) + .asType(methodType(long.class, Type.class, ValueBlock.class, int.class)); + DOUBLE_TYPE_GETTER = lookup().findVirtual(Type.class, "getDouble", methodType(double.class, Block.class, int.class)) + .asType(methodType(double.class, Type.class, ValueBlock.class, int.class)); + OBJECT_TYPE_GETTER = lookup().findVirtual(Type.class, "getObject", methodType(Object.class, Block.class, int.class)) + .asType(methodType(Object.class, Type.class, ValueBlock.class, int.class)); } catch (ReflectiveOperationException e) { throw new AssertionError(e); @@ -103,7 +105,6 @@ public static MethodHandle normalizeInputMethod( List inputArgumentKinds = parameterKinds.stream() .filter(kind -> kind == INPUT_CHANNEL || kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL) .collect(toImmutableList()); - boolean hasInputChannel = parameterKinds.stream().anyMatch(kind -> kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL); checkArgument( boundSignature.getArgumentTypes().size() - lambdaCount == inputArgumentKinds.size(), @@ -113,21 +114,26 @@ public static MethodHandle normalizeInputMethod( List expectedInputArgumentKinds = new ArrayList<>(); expectedInputArgumentKinds.addAll(stateArgumentKinds); - expectedInputArgumentKinds.addAll(inputArgumentKinds); - if (hasInputChannel) { - expectedInputArgumentKinds.add(BLOCK_INDEX); + for (AggregationParameterKind kind : inputArgumentKinds) { + expectedInputArgumentKinds.add(kind); + if (kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL) { + expectedInputArgumentKinds.add(BLOCK_INDEX); + } } + checkArgument( expectedInputArgumentKinds.equals(parameterKinds), "Expected input parameter kinds %s, but got %s", expectedInputArgumentKinds, parameterKinds); - MethodType inputMethodType = inputMethod.type(); for (int argumentIndex = 0; argumentIndex < inputArgumentKinds.size(); argumentIndex++) { - int parameterIndex = stateArgumentKinds.size() + argumentIndex; + int parameterIndex = stateArgumentKinds.size() + (argumentIndex * 2); AggregationParameterKind inputArgument = inputArgumentKinds.get(argumentIndex); if (inputArgument != INPUT_CHANNEL) { + if (inputArgument == BLOCK_INPUT_CHANNEL || inputArgument == NULLABLE_BLOCK_INPUT_CHANNEL) { + checkArgument(ValueBlock.class.isAssignableFrom(inputMethod.type().parameterType(parameterIndex)), "Expected parameter %s to be a ValueBlock", parameterIndex); + } continue; } Type argumentType = boundSignature.getArgumentType(argumentIndex); @@ -145,27 +151,9 @@ else if (argumentType.getJavaType().equals(double.class)) { } else { valueGetter = OBJECT_TYPE_GETTER.bindTo(argumentType); - valueGetter = valueGetter.asType(valueGetter.type().changeReturnType(inputMethodType.parameterType(parameterIndex))); + valueGetter = valueGetter.asType(valueGetter.type().changeReturnType(inputMethod.type().parameterType(parameterIndex))); } inputMethod = collectArguments(inputMethod, parameterIndex, valueGetter); - - // move the position argument to the end (and combine with other existing position argument) - inputMethodType = inputMethodType.changeParameterType(parameterIndex, Block.class); - - ArrayList reorder; - if (hasInputChannel) { - reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new)); - reorder.add(parameterIndex + 1, inputMethodType.parameterCount() - 1 - lambdaCount); - } - else { - inputMethodType = inputMethodType.insertParameterTypes(inputMethodType.parameterCount() - lambdaCount, int.class); - reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new)); - int positionParameterIndex = inputMethodType.parameterCount() - 1 - lambdaCount; - reorder.remove(positionParameterIndex); - reorder.add(parameterIndex + 1, positionParameterIndex); - hasInputChannel = true; - } - inputMethod = permuteArguments(inputMethod, inputMethodType, reorder.stream().mapToInt(Integer::intValue).toArray()); } return inputMethod; } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationLoopBuilder.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationLoopBuilder.java new file mode 100644 index 000000000000..e7b7dd678452 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationLoopBuilder.java @@ -0,0 +1,331 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import com.google.common.collect.ImmutableList; +import io.airlift.bytecode.BytecodeBlock; +import io.airlift.bytecode.BytecodeNode; +import io.airlift.bytecode.ClassDefinition; +import io.airlift.bytecode.MethodDefinition; +import io.airlift.bytecode.Parameter; +import io.airlift.bytecode.Scope; +import io.airlift.bytecode.Variable; +import io.airlift.bytecode.control.ForLoop; +import io.airlift.bytecode.control.IfStatement; +import io.airlift.bytecode.expression.BytecodeExpression; +import io.airlift.bytecode.expression.BytecodeExpressions; +import io.trino.spi.block.Block; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; +import io.trino.spi.function.GroupedAccumulatorState; +import io.trino.sql.gen.CallSiteBinder; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodType; +import java.lang.reflect.Method; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.Iterables.cycle; +import static com.google.common.collect.Iterables.limit; +import static com.google.common.collect.MoreCollectors.onlyElement; +import static io.airlift.bytecode.Access.FINAL; +import static io.airlift.bytecode.Access.PRIVATE; +import static io.airlift.bytecode.Access.PUBLIC; +import static io.airlift.bytecode.Access.STATIC; +import static io.airlift.bytecode.Access.a; +import static io.airlift.bytecode.Parameter.arg; +import static io.airlift.bytecode.ParameterizedType.type; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantString; +import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic; +import static io.airlift.bytecode.expression.BytecodeExpressions.lessThan; +import static io.airlift.bytecode.expression.BytecodeExpressions.newInstance; +import static io.trino.sql.gen.BytecodeUtils.invoke; +import static io.trino.util.CompilerUtils.defineClass; +import static io.trino.util.CompilerUtils.makeClassName; +import static java.lang.invoke.MethodHandles.lookup; +import static java.lang.invoke.MethodType.methodType; + +final class AggregationLoopBuilder +{ + private AggregationLoopBuilder() {} + + /** + * Build a loop over the aggregation function. Internally, there are multiple loops generated that are specialized for + * RLE, Dictionary, and basic blocks, and for masked or unmasked input. The method handle is expected to have a {@link Block} and int + * position argument for each parameter. The returned method handle signature, will start with as {@link AggregationMask} + * and then a single {@link Block} for each parameter. + */ + public static MethodHandle buildLoop(MethodHandle function, int stateCount, int parameterCount, boolean grouped) + { + verifyFunctionSignature(function, stateCount, parameterCount); + CallSiteBinder binder = new CallSiteBinder(); + ClassDefinition definition = new ClassDefinition( + a(PUBLIC, STATIC, FINAL), + makeClassName("AggregationLoop"), + type(Object.class)); + + definition.declareDefaultConstructor(a(PRIVATE)); + + buildSpecializedLoop(binder, definition, function, stateCount, parameterCount, grouped); + + Class clazz = defineClass(definition, Object.class, binder.getBindings(), AggregationLoopBuilder.class.getClassLoader()); + + // it is simpler to find the method with reflection than using lookup().findStatic because of the complex signature + Method invokeMethod = Arrays.stream(clazz.getMethods()) + .filter(method -> method.getName().equals("invoke")) + .collect(onlyElement()); + + try { + return lookup().unreflect(invokeMethod); + } + catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + private static void buildSpecializedLoop(CallSiteBinder binder, ClassDefinition classDefinition, MethodHandle function, int stateCount, int parameterCount, boolean grouped) + { + AggregationParameters aggregationParameters = AggregationParameters.create(function, stateCount, parameterCount, grouped); + MethodDefinition methodDefinition = classDefinition.declareMethod( + a(PUBLIC, STATIC), + "invoke", + type(void.class), + aggregationParameters.allParameters()); + + Function, BytecodeNode> coreLoopBuilder = (blockTypes) -> { + MethodDefinition method = buildCoreLoop(binder, classDefinition, function, blockTypes, aggregationParameters); + return invokeStatic(method, aggregationParameters.allParameters().toArray(new BytecodeExpression[0])); + }; + + BytecodeNode bytecodeNode = buildLoopSelection(coreLoopBuilder, new ArrayDeque<>(parameterCount), new ArrayDeque<>(aggregationParameters.blocks())); + methodDefinition.getBody() + .append(bytecodeNode) + .ret(); + } + + private static BytecodeNode buildLoopSelection(Function, BytecodeNode> coreLoopBuilder, ArrayDeque currentTypes, ArrayDeque remainingParameters) + { + if (remainingParameters.isEmpty()) { + return coreLoopBuilder.apply(ImmutableList.copyOf(currentTypes)); + } + + // remove the next parameter from the queue + Parameter blockParameter = remainingParameters.removeFirst(); + + currentTypes.addLast(BlockType.VALUE); + BytecodeNode valueLoop = buildLoopSelection(coreLoopBuilder, currentTypes, remainingParameters); + currentTypes.removeLast(); + + currentTypes.addLast(BlockType.DICTIONARY); + BytecodeNode dictionaryLoop = buildLoopSelection(coreLoopBuilder, currentTypes, remainingParameters); + currentTypes.removeLast(); + + currentTypes.addLast(BlockType.RLE); + BytecodeNode rleLoop = buildLoopSelection(coreLoopBuilder, currentTypes, remainingParameters); + currentTypes.removeLast(); + + IfStatement blockTypeSelection = new IfStatement() + .condition(blockParameter.instanceOf(ValueBlock.class)) + .ifTrue(valueLoop) + .ifFalse(new IfStatement() + .condition(blockParameter.instanceOf(DictionaryBlock.class)) + .ifTrue(dictionaryLoop) + .ifFalse(new IfStatement() + .condition(blockParameter.instanceOf(RunLengthEncodedBlock.class)) + .ifTrue(rleLoop) + .ifFalse(new BytecodeBlock() + .append(newInstance(UnsupportedOperationException.class, constantString("Aggregation is not decomposable"))) + .throwObject()))); + + // restore the parameter to the queue + remainingParameters.addFirst(blockParameter); + + return blockTypeSelection; + } + + private static MethodDefinition buildCoreLoop( + CallSiteBinder binder, + ClassDefinition classDefinition, + MethodHandle function, + List blockTypes, + AggregationParameters aggregationParameters) + { + StringBuilder methodName = new StringBuilder("invoke_"); + for (BlockType blockType : blockTypes) { + methodName.append(blockType.name().charAt(0)); + } + + MethodDefinition methodDefinition = classDefinition.declareMethod( + a(PUBLIC, STATIC), + methodName.toString(), + type(void.class), + aggregationParameters.allParameters()); + Scope scope = methodDefinition.getScope(); + BytecodeBlock body = methodDefinition.getBody(); + + Variable position = scope.declareVariable(int.class, "position"); + + ImmutableList.Builder aggregationArguments = ImmutableList.builder(); + aggregationArguments.addAll(aggregationParameters.states()); + addBlockPositionArguments(methodDefinition, position, blockTypes, aggregationParameters.blocks(), aggregationArguments); + aggregationArguments.addAll(aggregationParameters.lambdas()); + + BytecodeBlock invokeFunction = new BytecodeBlock(); + if (aggregationParameters.groupIds().isPresent()) { + // set groupId on state variables + Variable groupId = scope.declareVariable(int.class, "groupId"); + invokeFunction.append(groupId.set(aggregationParameters.groupIds().get().getElement(position))); + for (Parameter stateParameter : aggregationParameters.states()) { + invokeFunction.append(stateParameter.cast(GroupedAccumulatorState.class).invoke("setGroupId", void.class, groupId.cast(long.class))); + } + } + invokeFunction.append(invoke(binder.bind(function), "input", aggregationArguments.build())); + + Variable positionCount = scope.declareVariable("positionCount", body, aggregationParameters.mask().invoke("getSelectedPositionCount", int.class)); + + ForLoop selectAllLoop = new ForLoop() + .initialize(position.set(constantInt(0))) + .condition(lessThan(position, positionCount)) + .update(position.increment()) + .body(invokeFunction); + + Variable index = scope.declareVariable("index", body, constantInt(0)); + Variable selectedPositions = scope.declareVariable(int[].class, "selectedPositions"); + ForLoop maskedLoop = new ForLoop() + .initialize(selectedPositions.set(aggregationParameters.mask().invoke("getSelectedPositions", int[].class))) + .condition(lessThan(index, positionCount)) + .update(index.increment()) + .body(new BytecodeBlock() + .append(position.set(selectedPositions.getElement(index))) + .append(invokeFunction)); + + body.append(new IfStatement() + .condition(aggregationParameters.mask().invoke("isSelectAll", boolean.class)) + .ifTrue(selectAllLoop) + .ifFalse(maskedLoop)); + body.ret(); + return methodDefinition; + } + + private static void addBlockPositionArguments( + MethodDefinition methodDefinition, + Variable position, + List blockTypes, + List blockParameters, + ImmutableList.Builder aggregationArguments) + { + Scope scope = methodDefinition.getScope(); + BytecodeBlock methodBody = methodDefinition.getBody(); + + for (int i = 0; i < blockTypes.size(); i++) { + BlockType blockType = blockTypes.get(i); + switch (blockType) { + case VALUE -> { + aggregationArguments.add(blockParameters.get(i).cast(ValueBlock.class)); + aggregationArguments.add(position); + } + case DICTIONARY -> { + Variable valueBlock = scope.declareVariable( + "valueBlock" + i, + methodBody, + blockParameters.get(i).cast(DictionaryBlock.class).invoke("getDictionary", ValueBlock.class)); + Variable rawIds = scope.declareVariable( + "rawIds" + i, + methodBody, + blockParameters.get(i).cast(DictionaryBlock.class).invoke("getRawIds", int[].class)); + Variable rawIdsOffset = scope.declareVariable( + "rawIdsOffset" + i, + methodBody, + blockParameters.get(i).cast(DictionaryBlock.class).invoke("getRawIdsOffset", int.class)); + aggregationArguments.add(valueBlock); + aggregationArguments.add(rawIds.getElement(BytecodeExpressions.add(rawIdsOffset, position))); + } + case RLE -> { + Variable valueBlock = scope.declareVariable( + "valueBlock" + i, + methodBody, + blockParameters.get(i).cast(RunLengthEncodedBlock.class).invoke("getValue", ValueBlock.class)); + aggregationArguments.add(valueBlock); + aggregationArguments.add(constantInt(0)); + } + } + } + } + + private static void verifyFunctionSignature(MethodHandle function, int stateCount, int parameterCount) + { + // verify signature + List> expectedParameterTypes = ImmutableList.>builder() + .addAll(function.type().parameterList().subList(0, stateCount)) + .addAll(limit(cycle(ValueBlock.class, int.class), parameterCount * 2)) + .addAll(function.type().parameterList().subList(stateCount + (parameterCount * 2), function.type().parameterCount())) + .build(); + MethodType expectedSignature = methodType(void.class, expectedParameterTypes); + checkArgument(function.type().equals(expectedSignature), "Expected function signature to be %s, but is %s", expectedSignature, function.type()); + } + + private record AggregationParameters(Parameter mask, Optional groupIds, List states, List blocks, List lambdas) + { + static AggregationParameters create(MethodHandle function, int stateCount, int parameterCount, boolean grouped) + { + Parameter mask = arg("aggregationMask", AggregationMask.class); + + Optional groupIds = Optional.empty(); + if (grouped) { + groupIds = Optional.of(arg("groupIds", int[].class)); + } + + ImmutableList.Builder states = ImmutableList.builder(); + for (int i = 0; i < stateCount; i++) { + states.add(arg("state" + i, function.type().parameterType(i))); + } + + ImmutableList.Builder parameters = ImmutableList.builder(); + for (int i = 0; i < parameterCount; i++) { + parameters.add(arg("block" + i, Block.class)); + } + + ImmutableList.Builder lambdas = ImmutableList.builder(); + int lambdaFunctionOffset = stateCount + (parameterCount * 2); + for (int i = 0; i < function.type().parameterCount() - lambdaFunctionOffset; i++) { + lambdas.add(arg("lambda" + i, function.type().parameterType(lambdaFunctionOffset + i))); + } + + return new AggregationParameters(mask, groupIds, states.build(), parameters.build(), lambdas.build()); + } + + public List allParameters() + { + return ImmutableList.builder() + .add(mask) + .addAll(groupIds.stream().iterator()) + .addAll(states) + .addAll(blocks) + .addAll(lambdas) + .build(); + } + } + + private enum BlockType + { + RLE, DICTIONARY, VALUE + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java index e28492becf3f..1679e3ece3ea 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java @@ -16,8 +16,8 @@ import com.google.common.annotations.VisibleForTesting; import io.airlift.stats.cardinality.HyperLogLog; import io.trino.operator.aggregation.state.HyperLogLogState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -52,7 +52,7 @@ private ApproximateCountDistinctAggregation() {} @InputFunction public static void input( @AggregationState HyperLogLogState state, - @BlockPosition @SqlType("unknown") Block block, + @BlockPosition @SqlType("unknown") ValueBlock block, @BlockIndex int index, @SqlType(StandardTypes.DOUBLE) double maxStandardError) { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateSetGenericAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateSetGenericAggregation.java index 459dabae0daf..4791fe78a83d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateSetGenericAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateSetGenericAggregation.java @@ -16,8 +16,8 @@ import io.airlift.stats.cardinality.HyperLogLog; import io.trino.operator.aggregation.state.HyperLogLogState; import io.trino.operator.aggregation.state.StateCompiler; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorStateSerializer; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; @@ -51,7 +51,7 @@ private ApproximateSetGenericAggregation() {} @InputFunction public static void input( @AggregationState HyperLogLogState state, - @BlockPosition @SqlType("unknown") Block block, + @BlockPosition @SqlType("unknown") ValueBlock block, @BlockIndex int index) { // do nothing -- unknown type is always NULL diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ArbitraryAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ArbitraryAggregationFunction.java index 28d778a5489a..e62a9ea91adf 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ArbitraryAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ArbitraryAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -37,7 +37,7 @@ private ArbitraryAggregationFunction() {} @TypeParameter("T") public static void input( @AggregationState("T") InOut state, - @BlockPosition @SqlType("T") Block block, + @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int position) throws Throwable { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ChecksumAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ChecksumAggregationFunction.java index 51452b0e4022..92bbc6e40326 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ChecksumAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ChecksumAggregationFunction.java @@ -17,8 +17,8 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.trino.operator.aggregation.state.NullableLongState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -36,7 +36,7 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.type.VarbinaryType.VARBINARY; @@ -55,10 +55,10 @@ public static void input( @OperatorDependency( operator = OperatorType.XX_HASH_64, argumentTypes = "T", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle xxHash64Operator, @AggregationState NullableLongState state, - @SqlNullable @BlockPosition @SqlType("T") Block block, + @SqlNullable @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int position) throws Throwable { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java index 9163dab81c84..87ccef50fbec 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/CountColumn.java @@ -14,8 +14,8 @@ package io.trino.operator.aggregation; import io.trino.operator.aggregation.state.LongState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -40,7 +40,7 @@ private CountColumn() {} @TypeParameter("T") public static void input( @AggregationState LongState state, - @BlockPosition @SqlType("T") Block block, + @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int position) { state.setValue(state.getValue() + 1); @@ -49,7 +49,7 @@ public static void input( @RemoveInputFunction public static void removeInput( @AggregationState LongState state, - @BlockPosition @SqlType("T") Block block, + @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int position) { state.setValue(state.getValue() - 1); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java index 69728b943941..c62754cac935 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalAverageAggregation.java @@ -15,8 +15,8 @@ import com.google.common.annotations.VisibleForTesting; import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.Int128ArrayBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -80,7 +80,7 @@ public static void inputShortDecimal( @LiteralParameters({"p", "s"}) public static void inputLongDecimal( @AggregationState LongDecimalWithOverflowAndLongState state, - @BlockPosition @SqlType(value = "decimal(p, s)", nativeContainerType = Int128.class) Block block, + @BlockPosition @SqlType(value = "decimal(p, s)", nativeContainerType = Int128.class) Int128ArrayBlock block, @BlockIndex int position) { state.addLong(1); // row counter diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalSumAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalSumAggregation.java index 6439dbc23483..f256a8748777 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalSumAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/DecimalSumAggregation.java @@ -14,8 +14,8 @@ package io.trino.operator.aggregation; import io.trino.operator.aggregation.state.LongDecimalWithOverflowState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.Int128ArrayBlock; import io.trino.spi.block.Int128ArrayBlockBuilder; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; @@ -66,7 +66,7 @@ public static void inputShortDecimal( @LiteralParameters({"p", "s"}) public static void inputLongDecimal( @AggregationState LongDecimalWithOverflowState state, - @BlockPosition @SqlType(value = "decimal(p,s)", nativeContainerType = Int128.class) Block block, + @BlockPosition @SqlType(value = "decimal(p,s)", nativeContainerType = Int128.class) Int128ArrayBlock block, @BlockIndex int position) { state.setNotNull(); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java index ee7c9e7de10e..517fa4df07b4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java @@ -14,8 +14,8 @@ package io.trino.operator.aggregation; import io.trino.operator.aggregation.state.HyperLogLogState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -45,7 +45,7 @@ private DefaultApproximateCountDistinctAggregation() {} @InputFunction public static void input( @AggregationState HyperLogLogState state, - @BlockPosition @SqlType("unknown") Block block, + @BlockPosition @SqlType("unknown") ValueBlock block, @BlockIndex int index) { // do nothing diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedMapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedMapAggregationState.java index 2084d33a0964..0320d955a761 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedMapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedMapAggregationState.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.GroupedAccumulatorState; import io.trino.spi.type.Type; @@ -65,7 +65,7 @@ public void ensureCapacity(long size) } @Override - public void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition) + public void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { add(groupId, keyBlock, keyPosition, valueBlock, valuePosition); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationFunction.java index 0cf618037f66..a1f1ce4b57cd 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationFunction.java @@ -13,9 +13,9 @@ */ package io.trino.operator.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -39,11 +39,12 @@ private MapAggregationFunction() {} @TypeParameter("V") public static void input( @AggregationState({"K", "V"}) MapAggregationState state, - @BlockPosition @SqlType("K") Block key, - @SqlNullable @BlockPosition @SqlType("V") Block value, - @BlockIndex int position) + @BlockPosition @SqlType("K") ValueBlock key, + @BlockIndex int keyPosition, + @SqlNullable @BlockPosition @SqlType("V") ValueBlock value, + @BlockIndex int valuePosition) { - state.add(key, position, value, position); + state.add(key, keyPosition, value, valuePosition); } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationState.java index 0d4a7886a1d3..f1fdbe3122f9 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationState.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateMetadata; @@ -27,7 +28,7 @@ public interface MapAggregationState extends AccumulatorState { - void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition); + void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition); default void merge(MapAggregationState other) { @@ -36,8 +37,10 @@ default void merge(MapAggregationState other) Block rawKeyBlock = serializedState.getRawKeyBlock(); Block rawValueBlock = serializedState.getRawValueBlock(); + ValueBlock rawKeyValues = rawKeyBlock.getUnderlyingValueBlock(); + ValueBlock rawValueValues = rawValueBlock.getUnderlyingValueBlock(); for (int i = 0; i < serializedState.getSize(); i++) { - add(rawKeyBlock, rawOffset + i, rawValueBlock, rawOffset + i); + add(rawKeyValues, rawKeyBlock.getUnderlyingValuePosition(rawOffset + i), rawValueValues, rawValueBlock.getUnderlyingValuePosition(rawOffset + i)); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationStateFactory.java index 8f6ae5c435db..ddb2a4630a54 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapAggregationStateFactory.java @@ -22,8 +22,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -52,7 +52,7 @@ public MapAggregationStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "K", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle keyWriteFlat, + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle keyWriteFlat, @OperatorDependency( operator = OperatorType.HASH_CODE, argumentTypes = "K", @@ -60,11 +60,11 @@ public MapAggregationStateFactory( @OperatorDependency( operator = OperatorType.IS_DISTINCT_FROM, argumentTypes = {"K", "K"}, - convention = @Convention(arguments = {FLAT, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle keyDistinctFlatBlock, + convention = @Convention(arguments = {FLAT, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle keyDistinctFlatBlock, @OperatorDependency( operator = OperatorType.HASH_CODE, argumentTypes = "K", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle keyHashBlock, + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle keyHashBlock, @TypeParameter("V") Type valueType, @OperatorDependency( operator = OperatorType.READ_VALUE, @@ -73,7 +73,7 @@ public MapAggregationStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "V", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle valueWriteFlat) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle valueWriteFlat) { this.keyType = requireNonNull(keyType, "keyType is null"); this.keyReadFlat = requireNonNull(keyReadFlat, "keyReadFlat is null"); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapUnionAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapUnionAggregation.java index 9a247f9e8b82..718090b4601f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MapUnionAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MapUnionAggregation.java @@ -17,6 +17,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.CombineFunction; @@ -45,8 +46,10 @@ public static void input( Block rawKeyBlock = value.getRawKeyBlock(); Block rawValueBlock = value.getRawValueBlock(); + ValueBlock rawKeyValues = rawKeyBlock.getUnderlyingValueBlock(); + ValueBlock rawValueValues = rawValueBlock.getUnderlyingValueBlock(); for (int i = 0; i < value.getSize(); i++) { - state.add(rawKeyBlock, rawOffset + i, rawValueBlock, rawOffset + i); + state.add(rawKeyValues, rawKeyBlock.getUnderlyingValuePosition(rawOffset + i), rawValueValues, rawValueBlock.getUnderlyingValuePosition(rawOffset + i)); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxAggregationFunction.java index b3b500720ddc..6ec2f540c84f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -32,8 +32,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @AggregationFunction("max") @@ -48,10 +48,10 @@ public static void input( @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_FIRST, argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {BLOCK_POSITION_NOT_NULL, IN_OUT}, result = FAIL_ON_NULL)) + convention = @Convention(arguments = {VALUE_BLOCK_POSITION_NOT_NULL, IN_OUT}, result = FAIL_ON_NULL)) MethodHandle compare, @AggregationState("T") InOut state, - @BlockPosition @SqlType("T") Block block, + @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int position) throws Throwable { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxByAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxByAggregationFunction.java index c3fc6fd8a6ab..1e7a2f1294d9 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxByAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxByAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -33,8 +33,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @AggregationFunction("max_by") @@ -50,18 +50,19 @@ public static void input( @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_FIRST, argumentTypes = {"K", "K"}, - convention = @Convention(arguments = {BLOCK_POSITION_NOT_NULL, IN_OUT}, result = FAIL_ON_NULL)) + convention = @Convention(arguments = {VALUE_BLOCK_POSITION_NOT_NULL, IN_OUT}, result = FAIL_ON_NULL)) MethodHandle compare, @AggregationState("K") InOut keyState, @AggregationState("V") InOut valueState, - @SqlNullable @BlockPosition @SqlType("V") Block valueBlock, - @BlockPosition @SqlType("K") Block keyBlock, - @BlockIndex int position) + @SqlNullable @BlockPosition @SqlType("V") ValueBlock valueBlock, + @BlockIndex int valuePosition, + @BlockPosition @SqlType("K") ValueBlock keyBlock, + @BlockIndex int keyPosition) throws Throwable { - if (keyState.isNull() || ((long) compare.invokeExact(keyBlock, position, keyState)) > 0) { - keyState.set(keyBlock, position); - valueState.set(valueBlock, position); + if (keyState.isNull() || ((long) compare.invokeExact(keyBlock, keyPosition, keyState)) > 0) { + keyState.set(keyBlock, keyPosition); + valueState.set(valueBlock, valuePosition); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxDataSizeForStats.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxDataSizeForStats.java index d2ab6797150c..317e16ba8649 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxDataSizeForStats.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxDataSizeForStats.java @@ -14,8 +14,8 @@ package io.trino.operator.aggregation; import io.trino.operator.aggregation.state.LongState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -40,7 +40,7 @@ private MaxDataSizeForStats() {} @InputFunction @TypeParameter("T") - public static void input(@AggregationState LongState state, @SqlNullable @BlockPosition @SqlType("T") Block block, @BlockIndex int index) + public static void input(@AggregationState LongState state, @SqlNullable @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int index) { update(state, block.getEstimatedDataSizeForStats(index)); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MergeQuantileDigestFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MergeQuantileDigestFunction.java index 2c18f112974d..5076734a0a93 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MergeQuantileDigestFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MergeQuantileDigestFunction.java @@ -15,8 +15,8 @@ import io.airlift.stats.QuantileDigest; import io.trino.operator.aggregation.state.QuantileDigestState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -45,7 +45,7 @@ private MergeQuantileDigestFunction() {} public static void input( @TypeParameter("qdigest(V)") Type type, @AggregationState QuantileDigestState state, - @BlockPosition @SqlType("qdigest(V)") Block value, + @BlockPosition @SqlType("qdigest(V)") ValueBlock value, @BlockIndex int index) { merge(state, new QuantileDigest(type.getSlice(value, index))); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MinAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MinAggregationFunction.java index acf8e408dbea..8616b7c2116c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MinAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MinAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -32,8 +32,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @AggregationFunction("min") @@ -48,10 +48,10 @@ public static void input( @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_LAST, argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {BLOCK_POSITION_NOT_NULL, IN_OUT}, result = FAIL_ON_NULL)) + convention = @Convention(arguments = {VALUE_BLOCK_POSITION_NOT_NULL, IN_OUT}, result = FAIL_ON_NULL)) MethodHandle compare, @AggregationState("T") InOut state, - @BlockPosition @SqlType("T") Block block, + @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int position) throws Throwable { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MinByAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MinByAggregationFunction.java index 6d8520d3cf0f..3c79a80adc1f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MinByAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MinByAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -33,8 +33,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @AggregationFunction("min_by") @@ -50,18 +50,19 @@ public static void input( @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_LAST, argumentTypes = {"K", "K"}, - convention = @Convention(arguments = {BLOCK_POSITION_NOT_NULL, IN_OUT}, result = FAIL_ON_NULL)) + convention = @Convention(arguments = {VALUE_BLOCK_POSITION_NOT_NULL, IN_OUT}, result = FAIL_ON_NULL)) MethodHandle compare, @AggregationState("K") InOut keyState, @AggregationState("V") InOut valueState, - @SqlNullable @BlockPosition @SqlType("V") Block valueBlock, - @BlockPosition @SqlType("K") Block keyBlock, - @BlockIndex int position) + @SqlNullable @BlockPosition @SqlType("V") ValueBlock valueBlock, + @BlockIndex int valuePosition, + @BlockPosition @SqlType("K") ValueBlock keyBlock, + @BlockIndex int keyPosition) throws Throwable { - if (keyState.isNull() || ((long) compare.invokeExact(keyBlock, position, keyState)) < 0) { - keyState.set(keyBlock, position); - valueState.set(valueBlock, position); + if (keyState.isNull() || ((long) compare.invokeExact(keyBlock, keyPosition, keyState)) < 0) { + keyState.set(keyBlock, keyPosition); + valueState.set(valueBlock, valuePosition); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregationImplementation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregationImplementation.java index 3aef3c7f2ff5..ba77f57d03dc 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregationImplementation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregationImplementation.java @@ -19,7 +19,7 @@ import io.trino.operator.aggregation.AggregationFromAnnotationsParser.AccumulatorStateDetails; import io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind; import io.trino.operator.annotations.ImplementationDependency; -import io.trino.spi.block.Block; +import io.trino.spi.block.ValueBlock; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -220,7 +220,7 @@ public boolean areTypesAssignable(BoundSignature boundSignature) // block and position works for any type, but if block is annotated with SqlType nativeContainerType, then only types with the // specified container type match - if (isCurrentBlockPosition && methodDeclaredType.isAssignableFrom(Block.class)) { + if (isCurrentBlockPosition && ValueBlock.class.isAssignableFrom(methodDeclaredType)) { continue; } if (methodDeclaredType.isAssignableFrom(argumentType)) { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/SingleMapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/SingleMapAggregationState.java index 3606543c8f9f..1d6fd3b8421d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/SingleMapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/SingleMapAggregationState.java @@ -13,9 +13,9 @@ */ package io.trino.operator.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.type.Type; @@ -61,7 +61,7 @@ private SingleMapAggregationState(SingleMapAggregationState state) } @Override - public void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition) + public void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { add(0, keyBlock, keyPosition, valueBlock, valuePosition); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/SumDataSizeForStats.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/SumDataSizeForStats.java index 415299c729d0..04f2f607f8dd 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/SumDataSizeForStats.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/SumDataSizeForStats.java @@ -14,8 +14,8 @@ package io.trino.operator.aggregation; import io.trino.operator.aggregation.state.LongState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -38,7 +38,7 @@ private SumDataSizeForStats() {} @InputFunction @TypeParameter("T") - public static void input(@AggregationState LongState state, @SqlNullable @BlockPosition @SqlType("T") Block block, @BlockIndex int index) + public static void input(@AggregationState LongState state, @SqlNullable @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int index) { update(state, block.getEstimatedDataSizeForStats(index)); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationFunction.java index 515780bb75fc..d4e969d50bd6 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationFunction.java @@ -14,8 +14,8 @@ package io.trino.operator.aggregation.arrayagg; import io.trino.spi.block.ArrayBlockBuilder; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -39,7 +39,7 @@ private ArrayAggregationFunction() {} @TypeParameter("T") public static void input( @AggregationState("T") ArrayAggregationState state, - @SqlNullable @BlockPosition @SqlType("T") Block value, + @SqlNullable @BlockPosition @SqlType("T") ValueBlock value, @BlockIndex int position) { state.add(value, position); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationState.java index 77018642d751..4488e1708ff4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationState.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateMetadata; @@ -26,9 +27,7 @@ public interface ArrayAggregationState extends AccumulatorState { - void addAll(Block block); - - void add(Block block, int position); + void add(ValueBlock block, int position); void writeAll(BlockBuilder blockBuilder); @@ -36,6 +35,10 @@ public interface ArrayAggregationState default void merge(ArrayAggregationState otherState) { - addAll(((SingleArrayAggregationState) otherState).removeTempDeserializeBlock()); + Block block = ((SingleArrayAggregationState) otherState).removeTempDeserializeBlock(); + ValueBlock valueBlock = block.getUnderlyingValueBlock(); + for (int position = 0; position < block.getPositionCount(); position++) { + add(valueBlock, block.getUnderlyingValuePosition(position)); + } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationStateFactory.java index 7694e78b75aa..9176c313398e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/ArrayAggregationStateFactory.java @@ -22,8 +22,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -43,7 +43,7 @@ public ArrayAggregationStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "T", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle writeFlat, @TypeParameter("T") Type type) { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/FlatArrayBuilder.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/FlatArrayBuilder.java index 1e2e3be1140f..57c2121508b2 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/FlatArrayBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/FlatArrayBuilder.java @@ -15,8 +15,8 @@ import com.google.common.base.Throwables; import io.trino.operator.VariableWidthData; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; @@ -154,7 +154,7 @@ public void setNextIndex(long tailIndex, long nextIndex) LONG_HANDLE.set(records, recordOffset + recordNextIndexOffset, nextIndex); } - public void add(Block block, int position) + public void add(ValueBlock block, int position) { if (size == capacity) { growCapacity(); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/GroupArrayAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/GroupArrayAggregationState.java index 5d5f3e9bcba3..84381a9aa0ac 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/GroupArrayAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/GroupArrayAggregationState.java @@ -15,8 +15,8 @@ import com.google.common.primitives.Ints; import io.trino.operator.aggregation.state.AbstractGroupedAccumulatorState; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; @@ -72,15 +72,7 @@ public void ensureCapacity(long maxGroupId) } @Override - public void addAll(Block block) - { - for (int position = 0; position < block.getPositionCount(); position++) { - add(block, position); - } - } - - @Override - public void add(Block block, int position) + public void add(ValueBlock block, int position) { int groupId = (int) getGroupId(); long index = arrayBuilder.size(); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/SingleArrayAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/SingleArrayAggregationState.java index 64acf0744148..30fcb7acdbc3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/SingleArrayAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/arrayagg/SingleArrayAggregationState.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; @@ -52,15 +53,7 @@ public long getEstimatedSize() } @Override - public void addAll(Block block) - { - for (int position = 0; position < block.getPositionCount(); position++) { - add(block, position); - } - } - - @Override - public void add(Block block, int position) + public void add(ValueBlock block, int position) { arrayBuilder.add(block, position); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/GroupedHistogramState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/GroupedHistogramState.java index 2697ae84556b..ad9675303706 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/GroupedHistogramState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/GroupedHistogramState.java @@ -15,8 +15,8 @@ package io.trino.operator.aggregation.histogram; import io.trino.operator.aggregation.state.AbstractGroupedAccumulatorState; -import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; @@ -50,7 +50,7 @@ public void ensureCapacity(long size) } @Override - public void add(Block block, int position, long count) + public void add(ValueBlock block, int position, long count) { histogram.add(toIntExact(getGroupId()), block, position, count); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/Histogram.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/Histogram.java index 7dccd97cd48e..a835c4780cc0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/Histogram.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/Histogram.java @@ -13,9 +13,9 @@ */ package io.trino.operator.aggregation.histogram; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -39,7 +39,7 @@ private Histogram() {} public static void input( @TypeParameter("T") Type type, @AggregationState("T") HistogramState state, - @BlockPosition @SqlType("T") Block key, + @BlockPosition @SqlType("T") ValueBlock key, @BlockIndex int position) { state.add(key, position, 1L); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramState.java index 32b18321b2ee..b0ae54e64333 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramState.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateMetadata; @@ -29,7 +30,7 @@ public interface HistogramState extends AccumulatorState { - void add(Block block, int position, long count); + void add(ValueBlock block, int position, long count); default void merge(HistogramState other) { @@ -38,8 +39,10 @@ default void merge(HistogramState other) Block rawKeyBlock = serializedState.getRawKeyBlock(); Block rawValueBlock = serializedState.getRawValueBlock(); + ValueBlock rawKeyValues = rawKeyBlock.getUnderlyingValueBlock(); + ValueBlock rawValueValues = rawValueBlock.getUnderlyingValueBlock(); for (int i = 0; i < serializedState.getSize(); i++) { - add(rawKeyBlock, rawOffset + i, BIGINT.getLong(rawValueBlock, rawOffset + i)); + add(rawKeyValues, rawKeyBlock.getUnderlyingValuePosition(rawOffset + i), BIGINT.getLong(rawValueValues, rawValueBlock.getUnderlyingValuePosition(rawOffset + i))); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramStateFactory.java index 096c10e044e9..4a11a67f39d1 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/HistogramStateFactory.java @@ -22,8 +22,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -48,7 +48,7 @@ public HistogramStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "T", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle writeFlat, + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle writeFlat, @OperatorDependency( operator = OperatorType.HASH_CODE, argumentTypes = "T", @@ -56,11 +56,11 @@ public HistogramStateFactory( @OperatorDependency( operator = OperatorType.IS_DISTINCT_FROM, argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {FLAT, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle distinctFlatBlock, + convention = @Convention(arguments = {FLAT, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle distinctFlatBlock, @OperatorDependency( operator = OperatorType.HASH_CODE, argumentTypes = "T", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle hashBlock) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle hashBlock) { this.type = requireNonNull(type, "type is null"); this.readFlat = requireNonNull(readFlat, "readFlat is null"); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/SingleHistogramState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/SingleHistogramState.java index 73a15cc2dd71..c6a3494b6bae 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/SingleHistogramState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/SingleHistogramState.java @@ -14,9 +14,9 @@ package io.trino.operator.aggregation.histogram; -import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; @@ -56,7 +56,7 @@ public SingleHistogramState( } @Override - public void add(Block block, int position, long count) + public void add(ValueBlock block, int position, long count) { if (typedHistogram == null) { typedHistogram = new TypedHistogram(keyType, readFlat, writeFlat, hashFlat, distinctFlatBlock, hashBlock, false); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/TypedHistogram.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/TypedHistogram.java index 95656f8e6ca3..e40f503047a0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/TypedHistogram.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/histogram/TypedHistogram.java @@ -17,9 +17,9 @@ import com.google.common.primitives.Ints; import io.trino.operator.VariableWidthData; import io.trino.spi.TrinoException; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import jakarta.annotation.Nullable; @@ -237,7 +237,7 @@ private void serializeEntry(BlockBuilder keyBuilder, BlockBuilder valueBuilder, BIGINT.writeLong(valueBuilder, (long) LONG_HANDLE.get(records, recordOffset + recordCountOffset)); } - public void add(int groupId, Block block, int position, long count) + public void add(int groupId, ValueBlock block, int position, long count) { checkArgument(!block.isNull(position), "value must not be null"); checkArgument(groupId == 0 || groupRecordIndex != null, "groupId must be zero when grouping is not enabled"); @@ -275,7 +275,7 @@ public void add(int groupId, Block block, int position, long count) } } - private int matchInVector(int groupId, Block block, int position, int vectorStartBucket, long repeated, long controlVector) + private int matchInVector(int groupId, ValueBlock block, int position, int vectorStartBucket, long repeated, long controlVector) { long controlMatches = match(controlVector, repeated); while (controlMatches != 0) { @@ -306,7 +306,7 @@ private void addCount(int index, long increment) LONG_HANDLE.set(records, countOffset, (long) LONG_HANDLE.get(records, countOffset) + increment); } - private void insert(int index, int groupId, Block block, int position, long count, byte hashPrefix) + private void insert(int index, int groupId, ValueBlock block, int position, long count, byte hashPrefix) { setControl(index, hashPrefix); @@ -455,7 +455,7 @@ private long valueHashCode(int groupId, byte[] records, int index) } } - private long valueHashCode(int groupId, Block right, int rightPosition) + private long valueHashCode(int groupId, ValueBlock right, int rightPosition) { try { long valueHash = (long) hashBlock.invokeExact(right, rightPosition); @@ -467,7 +467,7 @@ private long valueHashCode(int groupId, Block right, int rightPosition) } } - private boolean valueNotDistinctFrom(int leftPosition, Block right, int rightPosition, int rightGroupId) + private boolean valueNotDistinctFrom(int leftPosition, ValueBlock right, int rightPosition, int rightGroupId) { byte[] leftRecords = getRecords(leftPosition); int leftRecordOffset = getRecordOffset(leftPosition); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/AbstractListaggAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/AbstractListaggAggregationState.java index ea225a2a9af4..d06b00295590 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/AbstractListaggAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/AbstractListaggAggregationState.java @@ -23,6 +23,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.SqlRow; +import io.trino.spi.block.ValueBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.type.ArrayType; @@ -153,7 +154,7 @@ void setMaxOutputLength(int maxOutputLength) } @Override - public void add(Block block, int position) + public void add(ValueBlock block, int position) { checkArgument(!block.isNull(position), "element is null"); @@ -231,9 +232,10 @@ public void merge(ListaggAggregationState other) boolean showOverflowEntryCount = BOOLEAN.getBoolean(fields.get(3), index); initialize(separator, overflowError, overflowFiller, showOverflowEntryCount); - Block values = new ArrayType(VARCHAR).getObject(fields.get(4), index); - for (int i = 0; i < values.getPositionCount(); i++) { - add(values, i); + Block array = new ArrayType(VARCHAR).getObject(fields.get(4), index); + ValueBlock arrayValues = array.getUnderlyingValueBlock(); + for (int i = 0; i < array.getPositionCount(); i++) { + add(arrayValues, arrayValues.getUnderlyingValuePosition(i)); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/GroupListaggAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/GroupListaggAggregationState.java index 716b10607657..863c3987c51a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/GroupListaggAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/GroupListaggAggregationState.java @@ -15,7 +15,7 @@ import com.google.common.primitives.Ints; import io.airlift.slice.SliceOutput; -import io.trino.spi.block.Block; +import io.trino.spi.block.ValueBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.GroupedAccumulatorState; @@ -102,7 +102,7 @@ public void ensureCapacity(long maxGroupId) } @Override - public void add(Block block, int position) + public void add(ValueBlock block, int position) { super.add(block, position); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationFunction.java index 1bd177cedc97..738b1ffb806a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationFunction.java @@ -14,8 +14,8 @@ package io.trino.operator.aggregation.listagg; import io.airlift.slice.Slice; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; @@ -36,12 +36,12 @@ private ListaggAggregationFunction() {} @InputFunction public static void input( @AggregationState ListaggAggregationState state, - @BlockPosition @SqlType("VARCHAR") Block value, + @BlockPosition @SqlType("VARCHAR") ValueBlock value, + @BlockIndex int position, @SqlType("VARCHAR") Slice separator, @SqlType("BOOLEAN") boolean overflowError, @SqlType("VARCHAR") Slice overflowFiller, - @SqlType("BOOLEAN") boolean showOverflowEntryCount, - @BlockIndex int position) + @SqlType("BOOLEAN") boolean showOverflowEntryCount) { state.initialize(separator, overflowError, overflowFiller, showOverflowEntryCount); state.add(value, position); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationState.java index 107350904832..c5b168c06d3d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/listagg/ListaggAggregationState.java @@ -14,8 +14,8 @@ package io.trino.operator.aggregation.listagg; import io.airlift.slice.Slice; -import io.trino.spi.block.Block; import io.trino.spi.block.RowBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateMetadata; @@ -28,7 +28,7 @@ public interface ListaggAggregationState { void initialize(Slice separator, boolean overflowError, Slice overflowFiller, boolean showOverflowEntryCount); - void add(Block block, int position); + void add(ValueBlock block, int position); void serialize(RowBlockBuilder out); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNAggregationFunction.java index bbeef22c095c..2b3ed17f7512 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation.minmaxbyn; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -38,13 +38,14 @@ private MaxByNAggregationFunction() {} @TypeParameter("V") public static void input( @AggregationState({"K", "V"}) MaxByNState state, - @SqlNullable @BlockPosition @SqlType("V") Block valueBlock, - @BlockPosition @SqlType("K") Block keyBlock, - @SqlType("BIGINT") long n, - @BlockIndex int blockIndex) + @SqlNullable @BlockPosition @SqlType("V") ValueBlock valueBlock, + @BlockIndex int valuePosition, + @BlockPosition @SqlType("K") ValueBlock keyBlock, + @BlockIndex int keyPosition, + @SqlType("BIGINT") long n) { state.initialize(n); - state.add(keyBlock, valueBlock, blockIndex); + state.add(keyBlock, keyPosition, valueBlock, valuePosition); } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNStateFactory.java index a1fa006bf3fa..6c2b06af5ddf 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MaxByNStateFactory.java @@ -27,8 +27,8 @@ import java.util.function.LongFunction; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -50,7 +50,7 @@ public MaxByNStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "K", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle keyWriteFlat, @OperatorDependency( operator = OperatorType.READ_VALUE, @@ -60,7 +60,7 @@ public MaxByNStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "V", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle valueWriteFlat, @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_FIRST, @@ -70,7 +70,7 @@ public MaxByNStateFactory( @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_FIRST, argumentTypes = {"K", "K"}, - convention = @Convention(arguments = {FLAT, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) + convention = @Convention(arguments = {FLAT, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle compareFlatBlock, @TypeParameter("K") Type keyType, @TypeParameter("V") Type valueType) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNAggregationFunction.java index 5036d19f87ab..451240b03d6b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation.minmaxbyn; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -38,13 +38,14 @@ private MinByNAggregationFunction() {} @TypeParameter("V") public static void input( @AggregationState({"K", "V"}) MinByNState state, - @SqlNullable @BlockPosition @SqlType("V") Block valueBlock, - @BlockPosition @SqlType("K") Block keyBlock, - @SqlType("BIGINT") long n, - @BlockIndex int blockIndex) + @SqlNullable @BlockPosition @SqlType("V") ValueBlock valueBlock, + @BlockIndex int valuePosition, + @BlockPosition @SqlType("K") ValueBlock keyBlock, + @BlockIndex int keyPosition, + @SqlType("BIGINT") long n) { state.initialize(n); - state.add(keyBlock, valueBlock, blockIndex); + state.add(keyBlock, keyPosition, valueBlock, valuePosition); } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNStateFactory.java index 79404c8e337b..644f586789e9 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinByNStateFactory.java @@ -27,8 +27,8 @@ import java.util.function.LongFunction; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -50,7 +50,7 @@ public MinByNStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "K", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle keyWriteFlat, @OperatorDependency( operator = OperatorType.READ_VALUE, @@ -60,7 +60,7 @@ public MinByNStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "V", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle valueWriteFlat, @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_LAST, @@ -70,7 +70,7 @@ public MinByNStateFactory( @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_LAST, argumentTypes = {"K", "K"}, - convention = @Convention(arguments = {FLAT, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) + convention = @Convention(arguments = {FLAT, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle compareFlatBlock, @TypeParameter("K") Type keyType, @TypeParameter("V") Type valueType) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNState.java index 516a0d2fea1b..adf254926460 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNState.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation.minmaxbyn; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; public interface MinMaxByNState @@ -29,7 +29,7 @@ public interface MinMaxByNState /** * Adds the value to this state. */ - void add(Block keyBlock, Block valueBlock, int position); + void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition); /** * Merge with the specified state. diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateFactory.java index b69de366e942..a2b63e971139 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/MinMaxByNStateFactory.java @@ -19,6 +19,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.SqlRow; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.GroupedAccumulatorState; import io.trino.spi.type.ArrayType; @@ -50,7 +51,12 @@ public final void merge(MinMaxByNState other) Block keys = new ArrayType(typedKeyValueHeap.getKeyType()).getObject(sqlRow.getRawFieldBlock(1), rawIndex); Block values = new ArrayType(typedKeyValueHeap.getValueType()).getObject(sqlRow.getRawFieldBlock(2), rawIndex); - typedKeyValueHeap.addAll(keys, values); + + ValueBlock rawKeyValues = keys.getUnderlyingValueBlock(); + ValueBlock rawValueValues = values.getUnderlyingValueBlock(); + for (int i = 0; i < keys.getPositionCount(); i++) { + typedKeyValueHeap.add(rawKeyValues, keys.getUnderlyingValuePosition(i), rawValueValues, values.getUnderlyingValuePosition(i)); + } } @Override @@ -118,12 +124,12 @@ public final void initialize(long n) } @Override - public final void add(Block keyBlock, Block valueBlock, int position) + public final void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { TypedKeyValueHeap typedHeap = getTypedKeyValueHeap(); size -= typedHeap.getEstimatedSize(); - typedHeap.add(keyBlock, valueBlock, position); + typedHeap.add(keyBlock, keyPosition, valueBlock, valuePosition); size += typedHeap.getEstimatedSize(); } @@ -203,9 +209,9 @@ public final void initialize(long n) } @Override - public final void add(Block keyBlock, Block valueBlock, int position) + public final void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { - typedHeap.add(keyBlock, valueBlock, position); + typedHeap.add(keyBlock, keyPosition, valueBlock, valuePosition); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/TypedKeyValueHeap.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/TypedKeyValueHeap.java index afe2b2b6a394..4b7afb267fa2 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/TypedKeyValueHeap.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxbyn/TypedKeyValueHeap.java @@ -16,8 +16,8 @@ import com.google.common.base.Throwables; import io.airlift.slice.SizeOf; import io.trino.operator.VariableWidthData; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import it.unimi.dsi.fastutil.ints.IntArrays; import jakarta.annotation.Nullable; @@ -227,27 +227,20 @@ private void write(int index, @Nullable BlockBuilder keyBlockBuilder, BlockBuild } } - public void addAll(Block keyBlock, Block valueBlock) + public void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { - for (int i = 0; i < keyBlock.getPositionCount(); i++) { - add(keyBlock, valueBlock, i); - } - } - - public void add(Block keyBlock, Block valueBlock, int position) - { - checkArgument(!keyBlock.isNull(position)); + checkArgument(!keyBlock.isNull(keyPosition)); if (positionCount == capacity) { // is it possible the value is within the top N values? - if (!shouldConsiderValue(keyBlock, position)) { + if (!shouldConsiderValue(keyBlock, keyPosition)) { return; } clear(0); - set(0, keyBlock, valueBlock, position); + set(0, keyBlock, keyPosition, valueBlock, valuePosition); siftDown(); } else { - set(positionCount, keyBlock, valueBlock, position); + set(positionCount, keyBlock, keyPosition, valueBlock, valuePosition); positionCount++; siftUp(); } @@ -274,7 +267,7 @@ private void clear(int index) }); } - private void set(int index, Block keyBlock, Block valueBlock, int position) + private void set(int index, ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { int recordOffset = getRecordOffset(index); @@ -283,28 +276,28 @@ private void set(int index, Block keyBlock, Block valueBlock, int position) int keyVariableWidthLength = 0; if (variableWidthData != null) { if (keyVariableWidth) { - keyVariableWidthLength = keyType.getFlatVariableWidthSize(keyBlock, position); + keyVariableWidthLength = keyType.getFlatVariableWidthSize(keyBlock, keyPosition); } - int valueVariableWidthLength = valueType.getFlatVariableWidthSize(valueBlock, position); + int valueVariableWidthLength = valueType.getFlatVariableWidthSize(valueBlock, valuePosition); variableWidthChunk = variableWidthData.allocate(fixedChunk, recordOffset, keyVariableWidthLength + valueVariableWidthLength); variableWidthChunkOffset = getChunkOffset(fixedChunk, recordOffset); } try { - keyWriteFlat.invokeExact(keyBlock, position, fixedChunk, recordOffset + recordKeyOffset, variableWidthChunk, variableWidthChunkOffset); + keyWriteFlat.invokeExact(keyBlock, keyPosition, fixedChunk, recordOffset + recordKeyOffset, variableWidthChunk, variableWidthChunkOffset); } catch (Throwable throwable) { Throwables.throwIfUnchecked(throwable); throw new RuntimeException(throwable); } - if (valueBlock.isNull(position)) { + if (valueBlock.isNull(valuePosition)) { fixedChunk[recordOffset + recordKeyOffset - 1] = 1; } else { try { valueWriteFlat.invokeExact( valueBlock, - position, + valuePosition, fixedChunk, recordOffset + recordValueOffset, variableWidthChunk, @@ -394,7 +387,7 @@ private int compare(int leftPosition, int rightPosition) } } - private boolean shouldConsiderValue(Block right, int rightPosition) + private boolean shouldConsiderValue(ValueBlock right, int rightPosition) { byte[] leftFixedRecordChunk = fixedChunk; int leftRecordOffset = getRecordOffset(0); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java index 1c385121520f..df02803a46f4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation.minmaxn; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -36,9 +36,9 @@ private MaxNAggregationFunction() {} @TypeParameter("E") public static void input( @AggregationState("E") MaxNState state, - @BlockPosition @SqlType("E") Block block, - @SqlType("BIGINT") long n, - @BlockIndex int blockIndex) + @BlockPosition @SqlType("E") ValueBlock block, + @BlockIndex int blockIndex, + @SqlType("BIGINT") long n) { state.initialize(n); state.add(block, blockIndex); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNStateFactory.java index 9a49b8057a46..81c8fa1b9681 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MaxNStateFactory.java @@ -27,8 +27,8 @@ import java.util.function.LongFunction; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -50,7 +50,7 @@ public MaxNStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "T", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle writeFlat, @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_FIRST, @@ -60,7 +60,7 @@ public MaxNStateFactory( @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_FIRST, argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {FLAT, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) + convention = @Convention(arguments = {FLAT, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle compareFlatBlock, @TypeParameter("T") Type elementType) { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java index 97144ddc2f04..1c61c61fec2d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNState.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation.minmaxn; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; public interface MinMaxNState @@ -29,7 +29,7 @@ public interface MinMaxNState /** * Adds the value to this state. */ - void add(Block block, int position); + void add(ValueBlock block, int position); /** * Merge with the specified state. diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java index bd74cce3dbc9..fc94133942b3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinMaxNStateFactory.java @@ -19,6 +19,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.SqlRow; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.GroupedAccumulatorState; import io.trino.spi.type.ArrayType; @@ -51,8 +52,11 @@ public final void merge(MinMaxNState other) initialize(capacity); TypedHeap typedHeap = getTypedHeap(); - Block values = new ArrayType(typedHeap.getElementType()).getObject(sqlRow.getRawFieldBlock(1), rawIndex); - typedHeap.addAll(values); + Block array = new ArrayType(typedHeap.getElementType()).getObject(sqlRow.getRawFieldBlock(1), rawIndex); + ValueBlock arrayValues = array.getUnderlyingValueBlock(); + for (int i = 0; i < array.getPositionCount(); i++) { + typedHeap.add(arrayValues, array.getUnderlyingValuePosition(i)); + } } @Override @@ -118,7 +122,7 @@ public final void initialize(long n) } @Override - public final void add(Block block, int position) + public final void add(ValueBlock block, int position) { TypedHeap typedHeap = getTypedHeap(); @@ -200,7 +204,7 @@ public final void initialize(long n) } @Override - public final void add(Block block, int position) + public final void add(ValueBlock block, int position) { typedHeap.add(block, position); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java index 4521f979d6ce..3f4f5a78ceb0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNAggregationFunction.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation.minmaxn; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -36,9 +36,9 @@ private MinNAggregationFunction() {} @TypeParameter("E") public static void input( @AggregationState("E") MinNState state, - @BlockPosition @SqlType("E") Block block, - @SqlType("BIGINT") long n, - @BlockIndex int blockIndex) + @BlockPosition @SqlType("E") ValueBlock block, + @BlockIndex int blockIndex, + @SqlType("BIGINT") long n) { state.initialize(n); state.add(block, blockIndex); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java index 46715f2115bc..99fe5b6496d4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/MinNStateFactory.java @@ -26,8 +26,8 @@ import java.util.function.LongFunction; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -49,7 +49,7 @@ public MinNStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "T", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle writeFlat, @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_LAST, @@ -59,7 +59,7 @@ public MinNStateFactory( @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_LAST, argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {FLAT, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) + convention = @Convention(arguments = {FLAT, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle compareFlatBlock, @TypeParameter("T") Type elementType) { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java index 586e0e372477..7ba0168077d8 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/minmaxn/TypedHeap.java @@ -17,8 +17,8 @@ import com.google.common.primitives.Ints; import io.airlift.slice.SizeOf; import io.trino.operator.VariableWidthData; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import it.unimi.dsi.fastutil.ints.IntArrays; @@ -184,14 +184,7 @@ private void write(int index, BlockBuilder blockBuilder) } } - public void addAll(Block block) - { - for (int i = 0; i < block.getPositionCount(); i++) { - add(block, i); - } - } - - public void add(Block block, int position) + public void add(ValueBlock block, int position) { checkArgument(!block.isNull(position)); if (positionCount == capacity) { @@ -227,7 +220,7 @@ private void clear(int index) elementType.relocateFlatVariableWidthOffsets(fixedChunk, fixedSizeOffset + recordElementOffset, variableWidthChunk, variableWidthChunkOffset)); } - private void set(int index, Block block, int position) + private void set(int index, ValueBlock block, int position) { int recordOffset = getRecordOffset(index); @@ -325,7 +318,7 @@ private int compare(int leftPosition, int rightPosition) } } - private boolean shouldConsiderValue(Block right, int rightPosition) + private boolean shouldConsiderValue(ValueBlock right, int rightPosition) { byte[] leftFixedRecordChunk = fixedChunk; int leftRecordOffset = getRecordOffset(0); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/AbstractMultimapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/AbstractMultimapAggregationState.java index 1ff2b3e4d7d5..5a69677e9168 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/AbstractMultimapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/AbstractMultimapAggregationState.java @@ -23,6 +23,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import jakarta.annotation.Nullable; @@ -299,24 +300,26 @@ protected void deserialize(int groupId, SqlMap serializedState) Block rawKeyBlock = serializedState.getRawKeyBlock(); Block rawValueBlock = serializedState.getRawValueBlock(); + ValueBlock rawKeyValues = rawKeyBlock.getUnderlyingValueBlock(); ArrayType arrayType = new ArrayType(valueArrayBuilder.type()); for (int i = 0; i < serializedState.getSize(); i++) { - int keyId = putKeyIfAbsent(groupId, rawKeyBlock, rawOffset + i); + int keyId = putKeyIfAbsent(groupId, rawKeyValues, rawKeyBlock.getUnderlyingValuePosition(rawOffset + i)); Block array = arrayType.getObject(rawValueBlock, rawOffset + i); verify(array.getPositionCount() > 0, "array is empty"); + ValueBlock arrayValuesBlock = array.getUnderlyingValueBlock(); for (int arrayIndex = 0; arrayIndex < array.getPositionCount(); arrayIndex++) { - addKeyValue(keyId, array, arrayIndex); + addKeyValue(keyId, arrayValuesBlock, array.getUnderlyingValuePosition(arrayIndex)); } } } - protected void add(int groupId, Block keyBlock, int keyPosition, Block valueBlock, int valuePosition) + protected void add(int groupId, ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { int keyId = putKeyIfAbsent(groupId, keyBlock, keyPosition); addKeyValue(keyId, valueBlock, valuePosition); } - private int putKeyIfAbsent(int groupId, Block keyBlock, int keyPosition) + private int putKeyIfAbsent(int groupId, ValueBlock keyBlock, int keyPosition) { checkArgument(!keyBlock.isNull(keyPosition), "key must not be null"); checkArgument(groupId == 0 || groupRecordIndex != null, "groupId must be zero when grouping is not enabled"); @@ -356,7 +359,7 @@ private int putKeyIfAbsent(int groupId, Block keyBlock, int keyPosition) } } - private int matchInVector(int groupId, Block block, int position, int vectorStartBucket, long repeated, long controlVector) + private int matchInVector(int groupId, ValueBlock block, int position, int vectorStartBucket, long repeated, long controlVector) { long controlMatches = match(controlVector, repeated); while (controlMatches != 0) { @@ -380,7 +383,7 @@ private int findEmptyInVector(long vector, int vectorStartBucket) return bucket(vectorStartBucket + slot); } - private int insert(int keyIndex, int groupId, Block keyBlock, int keyPosition, byte hashPrefix) + private int insert(int keyIndex, int groupId, ValueBlock keyBlock, int keyPosition, byte hashPrefix) { setControl(keyIndex, hashPrefix); @@ -430,7 +433,7 @@ private int insert(int keyIndex, int groupId, Block keyBlock, int keyPosition, b return keyId; } - private void addKeyValue(int keyId, Block valueBlock, int valuePosition) + private void addKeyValue(int keyId, ValueBlock valueBlock, int valuePosition) { long index = valueArrayBuilder.size(); if (keyTailPositions[keyId] == -1) { @@ -554,7 +557,7 @@ private long keyHashCode(int groupId, byte[] records, int index) } } - private long keyHashCode(int groupId, Block right, int rightPosition) + private long keyHashCode(int groupId, ValueBlock right, int rightPosition) { try { long valueHash = (long) keyHashBlock.invokeExact(right, rightPosition); @@ -566,7 +569,7 @@ private long keyHashCode(int groupId, Block right, int rightPosition) } } - private boolean keyNotDistinctFrom(int leftPosition, Block right, int rightPosition, int rightGroupId) + private boolean keyNotDistinctFrom(int leftPosition, ValueBlock right, int rightPosition, int rightGroupId) { byte[] leftRecords = getRecords(leftPosition); int leftRecordOffset = getRecordOffset(leftPosition); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/GroupedMultimapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/GroupedMultimapAggregationState.java index e2f5fd079cee..3116bbcd16d8 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/GroupedMultimapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/GroupedMultimapAggregationState.java @@ -13,9 +13,9 @@ */ package io.trino.operator.aggregation.multimapagg; -import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.GroupedAccumulatorState; import io.trino.spi.type.Type; @@ -66,7 +66,7 @@ public void ensureCapacity(long size) } @Override - public void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition) + public void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { add(groupId, keyBlock, keyPosition, valueBlock, valuePosition); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java index d4f64d9a3660..374de6495cee 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java @@ -13,9 +13,9 @@ */ package io.trino.operator.aggregation.multimapagg; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -39,11 +39,12 @@ private MultimapAggregationFunction() {} @TypeParameter("V") public static void input( @AggregationState({"K", "V"}) MultimapAggregationState state, - @BlockPosition @SqlType("K") Block key, - @SqlNullable @BlockPosition @SqlType("V") Block value, - @BlockIndex int position) + @BlockPosition @SqlType("K") ValueBlock key, + @BlockIndex int keyPosition, + @SqlNullable @BlockPosition @SqlType("V") ValueBlock value, + @BlockIndex int valuePosition) { - state.add(key, position, value, position); + state.add(key, keyPosition, value, valuePosition); } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationState.java index 88a587be3e08..87143c148ec3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationState.java @@ -13,8 +13,8 @@ */ package io.trino.operator.aggregation.multimapagg; -import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateMetadata; @@ -26,7 +26,7 @@ public interface MultimapAggregationState extends AccumulatorState { - void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition); + void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition); void merge(MultimapAggregationState other); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateFactory.java index 7e5615dacb7c..a0682a133f0a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationStateFactory.java @@ -22,8 +22,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -52,7 +52,7 @@ public MultimapAggregationStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "K", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle keyWriteFlat, + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle keyWriteFlat, @OperatorDependency( operator = OperatorType.HASH_CODE, argumentTypes = "K", @@ -60,11 +60,11 @@ public MultimapAggregationStateFactory( @OperatorDependency( operator = OperatorType.IS_DISTINCT_FROM, argumentTypes = {"K", "K"}, - convention = @Convention(arguments = {FLAT, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle keyDistinctFlatBlock, + convention = @Convention(arguments = {FLAT, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle keyDistinctFlatBlock, @OperatorDependency( operator = OperatorType.HASH_CODE, argumentTypes = "K", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle keyHashBlock, + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle keyHashBlock, @TypeParameter("V") Type valueType, @OperatorDependency( operator = OperatorType.READ_VALUE, @@ -73,7 +73,7 @@ public MultimapAggregationStateFactory( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "V", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle valueWriteFlat) + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle valueWriteFlat) { this.keyType = requireNonNull(keyType, "keyType is null"); this.keyReadFlat = requireNonNull(keyReadFlat, "keyReadFlat is null"); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/SingleMultimapAggregationState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/SingleMultimapAggregationState.java index 269cab614349..65fa48b81eb1 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/SingleMultimapAggregationState.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/SingleMultimapAggregationState.java @@ -13,9 +13,9 @@ */ package io.trino.operator.aggregation.multimapagg; -import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AccumulatorState; import io.trino.spi.type.Type; @@ -61,7 +61,7 @@ private SingleMultimapAggregationState(SingleMultimapAggregationState state) } @Override - public void add(Block keyBlock, int keyPosition, Block valueBlock, int valuePosition) + public void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition) { add(0, keyBlock, keyPosition, valueBlock, valuePosition); } diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalPartitionGenerator.java b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalPartitionGenerator.java index 20a01eb430d6..2c065f182238 100644 --- a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalPartitionGenerator.java +++ b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalPartitionGenerator.java @@ -36,7 +36,7 @@ public LocalPartitionGenerator(HashGenerator hashGenerator, int partitionCount) } @Override - public int getPartitionCount() + public int partitionCount() { return partitionCount; } diff --git a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java index ac62fb3eee14..9708ae7092e4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java @@ -149,7 +149,7 @@ public long getJoinPosition(int position, Page hashChannelsPage, Page allChannel public void getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage, long[] rawHashes, long[] result) { int positionCount = positions.length; - int partitionCount = partitionGenerator.getPartitionCount(); + int partitionCount = partitionGenerator.partitionCount(); int[] partitions = new int[positionCount]; int[] partitionPositionsCount = new int[partitionCount]; diff --git a/core/trino-main/src/main/java/io/trino/operator/output/BytePositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/BytePositionsAppender.java index d640db4e0208..73dbf4360055 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/BytePositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/BytePositionsAppender.java @@ -16,11 +16,13 @@ import io.trino.spi.block.Block; import io.trino.spi.block.ByteArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.Arrays; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.instanceSize; import static io.airlift.slice.SizeOf.sizeOf; import static io.trino.operator.output.PositionsAppenderUtil.calculateBlockResetSize; @@ -55,9 +57,10 @@ public BytePositionsAppender(int expectedEntries) } @Override - // TODO: Make PositionsAppender work performant with different block types (https://github.com/trinodb/trino/issues/13267) - public void append(IntArrayList positions, Block block) + public void append(IntArrayList positions, ValueBlock block) { + checkArgument(block instanceof ByteArrayBlock, "Block must be instance of %s", ByteArrayBlock.class); + if (positions.isEmpty()) { return; } @@ -94,8 +97,10 @@ public void append(IntArrayList positions, Block block) } @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock block, int rlePositionCount) { + checkArgument(block instanceof ByteArrayBlock, "Block must be instance of %s", ByteArrayBlock.class); + if (rlePositionCount == 0) { return; } @@ -116,8 +121,10 @@ public void appendRle(Block block, int rlePositionCount) } @Override - public void append(int sourcePosition, Block source) + public void append(int sourcePosition, ValueBlock source) { + checkArgument(source instanceof ByteArrayBlock, "Block must be instance of %s", ByteArrayBlock.class); + ensureCapacity(positionCount + 1); if (source.isNull(sourcePosition)) { valueIsNull[positionCount] = true; @@ -185,7 +192,7 @@ private void ensureCapacity(int capacity) newSize = initialEntryCount; initialized = true; } - newSize = Math.max(newSize, capacity); + newSize = max(newSize, capacity); valueIsNull = Arrays.copyOf(valueIsNull, newSize); values = Arrays.copyOf(values, newSize); diff --git a/core/trino-main/src/main/java/io/trino/operator/output/Fixed12PositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/Fixed12PositionsAppender.java index d90694082928..d7125a989c60 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/Fixed12PositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/Fixed12PositionsAppender.java @@ -16,11 +16,13 @@ import io.trino.spi.block.Block; import io.trino.spi.block.Fixed12Block; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.Arrays; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.SIZE_OF_INT; import static io.airlift.slice.SizeOf.instanceSize; import static io.airlift.slice.SizeOf.sizeOf; @@ -56,9 +58,10 @@ public Fixed12PositionsAppender(int expectedEntries) } @Override - // TODO: Make PositionsAppender work performant with different block types (https://github.com/trinodb/trino/issues/13267) - public void append(IntArrayList positions, Block block) + public void append(IntArrayList positions, ValueBlock block) { + checkArgument(block instanceof Fixed12Block, "Block must be instance of %s", Fixed12Block.class); + if (positions.isEmpty()) { return; } @@ -100,8 +103,10 @@ public void append(IntArrayList positions, Block block) } @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock block, int rlePositionCount) { + checkArgument(block instanceof Fixed12Block, "Block must be instance of %s", Fixed12Block.class); + if (rlePositionCount == 0) { return; } @@ -130,8 +135,10 @@ public void appendRle(Block block, int rlePositionCount) } @Override - public void append(int sourcePosition, Block source) + public void append(int sourcePosition, ValueBlock source) { + checkArgument(source instanceof Fixed12Block, "Block must be instance of %s", Fixed12Block.class); + ensureCapacity(positionCount + 1); if (source.isNull(sourcePosition)) { valueIsNull[positionCount] = true; @@ -202,7 +209,7 @@ private void ensureCapacity(int capacity) newSize = initialEntryCount; initialized = true; } - newSize = Math.max(newSize, capacity); + newSize = max(newSize, capacity); valueIsNull = Arrays.copyOf(valueIsNull, newSize); values = Arrays.copyOf(values, newSize * 3); diff --git a/core/trino-main/src/main/java/io/trino/operator/output/Int128PositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/Int128PositionsAppender.java index fceb70eb4d28..4198091cc548 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/Int128PositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/Int128PositionsAppender.java @@ -16,11 +16,13 @@ import io.trino.spi.block.Block; import io.trino.spi.block.Int128ArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.Arrays; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.SIZE_OF_LONG; import static io.airlift.slice.SizeOf.instanceSize; import static io.airlift.slice.SizeOf.sizeOf; @@ -56,9 +58,10 @@ public Int128PositionsAppender(int expectedEntries) } @Override - // TODO: Make PositionsAppender work performant with different block types (https://github.com/trinodb/trino/issues/13267) - public void append(IntArrayList positions, Block block) + public void append(IntArrayList positions, ValueBlock block) { + checkArgument(block instanceof Int128ArrayBlock, "Block must be instance of %s", Int128ArrayBlock.class); + if (positions.isEmpty()) { return; } @@ -101,8 +104,10 @@ public void append(IntArrayList positions, Block block) } @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock block, int rlePositionCount) { + checkArgument(block instanceof Int128ArrayBlock, "Block must be instance of %s", Int128ArrayBlock.class); + if (rlePositionCount == 0) { return; } @@ -129,8 +134,10 @@ public void appendRle(Block block, int rlePositionCount) } @Override - public void append(int sourcePosition, Block source) + public void append(int sourcePosition, ValueBlock source) { + checkArgument(source instanceof Int128ArrayBlock, "Block must be instance of %s", Int128ArrayBlock.class); + ensureCapacity(positionCount + 1); if (source.isNull(sourcePosition)) { valueIsNull[positionCount] = true; @@ -200,7 +207,7 @@ private void ensureCapacity(int capacity) newSize = initialEntryCount; initialized = true; } - newSize = Math.max(newSize, capacity); + newSize = max(newSize, capacity); valueIsNull = Arrays.copyOf(valueIsNull, newSize); values = Arrays.copyOf(values, newSize * 2); diff --git a/core/trino-main/src/main/java/io/trino/operator/output/IntPositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/IntPositionsAppender.java index f4b28b1c5a0b..290d395d5d75 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/IntPositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/IntPositionsAppender.java @@ -16,11 +16,13 @@ import io.trino.spi.block.Block; import io.trino.spi.block.IntArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.Arrays; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.instanceSize; import static io.airlift.slice.SizeOf.sizeOf; import static io.trino.operator.output.PositionsAppenderUtil.calculateBlockResetSize; @@ -55,9 +57,10 @@ public IntPositionsAppender(int expectedEntries) } @Override - // TODO: Make PositionsAppender work performant with different block types (https://github.com/trinodb/trino/issues/13267) - public void append(IntArrayList positions, Block block) + public void append(IntArrayList positions, ValueBlock block) { + checkArgument(block instanceof IntArrayBlock, "Block must be instance of %s", IntArrayBlock.class); + if (positions.isEmpty()) { return; } @@ -94,8 +97,10 @@ public void append(IntArrayList positions, Block block) } @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock block, int rlePositionCount) { + checkArgument(block instanceof IntArrayBlock, "Block must be instance of %s", IntArrayBlock.class); + if (rlePositionCount == 0) { return; } @@ -116,8 +121,10 @@ public void appendRle(Block block, int rlePositionCount) } @Override - public void append(int sourcePosition, Block source) + public void append(int sourcePosition, ValueBlock source) { + checkArgument(source instanceof IntArrayBlock, "Block must be instance of %s", IntArrayBlock.class); + ensureCapacity(positionCount + 1); if (source.isNull(sourcePosition)) { valueIsNull[positionCount] = true; @@ -185,7 +192,7 @@ private void ensureCapacity(int capacity) newSize = initialEntryCount; initialized = true; } - newSize = Math.max(newSize, capacity); + newSize = max(newSize, capacity); valueIsNull = Arrays.copyOf(valueIsNull, newSize); values = Arrays.copyOf(values, newSize); diff --git a/core/trino-main/src/main/java/io/trino/operator/output/LongPositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/LongPositionsAppender.java index 2a5910efdec0..1b378ca502b3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/LongPositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/LongPositionsAppender.java @@ -16,11 +16,13 @@ import io.trino.spi.block.Block; import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.Arrays; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.instanceSize; import static io.airlift.slice.SizeOf.sizeOf; import static io.trino.operator.output.PositionsAppenderUtil.calculateBlockResetSize; @@ -55,9 +57,10 @@ public LongPositionsAppender(int expectedEntries) } @Override - // TODO: Make PositionsAppender work performant with different block types (https://github.com/trinodb/trino/issues/13267) - public void append(IntArrayList positions, Block block) + public void append(IntArrayList positions, ValueBlock block) { + checkArgument(block instanceof LongArrayBlock, "Block must be instance of %s", LongArrayBlock.class); + if (positions.isEmpty()) { return; } @@ -94,8 +97,10 @@ public void append(IntArrayList positions, Block block) } @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock block, int rlePositionCount) { + checkArgument(block instanceof LongArrayBlock, "Block must be instance of %s", LongArrayBlock.class); + if (rlePositionCount == 0) { return; } @@ -116,8 +121,10 @@ public void appendRle(Block block, int rlePositionCount) } @Override - public void append(int sourcePosition, Block source) + public void append(int sourcePosition, ValueBlock source) { + checkArgument(source instanceof LongArrayBlock, "Block must be instance of %s", LongArrayBlock.class); + ensureCapacity(positionCount + 1); if (source.isNull(sourcePosition)) { valueIsNull[positionCount] = true; @@ -185,7 +192,7 @@ private void ensureCapacity(int capacity) newSize = initialEntryCount; initialized = true; } - newSize = Math.max(newSize, capacity); + newSize = max(newSize, capacity); valueIsNull = Arrays.copyOf(valueIsNull, newSize); values = Arrays.copyOf(values, newSize); diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java b/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java index 72be3d3e4277..78e4a2ff12f5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java @@ -115,7 +115,7 @@ public PagePartitioner( } } - int partitionCount = partitionFunction.getPartitionCount(); + int partitionCount = partitionFunction.partitionCount(); int pageSize = toIntExact(min(DEFAULT_MAX_PAGE_SIZE_IN_BYTES, maxMemory.toBytes() / partitionCount)); pageSize = max(1, pageSize); @@ -146,7 +146,7 @@ public void partitionPage(Page page) return; } - if (page.getPositionCount() < partitionFunction.getPartitionCount() * COLUMNAR_STRATEGY_COEFFICIENT) { + if (page.getPositionCount() < partitionFunction.partitionCount() * COLUMNAR_STRATEGY_COEFFICIENT) { // Partition will have on average less than COLUMNAR_STRATEGY_COEFFICIENT rows. // Doing it column-wise would degrade performance, so we fall back to row-wise approach. // Performance degradation is the worst in case of skewed hash distribution when only small subset @@ -209,7 +209,7 @@ public void partitionPageByColumn(Page page) { IntArrayList[] partitionedPositions = partitionPositions(page); - for (int i = 0; i < partitionFunction.getPartitionCount(); i++) { + for (int i = 0; i < partitionFunction.partitionCount(); i++) { IntArrayList partitionPositions = partitionedPositions[i]; if (!partitionPositions.isEmpty()) { positionsAppenders[i].appendToOutputPartition(page, partitionPositions); @@ -259,9 +259,9 @@ private IntArrayList[] initPositions(Page page) // want memory to explode in case there are input pages with many positions, where each page // is assigned to a single partition entirely. // For example this can happen for partition columns if they are represented by RLE blocks. - IntArrayList[] partitionPositions = new IntArrayList[partitionFunction.getPartitionCount()]; + IntArrayList[] partitionPositions = new IntArrayList[partitionFunction.partitionCount()]; for (int i = 0; i < partitionPositions.length; i++) { - partitionPositions[i] = new IntArrayList(initialPartitionSize(page.getPositionCount() / partitionFunction.getPartitionCount())); + partitionPositions[i] = new IntArrayList(initialPartitionSize(page.getPositionCount() / partitionFunction.partitionCount())); } return partitionPositions; } @@ -275,7 +275,7 @@ private static int initialPartitionSize(int averagePositionsPerPartition) return (int) (averagePositionsPerPartition * 1.1) + 32; } - private boolean onlyRleBlocks(Page page) + private static boolean onlyRleBlocks(Page page) { for (int i = 0; i < page.getChannelCount(); i++) { if (!(page.getBlock(i) instanceof RunLengthEncodedBlock)) { @@ -308,7 +308,7 @@ private void partitionBySingleRleValue(Page page, int position, Page partitionFu } } - private Page extractRlePage(Page page) + private static Page extractRlePage(Page page) { Block[] valueBlocks = new Block[page.getChannelCount()]; for (int channel = 0; channel < valueBlocks.length; ++channel) { @@ -317,7 +317,7 @@ private Page extractRlePage(Page page) return new Page(valueBlocks); } - private int[] integersInRange(int start, int endExclusive) + private static int[] integersInRange(int start, int endExclusive) { int[] array = new int[endExclusive - start]; int current = start; @@ -327,7 +327,7 @@ private int[] integersInRange(int start, int endExclusive) return array; } - private boolean isDictionaryProcessingFaster(Block block) + private static boolean isDictionaryProcessingFaster(Block block) { if (!(block instanceof DictionaryBlock dictionaryBlock)) { return false; @@ -386,7 +386,7 @@ private void partitionNullablePositions(Page page, int position, IntArrayList[] } } - private void partitionNotNullPositions(Page page, int startingPosition, IntArrayList[] partitionPositions, IntUnaryOperator partitionFunction) + private static void partitionNotNullPositions(Page page, int startingPosition, IntArrayList[] partitionPositions, IntUnaryOperator partitionFunction) { int positionCount = page.getPositionCount(); int[] partitionPerPosition = new int[positionCount]; diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PagePartitionerPool.java b/core/trino-main/src/main/java/io/trino/operator/output/PagePartitionerPool.java index 2479f76a7941..1d47760e4e66 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/PagePartitionerPool.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/PagePartitionerPool.java @@ -35,7 +35,7 @@ public class PagePartitionerPool * In normal conditions, in the steady state, * the number of free {@link PagePartitioner}s is going to be close to 0. * There is a possible case though, where initially big number of concurrent drivers, say 128, - * drops to a small number e.g. 32 in a steady state. This could cause a lot of memory + * drops to a small number e.g., 32 in a steady state. This could cause a lot of memory * to be retained by the unused buffers. * To defend against that, {@link #maxFree} limits the number of free buffers, * thus limiting unused memory. diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppender.java index 4b0a38b61a53..e861c4e0f685 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppender.java @@ -14,27 +14,28 @@ package io.trino.operator.output; import io.trino.spi.block.Block; +import io.trino.spi.block.ValueBlock; import it.unimi.dsi.fastutil.ints.IntArrayList; public interface PositionsAppender { - void append(IntArrayList positions, Block source); + void append(IntArrayList positions, ValueBlock source); /** * Appends the specified value positionCount times. - * The result is the same as with using {@link PositionsAppender#append(IntArrayList, Block)} with - * positions list [0...positionCount -1] but with possible performance optimizations. + * The result is the same as with using {@link PositionsAppender#append(IntArrayList, ValueBlock)} with + * a position list [0...positionCount -1] but with possible performance optimizations. */ - void appendRle(Block value, int rlePositionCount); + void appendRle(ValueBlock value, int rlePositionCount); /** * Appends single position. The implementation must be conceptually equal to * {@code append(IntArrayList.wrap(new int[] {position}), source)} but may be optimized. - * Caller should avoid using this method if {@link #append(IntArrayList, Block)} can be used + * Caller should avoid using this method if {@link #append(IntArrayList, ValueBlock)} can be used * as appending positions one by one can be significantly slower and may not support features * like pushing RLE through the appender. */ - void append(int position, Block source); + void append(int position, ValueBlock source); /** * Creates the block from the appender data. diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderFactory.java b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderFactory.java index a597983d6a19..34eab30e020e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderFactory.java @@ -13,13 +13,20 @@ */ package io.trino.operator.output; +import io.trino.spi.block.ByteArrayBlock; import io.trino.spi.block.Fixed12Block; import io.trino.spi.block.Int128ArrayBlock; -import io.trino.spi.type.FixedWidthType; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.block.RowBlock; +import io.trino.spi.block.ShortArrayBlock; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; -import io.trino.spi.type.VariableWidthType; import io.trino.type.BlockTypeOperators; +import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom; + +import java.util.Optional; import static java.util.Objects.requireNonNull; @@ -32,45 +39,41 @@ public PositionsAppenderFactory(BlockTypeOperators blockTypeOperators) this.blockTypeOperators = requireNonNull(blockTypeOperators, "blockTypeOperators is null"); } - public PositionsAppender create(Type type, int expectedPositions, long maxPageSizeInBytes) + public UnnestingPositionsAppender create(Type type, int expectedPositions, long maxPageSizeInBytes) { - if (!type.isComparable()) { - return new UnnestingPositionsAppender(createPrimitiveAppender(type, expectedPositions, maxPageSizeInBytes)); + Optional distinctFromOperator = Optional.empty(); + if (type.isComparable()) { + distinctFromOperator = Optional.of(blockTypeOperators.getDistinctFromOperator(type)); } - - return new UnnestingPositionsAppender( - new RleAwarePositionsAppender( - blockTypeOperators.getDistinctFromOperator(type), - createPrimitiveAppender(type, expectedPositions, maxPageSizeInBytes))); + return new UnnestingPositionsAppender(createPrimitiveAppender(type, expectedPositions, maxPageSizeInBytes), distinctFromOperator); } private PositionsAppender createPrimitiveAppender(Type type, int expectedPositions, long maxPageSizeInBytes) { - if (type instanceof FixedWidthType) { - switch (((FixedWidthType) type).getFixedSize()) { - case Byte.BYTES: - return new BytePositionsAppender(expectedPositions); - case Short.BYTES: - return new ShortPositionsAppender(expectedPositions); - case Integer.BYTES: - return new IntPositionsAppender(expectedPositions); - case Long.BYTES: - return new LongPositionsAppender(expectedPositions); - case Fixed12Block.FIXED12_BYTES: - return new Fixed12PositionsAppender(expectedPositions); - case Int128ArrayBlock.INT128_BYTES: - return new Int128PositionsAppender(expectedPositions); - default: - // size not supported directly, fallback to the generic appender - } + if (type.getValueBlockType() == ByteArrayBlock.class) { + return new BytePositionsAppender(expectedPositions); + } + if (type.getValueBlockType() == ShortArrayBlock.class) { + return new ShortPositionsAppender(expectedPositions); + } + if (type.getValueBlockType() == IntArrayBlock.class) { + return new IntPositionsAppender(expectedPositions); } - else if (type instanceof VariableWidthType) { + if (type.getValueBlockType() == LongArrayBlock.class) { + return new LongPositionsAppender(expectedPositions); + } + if (type.getValueBlockType() == Fixed12Block.class) { + return new Fixed12PositionsAppender(expectedPositions); + } + if (type.getValueBlockType() == Int128ArrayBlock.class) { + return new Int128PositionsAppender(expectedPositions); + } + if (type.getValueBlockType() == VariableWidthBlock.class) { return new SlicePositionsAppender(expectedPositions, maxPageSizeInBytes); } - else if (type instanceof RowType) { + if (type.getValueBlockType() == RowBlock.class) { return RowPositionsAppender.createRowAppender(this, (RowType) type, expectedPositions, maxPageSizeInBytes); } - return new TypedPositionsAppender(type, expectedPositions); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderPageBuilder.java b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderPageBuilder.java index e19aaeb97401..7b113d87d429 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderPageBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderPageBuilder.java @@ -26,7 +26,7 @@ public class PositionsAppenderPageBuilder { private static final int DEFAULT_INITIAL_EXPECTED_ENTRIES = 8; - private final PositionsAppender[] channelAppenders; + private final UnnestingPositionsAppender[] channelAppenders; private final int maxPageSizeInBytes; private int declaredPositions; @@ -45,7 +45,7 @@ private PositionsAppenderPageBuilder( requireNonNull(positionsAppenderFactory, "positionsAppenderFactory is null"); this.maxPageSizeInBytes = maxPageSizeInBytes; - channelAppenders = new PositionsAppender[types.size()]; + channelAppenders = new UnnestingPositionsAppender[types.size()]; for (int i = 0; i < channelAppenders.length; i++) { channelAppenders[i] = positionsAppenderFactory.create(types.get(i), initialExpectedEntries, maxPageSizeInBytes); } @@ -76,7 +76,7 @@ public long getRetainedSizeInBytes() // We use a foreach loop instead of streams // as it has much better performance. long retainedSizeInBytes = 0; - for (PositionsAppender positionsAppender : channelAppenders) { + for (UnnestingPositionsAppender positionsAppender : channelAppenders) { retainedSizeInBytes += positionsAppender.getRetainedSizeInBytes(); } return retainedSizeInBytes; @@ -85,13 +85,13 @@ public long getRetainedSizeInBytes() public long getSizeInBytes() { long sizeInBytes = 0; - for (PositionsAppender positionsAppender : channelAppenders) { + for (UnnestingPositionsAppender positionsAppender : channelAppenders) { sizeInBytes += positionsAppender.getSizeInBytes(); } return sizeInBytes; } - public void declarePositions(int positions) + private void declarePositions(int positions) { declaredPositions += positions; } diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderUtil.java b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderUtil.java index 001e60e460e4..0d1d6b642096 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderUtil.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderUtil.java @@ -31,7 +31,7 @@ private PositionsAppenderUtil() // Copied from io.trino.spi.block.BlockUtil#calculateNewArraySize static int calculateNewArraySize(int currentSize) { - // grow array by 50% + // grow the array by 50% long newSize = (long) currentSize + (currentSize >> 1); // verify new size is within reasonable bounds diff --git a/core/trino-main/src/main/java/io/trino/operator/output/RleAwarePositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/RleAwarePositionsAppender.java deleted file mode 100644 index 82480d7edce6..000000000000 --- a/core/trino-main/src/main/java/io/trino/operator/output/RleAwarePositionsAppender.java +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator.output; - -import io.trino.spi.block.Block; -import io.trino.spi.block.RunLengthEncodedBlock; -import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom; -import it.unimi.dsi.fastutil.ints.IntArrayList; -import jakarta.annotation.Nullable; - -import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.slice.SizeOf.instanceSize; -import static java.util.Objects.requireNonNull; - -/** - * {@link PositionsAppender} that will produce {@link RunLengthEncodedBlock} output if possible, - * that is all inputs are {@link RunLengthEncodedBlock} blocks with the same value. - */ -public class RleAwarePositionsAppender - implements PositionsAppender -{ - private static final int INSTANCE_SIZE = instanceSize(RleAwarePositionsAppender.class); - private static final int NO_RLE = -1; - - private final BlockPositionIsDistinctFrom isDistinctFromOperator; - private final PositionsAppender delegate; - - @Nullable - private Block rleValue; - - // NO_RLE means flat state, 0 means initial empty state, positive means RLE state and the current RLE position count. - private int rlePositionCount; - - public RleAwarePositionsAppender(BlockPositionIsDistinctFrom isDistinctFromOperator, PositionsAppender delegate) - { - this.delegate = requireNonNull(delegate, "delegate is null"); - this.isDistinctFromOperator = requireNonNull(isDistinctFromOperator, "isDistinctFromOperator is null"); - } - - @Override - public void append(IntArrayList positions, Block source) - { - // RleAwarePositionsAppender should be used with UnnestingPositionsAppender that makes sure - // append is called only with flat block - checkArgument(!(source instanceof RunLengthEncodedBlock), "Append should be called with non-RLE block but got %s", source); - switchToFlat(); - delegate.append(positions, source); - } - - @Override - public void appendRle(Block value, int positionCount) - { - if (positionCount == 0) { - return; - } - checkArgument(value.getPositionCount() == 1, "Expected value to contain a single position but has %d positions".formatted(value.getPositionCount())); - - if (rlePositionCount == 0) { - // initial empty state, switch to RLE state - rleValue = value; - rlePositionCount = positionCount; - } - else if (rleValue != null) { - // we are in the RLE state - if (!isDistinctFromOperator.isDistinctFrom(rleValue, 0, value, 0)) { - // the values match. we can just add positions. - this.rlePositionCount += positionCount; - return; - } - // RLE values do not match. switch to flat state - switchToFlat(); - delegate.appendRle(value, positionCount); - } - else { - // flat state - delegate.appendRle(value, positionCount); - } - } - - @Override - public void append(int position, Block value) - { - switchToFlat(); - delegate.append(position, value); - } - - @Override - public Block build() - { - Block result; - if (rleValue != null) { - result = RunLengthEncodedBlock.create(rleValue, rlePositionCount); - } - else { - result = delegate.build(); - } - - reset(); - return result; - } - - private void reset() - { - rleValue = null; - rlePositionCount = 0; - } - - @Override - public long getRetainedSizeInBytes() - { - long retainedRleSize = rleValue != null ? rleValue.getRetainedSizeInBytes() : 0; - return INSTANCE_SIZE + retainedRleSize + delegate.getRetainedSizeInBytes(); - } - - @Override - public long getSizeInBytes() - { - long rleSize = rleValue != null ? rleValue.getSizeInBytes() : 0; - return rleSize + delegate.getSizeInBytes(); - } - - private void switchToFlat() - { - if (rleValue != null) { - // we are in the RLE state, flatten all RLE blocks - delegate.appendRle(rleValue, rlePositionCount); - rleValue = null; - } - rlePositionCount = NO_RLE; - } -} diff --git a/core/trino-main/src/main/java/io/trino/operator/output/RowPositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/RowPositionsAppender.java index 84634f3ef66e..da334b7b5506 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/RowPositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/RowPositionsAppender.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.RowBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.RowType; import it.unimi.dsi.fastutil.ints.IntArrayList; @@ -23,6 +24,7 @@ import java.util.List; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.instanceSize; import static io.airlift.slice.SizeOf.sizeOf; import static io.trino.operator.output.PositionsAppenderUtil.calculateBlockResetSize; @@ -34,7 +36,7 @@ public class RowPositionsAppender implements PositionsAppender { private static final int INSTANCE_SIZE = instanceSize(RowPositionsAppender.class); - private final PositionsAppender[] fieldAppenders; + private final UnnestingPositionsAppender[] fieldAppenders; private int initialEntryCount; private boolean initialized; @@ -51,14 +53,14 @@ public static RowPositionsAppender createRowAppender( int expectedPositions, long maxPageSizeInBytes) { - PositionsAppender[] fields = new PositionsAppender[type.getFields().size()]; + UnnestingPositionsAppender[] fields = new UnnestingPositionsAppender[type.getFields().size()]; for (int i = 0; i < fields.length; i++) { fields[i] = positionsAppenderFactory.create(type.getFields().get(i).getType(), expectedPositions, maxPageSizeInBytes); } return new RowPositionsAppender(fields, expectedPositions); } - private RowPositionsAppender(PositionsAppender[] fieldAppenders, int expectedPositions) + private RowPositionsAppender(UnnestingPositionsAppender[] fieldAppenders, int expectedPositions) { this.fieldAppenders = requireNonNull(fieldAppenders, "fields is null"); this.initialEntryCount = expectedPositions; @@ -66,39 +68,30 @@ private RowPositionsAppender(PositionsAppender[] fieldAppenders, int expectedPos } @Override - // TODO: Make PositionsAppender work performant with different block types (https://github.com/trinodb/trino/issues/13267) - public void append(IntArrayList positions, Block block) + public void append(IntArrayList positions, ValueBlock block) { + checkArgument(block instanceof RowBlock, "Block must be instance of %s", RowBlock.class); + if (positions.isEmpty()) { return; } ensureCapacity(positions.size()); - if (block instanceof RowBlock sourceRowBlock) { - IntArrayList nonNullPositions; - if (sourceRowBlock.mayHaveNull()) { - nonNullPositions = processNullablePositions(positions, sourceRowBlock); - hasNullRow |= nonNullPositions.size() < positions.size(); - hasNonNullRow |= nonNullPositions.size() > 0; - } - else { - // the source Block does not have nulls - nonNullPositions = processNonNullablePositions(positions, sourceRowBlock); - hasNonNullRow = true; - } - - List fieldBlocks = sourceRowBlock.getChildren(); - for (int i = 0; i < fieldAppenders.length; i++) { - fieldAppenders[i].append(nonNullPositions, fieldBlocks.get(i)); - } - } - else if (allPositionsNull(positions, block)) { - // all input positions are null. We can handle that even if block type is not RowBLock. - // append positions.size() nulls - Arrays.fill(rowIsNull, positionCount, positionCount + positions.size(), true); - hasNullRow = true; + RowBlock sourceRowBlock = (RowBlock) block; + IntArrayList nonNullPositions; + if (sourceRowBlock.mayHaveNull()) { + nonNullPositions = processNullablePositions(positions, sourceRowBlock); + hasNullRow |= nonNullPositions.size() < positions.size(); + hasNonNullRow |= !nonNullPositions.isEmpty(); } else { - throw new IllegalArgumentException("unsupported block type: " + block); + // the source Block does not have nulls + nonNullPositions = processNonNullablePositions(positions, sourceRowBlock); + hasNonNullRow = true; + } + + List fieldBlocks = sourceRowBlock.getChildren(); + for (int i = 0; i < fieldAppenders.length; i++) { + fieldAppenders[i].append(nonNullPositions, fieldBlocks.get(i)); } positionCount += positions.size(); @@ -106,62 +99,49 @@ else if (allPositionsNull(positions, block)) { } @Override - public void appendRle(Block value, int rlePositionCount) + public void appendRle(ValueBlock value, int rlePositionCount) { + checkArgument(value instanceof RowBlock, "Block must be instance of %s", RowBlock.class); + ensureCapacity(rlePositionCount); - if (value instanceof RowBlock sourceRowBlock) { - if (sourceRowBlock.isNull(0)) { - // append rlePositionCount nulls - Arrays.fill(rowIsNull, positionCount, positionCount + rlePositionCount, true); - hasNullRow = true; - } - else { - // append not null row value - List fieldBlocks = sourceRowBlock.getChildren(); - int fieldPosition = sourceRowBlock.getFieldBlockOffset(0); - for (int i = 0; i < fieldAppenders.length; i++) { - fieldAppenders[i].appendRle(fieldBlocks.get(i).getSingleValueBlock(fieldPosition), rlePositionCount); - } - hasNonNullRow = true; - } - } - else if (value.isNull(0)) { + RowBlock sourceRowBlock = (RowBlock) value; + if (sourceRowBlock.isNull(0)) { // append rlePositionCount nulls Arrays.fill(rowIsNull, positionCount, positionCount + rlePositionCount, true); hasNullRow = true; } else { - throw new IllegalArgumentException("unsupported block type: " + value); + // append not null row value + List fieldBlocks = sourceRowBlock.getChildren(); + int fieldPosition = sourceRowBlock.getFieldBlockOffset(0); + for (int i = 0; i < fieldAppenders.length; i++) { + fieldAppenders[i].appendRle(fieldBlocks.get(i).getSingleValueBlock(fieldPosition), rlePositionCount); + } + hasNonNullRow = true; } positionCount += rlePositionCount; resetSize(); } @Override - public void append(int position, Block value) + public void append(int position, ValueBlock value) { + checkArgument(value instanceof RowBlock, "Block must be instance of %s", RowBlock.class); + ensureCapacity(1); - if (value instanceof RowBlock sourceRowBlock) { - if (sourceRowBlock.isNull(position)) { - rowIsNull[positionCount] = true; - hasNullRow = true; - } - else { - // append not null row value - List fieldBlocks = sourceRowBlock.getChildren(); - int fieldPosition = sourceRowBlock.getFieldBlockOffset(position); - for (int i = 0; i < fieldAppenders.length; i++) { - fieldAppenders[i].append(fieldPosition, fieldBlocks.get(i)); - } - hasNonNullRow = true; - } - } - else if (value.isNull(position)) { + RowBlock sourceRowBlock = (RowBlock) value; + if (sourceRowBlock.isNull(position)) { rowIsNull[positionCount] = true; hasNullRow = true; } else { - throw new IllegalArgumentException("unsupported block type: " + value); + // append not null row value + List fieldBlocks = sourceRowBlock.getChildren(); + int fieldPosition = sourceRowBlock.getFieldBlockOffset(position); + for (int i = 0; i < fieldAppenders.length; i++) { + fieldAppenders[i].append(fieldPosition, fieldBlocks.get(i)); + } + hasNonNullRow = true; } positionCount++; resetSize(); @@ -195,7 +175,7 @@ public long getRetainedSizeInBytes() } long size = INSTANCE_SIZE + sizeOf(rowIsNull); - for (PositionsAppender field : fieldAppenders) { + for (UnnestingPositionsAppender field : fieldAppenders) { size += field.getRetainedSizeInBytes(); } @@ -211,7 +191,7 @@ public long getSizeInBytes() } long size = (Integer.BYTES + Byte.BYTES) * (long) positionCount; - for (PositionsAppender field : fieldAppenders) { + for (UnnestingPositionsAppender field : fieldAppenders) { size += field.getSizeInBytes(); } @@ -230,16 +210,6 @@ private void reset() resetSize(); } - private boolean allPositionsNull(IntArrayList positions, Block block) - { - for (int i = 0; i < positions.size(); i++) { - if (!block.isNull(positions.getInt(i))) { - return false; - } - } - return true; - } - private IntArrayList processNullablePositions(IntArrayList positions, RowBlock sourceRowBlock) { int[] nonNullPositions = new int[positions.size()]; @@ -256,7 +226,7 @@ private IntArrayList processNullablePositions(IntArrayList positions, RowBlock s return IntArrayList.wrap(nonNullPositions, nonNullPositionsCount); } - private IntArrayList processNonNullablePositions(IntArrayList positions, RowBlock sourceRowBlock) + private static IntArrayList processNonNullablePositions(IntArrayList positions, RowBlock sourceRowBlock) { int[] nonNullPositions = new int[positions.size()]; for (int i = 0; i < positions.size(); i++) { diff --git a/core/trino-main/src/main/java/io/trino/operator/output/ShortPositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/ShortPositionsAppender.java index 21afc3a700bc..acc9f9f23159 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/ShortPositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/ShortPositionsAppender.java @@ -16,11 +16,13 @@ import io.trino.spi.block.Block; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.block.ShortArrayBlock; +import io.trino.spi.block.ValueBlock; import it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.Arrays; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.instanceSize; import static io.airlift.slice.SizeOf.sizeOf; import static io.trino.operator.output.PositionsAppenderUtil.calculateBlockResetSize; @@ -55,9 +57,10 @@ public ShortPositionsAppender(int expectedEntries) } @Override - // TODO: Make PositionsAppender work performant with different block types (https://github.com/trinodb/trino/issues/13267) - public void append(IntArrayList positions, Block block) + public void append(IntArrayList positions, ValueBlock block) { + checkArgument(block instanceof ShortArrayBlock, "Block must be instance of %s", ShortArrayBlock.class); + if (positions.isEmpty()) { return; } @@ -94,8 +97,10 @@ public void append(IntArrayList positions, Block block) } @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock block, int rlePositionCount) { + checkArgument(block instanceof ShortArrayBlock, "Block must be instance of %s", ShortArrayBlock.class); + if (rlePositionCount == 0) { return; } @@ -116,8 +121,10 @@ public void appendRle(Block block, int rlePositionCount) } @Override - public void append(int sourcePosition, Block source) + public void append(int sourcePosition, ValueBlock source) { + checkArgument(source instanceof ShortArrayBlock, "Block must be instance of %s", ShortArrayBlock.class); + ensureCapacity(positionCount + 1); if (source.isNull(sourcePosition)) { valueIsNull[positionCount] = true; @@ -185,7 +192,7 @@ private void ensureCapacity(int capacity) newSize = initialEntryCount; initialized = true; } - newSize = Math.max(newSize, capacity); + newSize = max(newSize, capacity); valueIsNull = Arrays.copyOf(valueIsNull, newSize); values = Arrays.copyOf(values, newSize); diff --git a/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionFunction.java b/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionFunction.java index 058e5e49a19a..638fc54f3b16 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionFunction.java @@ -32,11 +32,11 @@ public SkewedPartitionFunction(PartitionFunction partitionFunction, SkewedPartit this.partitionFunction = requireNonNull(partitionFunction, "partitionFunction is null"); this.skewedPartitionRebalancer = requireNonNull(skewedPartitionRebalancer, "skewedPartitionRebalancer is null"); - this.partitionRowCount = new long[partitionFunction.getPartitionCount()]; + this.partitionRowCount = new long[partitionFunction.partitionCount()]; } @Override - public int getPartitionCount() + public int partitionCount() { return skewedPartitionRebalancer.getTaskCount(); } @@ -50,7 +50,7 @@ public int getPartition(Page page, int position) public void flushPartitionRowCountToRebalancer() { - for (int partition = 0; partition < partitionFunction.getPartitionCount(); partition++) { + for (int partition = 0; partition < partitionFunction.partitionCount(); partition++) { skewedPartitionRebalancer.addPartitionRowCount(partition, partitionRowCount[partition]); partitionRowCount[partition] = 0; } diff --git a/core/trino-main/src/main/java/io/trino/operator/output/SlicePositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/SlicePositionsAppender.java index 7849f6c61501..1d5d5d64ffd3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/SlicePositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/SlicePositionsAppender.java @@ -18,12 +18,14 @@ import io.airlift.slice.Slices; import io.trino.spi.block.Block; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import io.trino.spi.block.VariableWidthBlock; import it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.Arrays; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.SIZE_OF_BYTE; import static io.airlift.slice.SizeOf.SIZE_OF_INT; import static io.airlift.slice.SizeOf.instanceSize; @@ -51,7 +53,7 @@ public class SlicePositionsAppender private boolean hasNullValue; private boolean hasNonNullValue; - // it is assumed that the offsets array is one position longer than the valueIsNull array + // it is assumed that the offset array is one position longer than the valueIsNull array private boolean[] valueIsNull = new boolean[0]; private int[] offsets = new int[1]; @@ -74,54 +76,53 @@ public SlicePositionsAppender(int expectedEntries, int expectedBytes) } @Override - // TODO: Make PositionsAppender work performant with different block types (https://github.com/trinodb/trino/issues/13267) - public void append(IntArrayList positions, Block block) + public void append(IntArrayList positions, ValueBlock block) { + checkArgument(block instanceof VariableWidthBlock, "Block must be instance of %s", VariableWidthBlock.class); + if (positions.isEmpty()) { return; } ensurePositionCapacity(positionCount + positions.size()); - if (block instanceof VariableWidthBlock variableWidthBlock) { - int newByteCount = 0; - int[] lengths = new int[positions.size()]; - int[] sourceOffsets = new int[positions.size()]; - int[] positionArray = positions.elements(); - - if (block.mayHaveNull()) { - for (int i = 0; i < positions.size(); i++) { - int position = positionArray[i]; - int length = variableWidthBlock.getSliceLength(position); - lengths[i] = length; - sourceOffsets[i] = variableWidthBlock.getRawSliceOffset(position); - newByteCount += length; - boolean isNull = block.isNull(position); - valueIsNull[positionCount + i] = isNull; - offsets[positionCount + i + 1] = offsets[positionCount + i] + length; - hasNullValue |= isNull; - hasNonNullValue |= !isNull; - } - } - else { - for (int i = 0; i < positions.size(); i++) { - int position = positionArray[i]; - int length = variableWidthBlock.getSliceLength(position); - lengths[i] = length; - sourceOffsets[i] = variableWidthBlock.getRawSliceOffset(position); - newByteCount += length; - offsets[positionCount + i + 1] = offsets[positionCount + i] + length; - } - hasNonNullValue = true; + VariableWidthBlock variableWidthBlock = (VariableWidthBlock) block; + int newByteCount = 0; + int[] lengths = new int[positions.size()]; + int[] sourceOffsets = new int[positions.size()]; + int[] positionArray = positions.elements(); + + if (block.mayHaveNull()) { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + int length = variableWidthBlock.getSliceLength(position); + lengths[i] = length; + sourceOffsets[i] = variableWidthBlock.getRawSliceOffset(position); + newByteCount += length; + boolean isNull = block.isNull(position); + valueIsNull[positionCount + i] = isNull; + offsets[positionCount + i + 1] = offsets[positionCount + i] + length; + hasNullValue |= isNull; + hasNonNullValue |= !isNull; } - copyBytes(variableWidthBlock.getRawSlice(), lengths, sourceOffsets, positions.size(), newByteCount); } else { - appendGenericBlock(positions, block); + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + int length = variableWidthBlock.getSliceLength(position); + lengths[i] = length; + sourceOffsets[i] = variableWidthBlock.getRawSliceOffset(position); + newByteCount += length; + offsets[positionCount + i + 1] = offsets[positionCount + i] + length; + } + hasNonNullValue = true; } + copyBytes(variableWidthBlock.getRawSlice(), lengths, sourceOffsets, positions.size(), newByteCount); } @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock block, int rlePositionCount) { + checkArgument(block instanceof VariableWidthBlock, "Block must be instance of %s", VariableWidthBlock.class); + if (rlePositionCount == 0) { return; } @@ -141,8 +142,10 @@ public void appendRle(Block block, int rlePositionCount) } @Override - public void append(int position, Block source) + public void append(int position, ValueBlock source) { + checkArgument(source instanceof VariableWidthBlock, "Block must be instance of %s but is %s".formatted(VariableWidthBlock.class, source.getClass())); + ensurePositionCapacity(positionCount + 1); if (source.isNull(position)) { valueIsNull[positionCount] = true; @@ -259,30 +262,6 @@ static void duplicateBytes(Slice slice, byte[] bytes, int startOffset, int count System.arraycopy(bytes, startOffset, bytes, startOffset + duplicatedBytes, totalDuplicatedBytes - duplicatedBytes); } - private void appendGenericBlock(IntArrayList positions, Block block) - { - int newByteCount = 0; - for (int i = 0; i < positions.size(); i++) { - int position = positions.getInt(i); - if (block.isNull(position)) { - offsets[positionCount + 1] = offsets[positionCount]; - valueIsNull[positionCount] = true; - hasNullValue = true; - } - else { - int length = block.getSliceLength(position); - ensureExtraBytesCapacity(length); - Slice slice = block.getSlice(position, 0, length); - slice.getBytes(0, bytes, offsets[positionCount], length); - offsets[positionCount + 1] = offsets[positionCount] + length; - hasNonNullValue = true; - newByteCount += length; - } - positionCount++; - } - updateSize(positions.size(), newByteCount); - } - private void reset() { initialEntryCount = calculateBlockResetSize(positionCount); diff --git a/core/trino-main/src/main/java/io/trino/operator/output/TaskOutputOperator.java b/core/trino-main/src/main/java/io/trino/operator/output/TaskOutputOperator.java index 49b9e87595ec..b687ed09ac74 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/TaskOutputOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/TaskOutputOperator.java @@ -134,7 +134,7 @@ public boolean isFinished() @Override public ListenableFuture isBlocked() { - // Avoid re-synchronizing on the output buffer when operator is already blocked + // Avoid re-synchronizing on the output buffer when the operator is already blocked if (isBlocked.isDone()) { isBlocked = outputBuffer.isFull(); if (isBlocked.isDone()) { diff --git a/core/trino-main/src/main/java/io/trino/operator/output/TypedPositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/TypedPositionsAppender.java index 1f66dd05d0dc..9d8ff32d4478 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/TypedPositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/TypedPositionsAppender.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import it.unimi.dsi.fastutil.ints.IntArrayList; @@ -30,20 +31,13 @@ class TypedPositionsAppender private BlockBuilder blockBuilder; TypedPositionsAppender(Type type, int expectedPositions) - { - this( - type, - type.createBlockBuilder(null, expectedPositions)); - } - - TypedPositionsAppender(Type type, BlockBuilder blockBuilder) { this.type = requireNonNull(type, "type is null"); - this.blockBuilder = requireNonNull(blockBuilder, "blockBuilder is null"); + this.blockBuilder = type.createBlockBuilder(null, expectedPositions); } @Override - public void append(IntArrayList positions, Block source) + public void append(IntArrayList positions, ValueBlock source) { int[] positionArray = positions.elements(); for (int i = 0; i < positions.size(); i++) { @@ -52,7 +46,7 @@ public void append(IntArrayList positions, Block source) } @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock block, int rlePositionCount) { for (int i = 0; i < rlePositionCount; i++) { type.appendTo(block, 0, blockBuilder); @@ -60,7 +54,7 @@ public void appendRle(Block block, int rlePositionCount) } @Override - public void append(int position, Block source) + public void append(int position, ValueBlock source) { type.appendTo(source, position, blockBuilder); } diff --git a/core/trino-main/src/main/java/io/trino/operator/output/UnnestingPositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/UnnestingPositionsAppender.java index cedc3ece1cfd..c360c9b7cb8f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/UnnestingPositionsAppender.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/UnnestingPositionsAppender.java @@ -16,11 +16,18 @@ import io.trino.spi.block.Block; import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; +import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom; import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.ints.IntArrays; +import jakarta.annotation.Nullable; + +import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; import static io.airlift.slice.SizeOf.instanceSize; +import static io.airlift.slice.SizeOf.sizeOf; import static io.trino.operator.output.PositionsAppenderUtil.calculateBlockResetSize; import static io.trino.operator.output.PositionsAppenderUtil.calculateNewArraySize; import static java.lang.Math.max; @@ -30,52 +37,110 @@ * Dispatches the {@link #append} and {@link #appendRle} methods to the {@link #delegate} depending on the input {@link Block} class. */ public class UnnestingPositionsAppender - implements PositionsAppender { private static final int INSTANCE_SIZE = instanceSize(UnnestingPositionsAppender.class); + // The initial state will transition to either the DICTIONARY or RLE state, and from there to the DIRECT state if necessary. + private enum State + { + UNINITIALIZED, DICTIONARY, RLE, DIRECT + } + private final PositionsAppender delegate; - private DictionaryBlockBuilder dictionaryBlockBuilder; + @Nullable + private final BlockPositionIsDistinctFrom isDistinctFromOperator; - public UnnestingPositionsAppender(PositionsAppender delegate) + private State state = State.UNINITIALIZED; + + private ValueBlock dictionary; + private DictionaryIdsBuilder dictionaryIdsBuilder; + + @Nullable + private ValueBlock rleValue; + private int rlePositionCount; + + public UnnestingPositionsAppender(PositionsAppender delegate, Optional isDistinctFromOperator) { this.delegate = requireNonNull(delegate, "delegate is null"); - this.dictionaryBlockBuilder = new DictionaryBlockBuilder(); + this.dictionaryIdsBuilder = new DictionaryIdsBuilder(1024); + this.isDistinctFromOperator = isDistinctFromOperator.orElse(null); } - @Override public void append(IntArrayList positions, Block source) { if (positions.isEmpty()) { return; } - if (source instanceof RunLengthEncodedBlock) { - dictionaryBlockBuilder.flushDictionary(delegate); - delegate.appendRle(((RunLengthEncodedBlock) source).getValue(), positions.size()); + + if (source instanceof RunLengthEncodedBlock rleBlock) { + appendRle(rleBlock.getValue(), positions.size()); + } + else if (source instanceof DictionaryBlock dictionaryBlock) { + ValueBlock dictionary = dictionaryBlock.getDictionary(); + if (state == State.UNINITIALIZED) { + state = State.DICTIONARY; + this.dictionary = dictionary; + dictionaryIdsBuilder.appendPositions(positions, dictionaryBlock); + } + else if (state == State.DICTIONARY && this.dictionary == dictionary) { + dictionaryIdsBuilder.appendPositions(positions, dictionaryBlock); + } + else { + transitionToDirect(); + + int[] positionArray = new int[positions.size()]; + for (int i = 0; i < positions.size(); i++) { + positionArray[i] = dictionaryBlock.getId(positions.getInt(i)); + } + delegate.append(IntArrayList.wrap(positionArray), dictionary); + } } - else if (source instanceof DictionaryBlock) { - appendDictionary(positions, (DictionaryBlock) source); + else if (source instanceof ValueBlock valueBlock) { + transitionToDirect(); + delegate.append(positions, valueBlock); } else { - dictionaryBlockBuilder.flushDictionary(delegate); - delegate.append(positions, source); + throw new IllegalArgumentException("Unsupported block type: " + source.getClass().getSimpleName()); } } - @Override - public void appendRle(Block block, int rlePositionCount) + public void appendRle(ValueBlock value, int positionCount) { - if (rlePositionCount == 0) { + if (positionCount == 0) { return; } - dictionaryBlockBuilder.flushDictionary(delegate); - delegate.appendRle(block, rlePositionCount); + + if (state == State.DICTIONARY) { + transitionToDirect(); + } + if (isDistinctFromOperator == null) { + transitionToDirect(); + } + + if (state == State.UNINITIALIZED) { + state = State.RLE; + rleValue = value; + rlePositionCount = positionCount; + return; + } + if (state == State.RLE) { + if (!isDistinctFromOperator.isDistinctFrom(rleValue, 0, value, 0)) { + // the values match. we can just add positions. + rlePositionCount += positionCount; + return; + } + transitionToDirect(); + } + + verify(state == State.DIRECT); + delegate.appendRle(value, positionCount); } - @Override public void append(int position, Block source) { - dictionaryBlockBuilder.flushDictionary(delegate); + if (state != State.DIRECT) { + transitionToDirect(); + } if (source instanceof RunLengthEncodedBlock runLengthEncodedBlock) { delegate.append(0, runLengthEncodedBlock.getValue()); @@ -83,134 +148,108 @@ public void append(int position, Block source) else if (source instanceof DictionaryBlock dictionaryBlock) { delegate.append(dictionaryBlock.getId(position), dictionaryBlock.getDictionary()); } + else if (source instanceof ValueBlock valueBlock) { + delegate.append(position, valueBlock); + } else { - delegate.append(position, source); + throw new IllegalArgumentException("Unsupported block type: " + source.getClass().getSimpleName()); } } - @Override - public Block build() + private void transitionToDirect() { - Block result; - if (dictionaryBlockBuilder.isEmpty()) { - result = delegate.build(); + if (state == State.DICTIONARY) { + int[] dictionaryIds = dictionaryIdsBuilder.getDictionaryIds(); + delegate.append(IntArrayList.wrap(dictionaryIds, dictionaryIdsBuilder.size()), dictionary); + dictionary = null; + dictionaryIdsBuilder = dictionaryIdsBuilder.newBuilderLike(); } - else { - result = dictionaryBlockBuilder.build(); + else if (state == State.RLE) { + delegate.appendRle(rleValue, rlePositionCount); + rleValue = null; + rlePositionCount = 0; } - dictionaryBlockBuilder = dictionaryBlockBuilder.newBuilderLike(); - return result; + state = State.DIRECT; } - @Override - public long getRetainedSizeInBytes() + public Block build() { - return INSTANCE_SIZE + delegate.getRetainedSizeInBytes() + dictionaryBlockBuilder.getRetainedSizeInBytes(); - } + Block result = switch (state) { + case DICTIONARY -> DictionaryBlock.create(dictionaryIdsBuilder.size(), dictionary, dictionaryIdsBuilder.getDictionaryIds()); + case RLE -> RunLengthEncodedBlock.create(rleValue, rlePositionCount); + case UNINITIALIZED, DIRECT -> delegate.build(); + }; - @Override - public long getSizeInBytes() - { - return delegate.getSizeInBytes(); + state = State.UNINITIALIZED; + dictionary = null; + dictionaryIdsBuilder = dictionaryIdsBuilder.newBuilderLike(); + rleValue = null; + rlePositionCount = 0; + + return result; } - private void appendDictionary(IntArrayList positions, DictionaryBlock source) + public long getRetainedSizeInBytes() { - Block dictionary = source.getDictionary(); - if (dictionary instanceof RunLengthEncodedBlock rleDictionary) { - appendRle(rleDictionary.getValue(), positions.size()); - return; - } - - IntArrayList dictionaryPositions = getDictionaryPositions(positions, source); - if (dictionaryBlockBuilder.canAppend(dictionary)) { - dictionaryBlockBuilder.append(dictionaryPositions, dictionary); - } - else { - dictionaryBlockBuilder.flushDictionary(delegate); - delegate.append(dictionaryPositions, dictionary); - } + return INSTANCE_SIZE + + delegate.getRetainedSizeInBytes() + + dictionaryIdsBuilder.getRetainedSizeInBytes() + + (rleValue != null ? rleValue.getRetainedSizeInBytes() : 0); } - private IntArrayList getDictionaryPositions(IntArrayList positions, DictionaryBlock block) + public long getSizeInBytes() { - int[] positionArray = new int[positions.size()]; - for (int i = 0; i < positions.size(); i++) { - positionArray[i] = block.getId(positions.getInt(i)); - } - return IntArrayList.wrap(positionArray); + return delegate.getSizeInBytes() + + // dictionary size is not included due to the expense of the calculation + (rleValue != null ? rleValue.getSizeInBytes() : 0); } - private static class DictionaryBlockBuilder + private static class DictionaryIdsBuilder { - private static final int INSTANCE_SIZE = instanceSize(DictionaryBlockBuilder.class); + private static final int INSTANCE_SIZE = instanceSize(DictionaryIdsBuilder.class); + private final int initialEntryCount; - private Block dictionary; private int[] dictionaryIds; - private int positionCount; - private boolean closed; - - public DictionaryBlockBuilder() - { - this(1024); - } + private int size; - public DictionaryBlockBuilder(int initialEntryCount) + public DictionaryIdsBuilder(int initialEntryCount) { this.initialEntryCount = initialEntryCount; this.dictionaryIds = new int[0]; } - public boolean isEmpty() + public int[] getDictionaryIds() { - return positionCount == 0; + return dictionaryIds; } - public Block build() + public int size() { - return DictionaryBlock.create(positionCount, dictionary, dictionaryIds); + return size; } public long getRetainedSizeInBytes() { - return INSTANCE_SIZE - + (long) dictionaryIds.length * Integer.BYTES - + (dictionary != null ? dictionary.getRetainedSizeInBytes() : 0); + return INSTANCE_SIZE + sizeOf(dictionaryIds); } - public boolean canAppend(Block dictionary) + public void appendPositions(IntArrayList positions, DictionaryBlock block) { - return !closed && (dictionary == this.dictionary || this.dictionary == null); - } + checkArgument(!positions.isEmpty(), "positions is empty"); + ensureCapacity(size + positions.size()); - public void append(IntArrayList mappedPositions, Block dictionary) - { - checkArgument(canAppend(dictionary)); - this.dictionary = dictionary; - ensureCapacity(positionCount + mappedPositions.size()); - System.arraycopy(mappedPositions.elements(), 0, dictionaryIds, positionCount, mappedPositions.size()); - positionCount += mappedPositions.size(); - } - - public void flushDictionary(PositionsAppender delegate) - { - if (closed) { - return; - } - if (positionCount > 0) { - requireNonNull(dictionary, () -> "dictionary is null but we have pending dictionaryIds " + positionCount); - delegate.append(IntArrayList.wrap(dictionaryIds, positionCount), dictionary); + for (int i = 0; i < positions.size(); i++) { + dictionaryIds[size + i] = block.getId(positions.getInt(i)); } - - closed = true; - dictionaryIds = new int[0]; - positionCount = 0; - dictionary = null; + size += positions.size(); } - public DictionaryBlockBuilder newBuilderLike() + public DictionaryIdsBuilder newBuilderLike() { - return new DictionaryBlockBuilder(max(calculateBlockResetSize(positionCount), initialEntryCount)); + if (size == 0) { + return this; + } + return new DictionaryIdsBuilder(max(calculateBlockResetSize(size), initialEntryCount)); } private void ensureCapacity(int capacity) @@ -226,9 +265,9 @@ private void ensureCapacity(int capacity) else { newSize = initialEntryCount; } - newSize = Math.max(newSize, capacity); + newSize = max(newSize, capacity); - dictionaryIds = IntArrays.ensureCapacity(dictionaryIds, newSize, positionCount); + dictionaryIds = IntArrays.ensureCapacity(dictionaryIds, newSize, size); } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayHistogramFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayHistogramFunction.java index 4b386470415e..e0332a537189 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayHistogramFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayHistogramFunction.java @@ -18,6 +18,7 @@ import io.trino.spi.block.MapBlock; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.Convention; import io.trino.spi.function.Description; import io.trino.spi.function.OperatorDependency; @@ -30,8 +31,8 @@ import java.lang.invoke.MethodHandle; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -53,7 +54,7 @@ public static SqlMap arrayHistogram( @OperatorDependency( operator = OperatorType.READ_VALUE, argumentTypes = "T", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle writeFlat, + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FLAT_RETURN)) MethodHandle writeFlat, @OperatorDependency( operator = OperatorType.HASH_CODE, argumentTypes = "T", @@ -61,19 +62,20 @@ public static SqlMap arrayHistogram( @OperatorDependency( operator = OperatorType.IS_DISTINCT_FROM, argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {FLAT, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle distinctFlatBlock, + convention = @Convention(arguments = {FLAT, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle distinctFlatBlock, @OperatorDependency( operator = OperatorType.HASH_CODE, argumentTypes = "T", - convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle hashBlock, + convention = @Convention(arguments = VALUE_BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle hashBlock, @TypeParameter("map(T, bigint)") MapType mapType, @SqlType("array(T)") Block arrayBlock) { TypedHistogram histogram = new TypedHistogram(elementType, readFlat, writeFlat, hashFlat, distinctFlatBlock, hashBlock, false); - int positionCount = arrayBlock.getPositionCount(); - for (int position = 0; position < positionCount; position++) { + ValueBlock valueBlock = arrayBlock.getUnderlyingValueBlock(); + for (int i = 0; i < arrayBlock.getPositionCount(); i++) { + int position = arrayBlock.getUnderlyingValuePosition(i); if (!arrayBlock.isNull(position)) { - histogram.add(0, arrayBlock, position, 1L); + histogram.add(0, valueBlock, position, 1L); } } MapBlockBuilder blockBuilder = mapType.createBlockBuilder(null, histogram.size()); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesSpecializedSqlScalarFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesSpecializedSqlScalarFunction.java index 8db7d8638c12..daf0bd6f29fc 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesSpecializedSqlScalarFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesSpecializedSqlScalarFunction.java @@ -208,6 +208,10 @@ private static int computeScore(InvocationConvention callingConvention) case BLOCK_POSITION: score += 1000; break; + case VALUE_BLOCK_POSITION_NOT_NULL: + case VALUE_BLOCK_POSITION: + score += 2000; + break; case IN_OUT: score += 10_000; break; diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ParametricScalarImplementation.java b/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ParametricScalarImplementation.java index bbf3c42aa355..2ced4b20154c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ParametricScalarImplementation.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ParametricScalarImplementation.java @@ -26,6 +26,7 @@ import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction.ScalarImplementationChoice; import io.trino.operator.scalar.SpecializedSqlScalarFunction; import io.trino.spi.block.Block; +import io.trino.spi.block.ValueBlock; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.BlockIndex; import io.trino.spi.function.BlockPosition; @@ -59,7 +60,7 @@ import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.ImmutableSortedSet.toImmutableSortedSet; @@ -82,6 +83,8 @@ import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; import static io.trino.sql.analyzer.TypeSignatureTranslator.parseTypeSignature; @@ -255,6 +258,11 @@ private static MethodType javaMethodType(ParametricScalarImplementationChoice ch methodHandleParameterTypes.add(Block.class); methodHandleParameterTypes.add(int.class); break; + case VALUE_BLOCK_POSITION: + case VALUE_BLOCK_POSITION_NOT_NULL: + methodHandleParameterTypes.add(ValueBlock.class); + methodHandleParameterTypes.add(int.class); + break; case IN_OUT: methodHandleParameterTypes.add(InOut.class); break; @@ -599,15 +607,21 @@ private void parseArguments(Method method, Signature.Builder signatureBuilder, L else { // value type InvocationArgumentConvention argumentConvention; + boolean nullable = Stream.of(annotations).anyMatch(SqlNullable.class::isInstance); if (Stream.of(annotations).anyMatch(BlockPosition.class::isInstance)) { - checkState(method.getParameterCount() > (parameterIndex + 1)); - checkState(parameterType == Block.class); + verify(method.getParameterCount() > (parameterIndex + 1)); - argumentConvention = Stream.of(annotations).anyMatch(SqlNullable.class::isInstance) ? BLOCK_POSITION : BLOCK_POSITION_NOT_NULL; + if (parameterType == Block.class) { + argumentConvention = nullable ? BLOCK_POSITION : BLOCK_POSITION_NOT_NULL; + } + else { + verify(ValueBlock.class.isAssignableFrom(parameterType)); + argumentConvention = nullable ? VALUE_BLOCK_POSITION : VALUE_BLOCK_POSITION_NOT_NULL; + } Annotation[] parameterAnnotations = method.getParameterAnnotations()[parameterIndex + 1]; - checkState(Stream.of(parameterAnnotations).anyMatch(BlockIndex.class::isInstance)); + verify(Stream.of(parameterAnnotations).anyMatch(BlockIndex.class::isInstance)); } - else if (Stream.of(annotations).anyMatch(SqlNullable.class::isInstance)) { + else if (nullable) { checkCondition(!parameterType.isPrimitive(), FUNCTION_IMPLEMENTATION_ERROR, "Method [%s] has parameter with primitive type %s annotated with @SqlNullable", method, parameterType.getSimpleName()); argumentConvention = BOXED_NULLABLE; @@ -641,7 +655,7 @@ else if (parameterType.equals(InOut.class)) { } } - if (argumentConvention == BLOCK_POSITION || argumentConvention == BLOCK_POSITION_NOT_NULL) { + if (argumentConvention == BLOCK_POSITION || argumentConvention == BLOCK_POSITION_NOT_NULL || argumentConvention == VALUE_BLOCK_POSITION || argumentConvention == VALUE_BLOCK_POSITION_NOT_NULL) { argumentNativeContainerTypes.add(Optional.of(type.nativeContainerType())); } else { diff --git a/core/trino-main/src/main/java/io/trino/spiller/GenericPartitioningSpiller.java b/core/trino-main/src/main/java/io/trino/spiller/GenericPartitioningSpiller.java index 2d19dc6bc13c..eb56a954aeaf 100644 --- a/core/trino-main/src/main/java/io/trino/spiller/GenericPartitioningSpiller.java +++ b/core/trino-main/src/main/java/io/trino/spiller/GenericPartitioningSpiller.java @@ -79,7 +79,7 @@ public GenericPartitioningSpiller( requireNonNull(memoryContext, "memoryContext is null"); closer.register(memoryContext::close); this.memoryContext = memoryContext; - int partitionCount = partitionFunction.getPartitionCount(); + int partitionCount = partitionFunction.partitionCount(); ImmutableList.Builder pageBuilders = ImmutableList.builder(); spillers = new ArrayList<>(partitionCount); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/CompilerConfig.java b/core/trino-main/src/main/java/io/trino/sql/planner/CompilerConfig.java index 43cde586607a..da90def6a9a7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/CompilerConfig.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/CompilerConfig.java @@ -22,6 +22,7 @@ public class CompilerConfig { private int expressionCacheSize = 10_000; + private boolean specializeAggregationLoops = true; @Min(0) public int getExpressionCacheSize() @@ -36,4 +37,16 @@ public CompilerConfig setExpressionCacheSize(int expressionCacheSize) this.expressionCacheSize = expressionCacheSize; return this; } + + public boolean isSpecializeAggregationLoops() + { + return specializeAggregationLoops; + } + + @Config("compiler.specialized-aggregation-loops") + public CompilerConfig setSpecializeAggregationLoops(boolean specializeAggregationLoops) + { + this.specializeAggregationLoops = specializeAggregationLoops; + return this; + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index c5c8f6f7f266..f0286d52289e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -443,6 +443,7 @@ public class LocalExecutionPlanner private final ExchangeManagerRegistry exchangeManagerRegistry; private final PositionsAppenderFactory positionsAppenderFactory; private final NodeVersion version; + private final boolean specializeAggregationLoops; private final NonEvictableCache accumulatorFactoryCache = buildNonEvictableCache(CacheBuilder.newBuilder() .maximumSize(1000) @@ -477,7 +478,8 @@ public LocalExecutionPlanner( TypeOperators typeOperators, TableExecuteContextManager tableExecuteContextManager, ExchangeManagerRegistry exchangeManagerRegistry, - NodeVersion version) + NodeVersion version, + CompilerConfig compilerConfig) { this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.metadata = plannerContext.getMetadata(); @@ -524,6 +526,7 @@ public LocalExecutionPlanner( this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); this.positionsAppenderFactory = new PositionsAppenderFactory(blockTypeOperators); this.version = requireNonNull(version, "version is null"); + this.specializeAggregationLoops = compilerConfig.isSpecializeAggregationLoops(); } public LocalExecutionPlan plan( @@ -589,7 +592,7 @@ public LocalExecutionPlan plan( // Keep the task bucket count to 50% of total local writers int taskBucketCount = (int) ceil(0.5 * partitionedWriterCount); skewedPartitionRebalancer = Optional.of(new SkewedPartitionRebalancer( - partitionFunction.getPartitionCount(), + partitionFunction.partitionCount(), taskCount, taskBucketCount, getWriterScalingMinDataProcessed(taskContext.getSession()).toBytes(), @@ -3822,7 +3825,8 @@ private AggregatorFactory buildAggregatorFactory( () -> generateAccumulatorFactory( resolvedFunction.getSignature(), aggregationImplementation, - resolvedFunction.getFunctionNullability())); + resolvedFunction.getFunctionNullability(), + specializeAggregationLoops)); if (aggregation.isDistinct()) { accumulatorFactory = new DistinctAccumulatorFactory( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java b/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java index 2b9608defade..95f522c66606 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java @@ -199,13 +199,13 @@ public MergePartitionFunction(PartitionFunction insertFunction, PartitionFunctio this.updateFunction = requireNonNull(updateFunction, "updateFunction is null"); this.insertColumns = requireNonNull(insertColumns, "insertColumns is null"); this.updateColumns = requireNonNull(updateColumns, "updateColumns is null"); - checkArgument(insertFunction.getPartitionCount() == updateFunction.getPartitionCount(), "partition counts must match"); + checkArgument(insertFunction.partitionCount() == updateFunction.partitionCount(), "partition counts must match"); } @Override - public int getPartitionCount() + public int partitionCount() { - return insertFunction.getPartitionCount(); + return insertFunction.partitionCount(); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index c2960e5cc630..bee7e1aba871 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -549,7 +549,7 @@ public PlanNode plan(Delete node) assignmentsBuilder.putIdentity(symbol); } else { - assignmentsBuilder.put(symbol, new NullLiteral()); + assignmentsBuilder.put(symbol, new Cast(new NullLiteral(), toSqlType(symbolAllocator.getTypes().get(symbol)))); } } List columnSymbols = columnSymbolsBuilder.build(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java index 8c02efa9d16d..55cc935ed77a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java @@ -468,19 +468,19 @@ public Expression rewriteCurrentTime(CurrentTime node, Void context, ExpressionT .build(); case TIME -> BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata()) .setName("$current_time") - .setArguments(ImmutableList.of(analysis.getType(node)), ImmutableList.of(new NullLiteral())) + .setArguments(ImmutableList.of(analysis.getType(node)), ImmutableList.of(new Cast(new NullLiteral(), toSqlType(analysis.getType(node))))) .build(); case LOCALTIME -> BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata()) .setName("$localtime") - .setArguments(ImmutableList.of(analysis.getType(node)), ImmutableList.of(new NullLiteral())) + .setArguments(ImmutableList.of(analysis.getType(node)), ImmutableList.of(new Cast(new NullLiteral(), toSqlType(analysis.getType(node))))) .build(); case TIMESTAMP -> BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata()) .setName("$current_timestamp") - .setArguments(ImmutableList.of(analysis.getType(node)), ImmutableList.of(new NullLiteral())) + .setArguments(ImmutableList.of(analysis.getType(node)), ImmutableList.of(new Cast(new NullLiteral(), toSqlType(analysis.getType(node))))) .build(); case LOCALTIMESTAMP -> BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata()) .setName("$localtimestamp") - .setArguments(ImmutableList.of(analysis.getType(node)), ImmutableList.of(new NullLiteral())) + .setArguments(ImmutableList.of(analysis.getType(node)), ImmutableList.of(new Cast(new NullLiteral(), toSqlType(analysis.getType(node))))) .build(); }; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveEmptyTableExecute.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveEmptyTableExecute.java index 9360b029254f..9a16f1256b85 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveEmptyTableExecute.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveEmptyTableExecute.java @@ -23,12 +23,15 @@ import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.ValuesNode; +import io.trino.sql.tree.Cast; import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.Row; import java.util.Optional; import static com.google.common.base.Verify.verify; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.plan.Patterns.Values.rowCount; import static io.trino.sql.planner.plan.Patterns.tableFinish; import static io.trino.sql.planner.plan.Patterns.values; @@ -86,7 +89,7 @@ public Result apply(TableFinishNode finishNode, Captures captures, Context conte new ValuesNode( finishNode.getId(), finishNode.getOutputSymbols(), - ImmutableList.of(new Row(ImmutableList.of(new NullLiteral()))))); + ImmutableList.of(new Row(ImmutableList.of(new Cast(new NullLiteral(), toSqlType(BIGINT))))))); } private Optional getSingleSourceSkipExchange(PlanNode node, Lookup lookup) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformUncorrelatedSubqueryToJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformUncorrelatedSubqueryToJoin.java index 896a5a31f785..ea7a5458e893 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformUncorrelatedSubqueryToJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformUncorrelatedSubqueryToJoin.java @@ -19,12 +19,14 @@ import com.google.common.collect.Sets; import io.trino.matching.Captures; import io.trino.matching.Pattern; +import io.trino.spi.type.Type; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.CorrelatedJoinNode; import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.ProjectNode; +import io.trino.sql.tree.Cast; import io.trino.sql.tree.Expression; import io.trino.sql.tree.IfExpression; import io.trino.sql.tree.NullLiteral; @@ -33,6 +35,7 @@ import static com.google.common.base.Preconditions.checkState; import static io.trino.matching.Pattern.empty; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.FULL; import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.INNER; import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.LEFT; @@ -91,7 +94,8 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co for (Symbol inputSymbol : Sets.intersection( ImmutableSet.copyOf(correlatedJoinNode.getInput().getOutputSymbols()), ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols()))) { - assignments.put(inputSymbol, new IfExpression(correlatedJoinNode.getFilter(), inputSymbol.toSymbolReference(), new NullLiteral())); + Type inputType = context.getSymbolAllocator().getTypes().get(inputSymbol); + assignments.put(inputSymbol, new IfExpression(correlatedJoinNode.getFilter(), inputSymbol.toSymbolReference(), new Cast(new NullLiteral(), toSqlType(inputType)))); } ProjectNode projectNode = new ProjectNode( context.getIdAllocator().getNextId(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index b62a0a8d7006..991ea2f25b8f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -56,6 +56,7 @@ import static io.trino.SystemSessionProperties.isOptimizeDistinctAggregationEnabled; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; @@ -374,7 +375,7 @@ else if (aggregationOutputSymbolsMap.containsKey(symbol)) { // add null assignment for mask // unused mask will be removed by PruneUnreferencedOutputs - outputSymbols.put(aggregateInfo.getMask(), new NullLiteral()); + outputSymbols.put(aggregateInfo.getMask(), new Cast(new NullLiteral(), toSqlType(BOOLEAN))); aggregateInfo.setNewNonDistinctAggregateSymbols(outputNonDistinctAggregateSymbols.buildOrThrow()); diff --git a/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeInputRewrite.java b/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeInputRewrite.java index f79b1dcc072c..9f917d2e6d44 100644 --- a/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeInputRewrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeInputRewrite.java @@ -24,6 +24,7 @@ import io.trino.sql.analyzer.AnalyzerFactory; import io.trino.sql.parser.SqlParser; import io.trino.sql.tree.AstVisitor; +import io.trino.sql.tree.Cast; import io.trino.sql.tree.DescribeInput; import io.trino.sql.tree.Expression; import io.trino.sql.tree.Limit; @@ -32,6 +33,7 @@ import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.Parameter; +import io.trino.sql.tree.Query; import io.trino.sql.tree.Row; import io.trino.sql.tree.Statement; import io.trino.sql.tree.StringLiteral; @@ -42,6 +44,8 @@ import static io.trino.SystemSessionProperties.isOmitDateTimeTypePrecision; import static io.trino.execution.ParameterExtractor.extractParameters; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.QueryUtil.aliased; import static io.trino.sql.QueryUtil.ascending; import static io.trino.sql.QueryUtil.identifier; @@ -51,6 +55,7 @@ import static io.trino.sql.QueryUtil.simpleQuery; import static io.trino.sql.QueryUtil.values; import static io.trino.sql.analyzer.QueryType.DESCRIBE; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.type.TypeUtils.getDisplayLabel; import static io.trino.type.UnknownType.UNKNOWN; import static java.util.Objects.requireNonNull; @@ -82,6 +87,12 @@ public Statement rewrite( private static final class Visitor extends AstVisitor { + private static final Query EMPTY_INPUT = createDesctibeInputQuery( + new Row[]{row( + new Cast(new NullLiteral(), toSqlType(BIGINT)), + new Cast(new NullLiteral(), toSqlType(VARCHAR)))}, + Optional.of(new Limit(new LongLiteral("0")))); + private final Session session; private final SqlParser parser; private final AnalyzerFactory analyzerFactory; @@ -130,10 +141,14 @@ protected Node visitDescribeInput(DescribeInput node, Void context) Row[] rows = builder.build().toArray(Row[]::new); Optional limit = Optional.empty(); if (rows.length == 0) { - rows = new Row[] {row(new NullLiteral(), new NullLiteral())}; - limit = Optional.of(new Limit(new LongLiteral("0"))); + return EMPTY_INPUT; } + return createDesctibeInputQuery(rows, limit); + } + + private static Query createDesctibeInputQuery(Row[] rows, Optional limit) + { return simpleQuery( selectList(identifier("Position"), identifier("Type")), aliased( diff --git a/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeOutputRewrite.java b/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeOutputRewrite.java index b0ae19fe9821..b8a78d502e29 100644 --- a/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeOutputRewrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/rewrite/DescribeOutputRewrite.java @@ -27,6 +27,7 @@ import io.trino.sql.parser.SqlParser; import io.trino.sql.tree.AstVisitor; import io.trino.sql.tree.BooleanLiteral; +import io.trino.sql.tree.Cast; import io.trino.sql.tree.DescribeOutput; import io.trino.sql.tree.Expression; import io.trino.sql.tree.Limit; @@ -35,6 +36,7 @@ import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.Parameter; +import io.trino.sql.tree.Query; import io.trino.sql.tree.Row; import io.trino.sql.tree.Statement; import io.trino.sql.tree.StringLiteral; @@ -44,6 +46,9 @@ import java.util.Optional; import static io.trino.SystemSessionProperties.isOmitDateTimeTypePrecision; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.QueryUtil.aliased; import static io.trino.sql.QueryUtil.identifier; import static io.trino.sql.QueryUtil.row; @@ -51,6 +56,7 @@ import static io.trino.sql.QueryUtil.simpleQuery; import static io.trino.sql.QueryUtil.values; import static io.trino.sql.analyzer.QueryType.DESCRIBE; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.type.TypeUtils.getDisplayLabel; import static java.util.Objects.requireNonNull; @@ -81,6 +87,17 @@ public Statement rewrite( private static final class Visitor extends AstVisitor { + private static final Query EMPTY_OUTPUT = createDesctibeOutputQuery( + new Row[]{row( + new Cast(new NullLiteral(), toSqlType(VARCHAR)), + new Cast(new NullLiteral(), toSqlType(VARCHAR)), + new Cast(new NullLiteral(), toSqlType(VARCHAR)), + new Cast(new NullLiteral(), toSqlType(VARCHAR)), + new Cast(new NullLiteral(), toSqlType(VARCHAR)), + new Cast(new NullLiteral(), toSqlType(BIGINT)), + new Cast(new NullLiteral(), toSqlType(BOOLEAN)))}, + Optional.of(new Limit(new LongLiteral("0")))); + private final Session session; private final SqlParser parser; private final AnalyzerFactory analyzerFactory; @@ -119,10 +136,13 @@ protected Node visitDescribeOutput(DescribeOutput node, Void context) Optional limit = Optional.empty(); Row[] rows = analysis.getRootScope().getRelationType().getVisibleFields().stream().map(field -> createDescribeOutputRow(field, analysis)).toArray(Row[]::new); if (rows.length == 0) { - NullLiteral nullLiteral = new NullLiteral(); - rows = new Row[] {row(nullLiteral, nullLiteral, nullLiteral, nullLiteral, nullLiteral, nullLiteral, nullLiteral)}; - limit = Optional.of(new Limit(new LongLiteral("0"))); + return EMPTY_OUTPUT; } + return createDesctibeOutputQuery(rows, limit); + } + + private static Query createDesctibeOutputQuery(Row[] rows, Optional limit) + { return simpleQuery( selectList( identifier("Column Name"), diff --git a/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java b/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java index 2f3199c29d27..54c5ddec92f2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java @@ -54,6 +54,7 @@ import io.trino.spi.security.PrincipalType; import io.trino.spi.security.TrinoPrincipal; import io.trino.spi.session.PropertyMetadata; +import io.trino.spi.type.Type; import io.trino.sql.analyzer.AnalyzerFactory; import io.trino.sql.parser.ParsingException; import io.trino.sql.parser.SqlParser; @@ -61,6 +62,7 @@ import io.trino.sql.tree.Array; import io.trino.sql.tree.AstVisitor; import io.trino.sql.tree.BooleanLiteral; +import io.trino.sql.tree.Cast; import io.trino.sql.tree.ColumnDefinition; import io.trino.sql.tree.CreateMaterializedView; import io.trino.sql.tree.CreateSchema; @@ -75,13 +77,16 @@ import io.trino.sql.tree.LongLiteral; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeRef; +import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.Parameter; import io.trino.sql.tree.PrincipalSpecification; import io.trino.sql.tree.Property; import io.trino.sql.tree.QualifiedName; import io.trino.sql.tree.Query; +import io.trino.sql.tree.QuerySpecification; import io.trino.sql.tree.Relation; import io.trino.sql.tree.Row; +import io.trino.sql.tree.SelectItem; import io.trino.sql.tree.ShowCatalogs; import io.trino.sql.tree.ShowColumns; import io.trino.sql.tree.ShowCreate; @@ -92,6 +97,7 @@ import io.trino.sql.tree.ShowSchemas; import io.trino.sql.tree.ShowSession; import io.trino.sql.tree.ShowTables; +import io.trino.sql.tree.SingleColumn; import io.trino.sql.tree.SortItem; import io.trino.sql.tree.Statement; import io.trino.sql.tree.StringLiteral; @@ -128,18 +134,19 @@ import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.StandardErrorCode.SCHEMA_NOT_FOUND; import static io.trino.spi.StandardErrorCode.TABLE_NOT_FOUND; +import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.ExpressionUtils.combineConjuncts; import static io.trino.sql.QueryUtil.aliased; import static io.trino.sql.QueryUtil.aliasedName; import static io.trino.sql.QueryUtil.aliasedNullToEmpty; import static io.trino.sql.QueryUtil.ascending; -import static io.trino.sql.QueryUtil.emptyQuery; import static io.trino.sql.QueryUtil.equal; import static io.trino.sql.QueryUtil.functionCall; import static io.trino.sql.QueryUtil.identifier; import static io.trino.sql.QueryUtil.logicalAnd; import static io.trino.sql.QueryUtil.ordering; +import static io.trino.sql.QueryUtil.query; import static io.trino.sql.QueryUtil.row; import static io.trino.sql.QueryUtil.selectAll; import static io.trino.sql.QueryUtil.selectList; @@ -377,13 +384,13 @@ protected Node visitShowRoles(ShowRoles node, Void context) List rows = enabledRoles.stream() .map(role -> row(new StringLiteral(role))) .collect(toList()); - return singleColumnValues(rows, "Role"); + return singleColumnValues(rows, "Role", VARCHAR); } accessControl.checkCanShowRoles(session.toSecurityContext(), catalog); List rows = metadata.listRoles(session, catalog).stream() .map(role -> row(new StringLiteral(role))) .collect(toList()); - return singleColumnValues(rows, "Role"); + return singleColumnValues(rows, "Role", VARCHAR); } @Override @@ -402,14 +409,14 @@ protected Node visitShowRoleGrants(ShowRoleGrants node, Void context) .map(roleGrant -> row(new StringLiteral(roleGrant.getRoleName()))) .collect(toList()); - return singleColumnValues(rows, "Role Grants"); + return singleColumnValues(rows, "Role Grants", VARCHAR); } - private static Query singleColumnValues(List rows, String columnName) + private static Query singleColumnValues(List rows, String columnName, Type type) { List columns = ImmutableList.of(columnName); if (rows.isEmpty()) { - return emptyQuery(columns); + return emptyQuery(columns, ImmutableList.of(type)); } return simpleQuery( selectList(new AllColumns()), @@ -803,7 +810,7 @@ protected Node visitShowFunctions(ShowFunctions node, Void context) .buildOrThrow(); if (rows.isEmpty()) { - return emptyQuery(ImmutableList.copyOf(columns.values())); + return emptyQuery(ImmutableList.copyOf(columns.values()), ImmutableList.of(VARCHAR, VARCHAR, VARCHAR, VARCHAR, BOOLEAN, VARCHAR)); } return simpleQuery( @@ -949,5 +956,24 @@ protected Node visitNode(Node node, Void context) { return node; } + + public static Query emptyQuery(List columns, List types) + { + ImmutableList.Builder items = ImmutableList.builder(); + for (int i = 0; i < columns.size(); i++) { + items.add(new SingleColumn(new Cast(new NullLiteral(), toSqlType(types.get(i))), identifier(columns.get(i)))); + } + Optional where = Optional.of(FALSE_LITERAL); + return query(new QuerySpecification( + selectAll(items.build()), + Optional.empty(), + where, + Optional.empty(), + Optional.empty(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + Optional.empty())); + } } } diff --git a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java index 0e6ce21f962d..c3947b7a9d6d 100644 --- a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java @@ -168,6 +168,7 @@ import io.trino.sql.gen.OrderingCompiler; import io.trino.sql.gen.PageFunctionCompiler; import io.trino.sql.parser.SqlParser; +import io.trino.sql.planner.CompilerConfig; import io.trino.sql.planner.LocalExecutionPlanner; import io.trino.sql.planner.LocalExecutionPlanner.LocalExecutionPlan; import io.trino.sql.planner.LogicalPlanner; @@ -999,7 +1000,8 @@ private List createDrivers(Session session, Plan plan, OutputFactory out typeOperators, tableExecuteContextManager, exchangeManagerRegistry, - nodeManager.getCurrentNode().getNodeVersion()); + nodeManager.getCurrentNode().getNodeVersion(), + new CompilerConfig()); // plan query LocalExecutionPlan localExecutionPlan = executionPlanner.plan( diff --git a/core/trino-main/src/main/java/io/trino/type/CodePointsType.java b/core/trino-main/src/main/java/io/trino/type/CodePointsType.java index 699ab2ce504a..c5c33f9afc92 100644 --- a/core/trino-main/src/main/java/io/trino/type/CodePointsType.java +++ b/core/trino-main/src/main/java/io/trino/type/CodePointsType.java @@ -17,6 +17,7 @@ import io.airlift.slice.Slices; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -46,7 +47,9 @@ public Object getObject(Block block, int position) return null; } - Slice slice = block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + Slice slice = valueBlock.getSlice(valuePosition); int[] codePoints = new int[slice.length() / Integer.BYTES]; slice.getInts(0, codePoints); return codePoints; diff --git a/core/trino-main/src/main/java/io/trino/type/ColorType.java b/core/trino-main/src/main/java/io/trino/type/ColorType.java index e6329872dcef..cac180fbbc6d 100644 --- a/core/trino-main/src/main/java/io/trino/type/ColorType.java +++ b/core/trino-main/src/main/java/io/trino/type/ColorType.java @@ -46,7 +46,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - int color = block.getInt(position, 0); + int color = getInt(block, position); if (color < 0) { return ColorFunctions.SystemColor.valueOf(-(color + 1)).getName(); } diff --git a/core/trino-main/src/main/java/io/trino/type/FunctionType.java b/core/trino-main/src/main/java/io/trino/type/FunctionType.java index 5757fe91c7e0..2a5e6a790f4c 100644 --- a/core/trino-main/src/main/java/io/trino/type/FunctionType.java +++ b/core/trino-main/src/main/java/io/trino/type/FunctionType.java @@ -19,6 +19,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.ValueBlock; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -92,7 +93,13 @@ public String getDisplayName() @Override public final Class getJavaType() { - throw new UnsupportedOperationException(getTypeSignature() + " type does not have Java type"); + throw new UnsupportedOperationException(getTypeSignature() + " type does not have a Java type"); + } + + @Override + public Class getValueBlockType() + { + throw new UnsupportedOperationException(getTypeSignature() + " type does not have a ValueBlock type"); } @Override diff --git a/core/trino-main/src/main/java/io/trino/type/IntervalYearMonthType.java b/core/trino-main/src/main/java/io/trino/type/IntervalYearMonthType.java index 0cb9ed359e9d..ebcfe8dea965 100644 --- a/core/trino-main/src/main/java/io/trino/type/IntervalYearMonthType.java +++ b/core/trino-main/src/main/java/io/trino/type/IntervalYearMonthType.java @@ -35,7 +35,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position if (block.isNull(position)) { return null; } - return new SqlIntervalYearMonth(block.getInt(position, 0)); + return new SqlIntervalYearMonth(getInt(block, position)); } @Override diff --git a/core/trino-main/src/main/java/io/trino/type/IpAddressType.java b/core/trino-main/src/main/java/io/trino/type/IpAddressType.java index b248ec808175..c0bff38df5bc 100644 --- a/core/trino-main/src/main/java/io/trino/type/IpAddressType.java +++ b/core/trino-main/src/main/java/io/trino/type/IpAddressType.java @@ -20,6 +20,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.Int128ArrayBlock; import io.trino.spi.block.Int128ArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; @@ -63,7 +64,7 @@ public class IpAddressType private IpAddressType() { - super(new TypeSignature(StandardTypes.IPADDRESS), Slice.class); + super(new TypeSignature(StandardTypes.IPADDRESS), Slice.class, Int128ArrayBlock.class); } @Override @@ -219,7 +220,7 @@ private static boolean equalOperator(Slice left, Slice right) } @ScalarOperator(EQUAL) - private static boolean equalOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean equalOperator(@BlockPosition Int128ArrayBlock leftBlock, @BlockIndex int leftPosition, @BlockPosition Int128ArrayBlock rightBlock, @BlockIndex int rightPosition) { return equal( leftBlock.getLong(leftPosition, 0), @@ -240,7 +241,7 @@ private static long xxHash64Operator(Slice value) } @ScalarOperator(XX_HASH_64) - private static long xxHash64Operator(@BlockPosition Block block, @BlockIndex int position) + private static long xxHash64Operator(@BlockPosition Int128ArrayBlock block, @BlockIndex int position) { return xxHash64(block.getLong(position, 0), block.getLong(position, SIZE_OF_LONG)); } @@ -261,7 +262,7 @@ private static long comparisonOperator(Slice left, Slice right) } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static long comparisonOperator(@BlockPosition Int128ArrayBlock leftBlock, @BlockIndex int leftPosition, @BlockPosition Int128ArrayBlock rightBlock, @BlockIndex int rightPosition) { return compareBigEndian( leftBlock.getLong(leftPosition, 0), diff --git a/core/trino-main/src/main/java/io/trino/type/JoniRegexpType.java b/core/trino-main/src/main/java/io/trino/type/JoniRegexpType.java index 6bc6bf3587b9..606639ce4110 100644 --- a/core/trino-main/src/main/java/io/trino/type/JoniRegexpType.java +++ b/core/trino-main/src/main/java/io/trino/type/JoniRegexpType.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -47,7 +48,9 @@ public Object getObject(Block block, int position) return null; } - return joniRegexp(block.getSlice(position, 0, block.getSliceLength(position))); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return joniRegexp(valueBlock.getSlice(valuePosition)); } @Override diff --git a/core/trino-main/src/main/java/io/trino/type/Json2016Type.java b/core/trino-main/src/main/java/io/trino/type/Json2016Type.java index 5480dc224707..f028551ddc32 100644 --- a/core/trino-main/src/main/java/io/trino/type/Json2016Type.java +++ b/core/trino-main/src/main/java/io/trino/type/Json2016Type.java @@ -21,6 +21,7 @@ import io.trino.operator.scalar.json.JsonOutputConversionError; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -54,7 +55,9 @@ public Object getObject(Block block, int position) return null; } - String json = block.getSlice(position, 0, block.getSliceLength(position)).toStringUtf8(); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + String json = valueBlock.getSlice(valuePosition).toStringUtf8(); if (json.equals(JSON_ERROR.toString())) { return JSON_ERROR; } diff --git a/core/trino-main/src/main/java/io/trino/type/JsonPath2016Type.java b/core/trino-main/src/main/java/io/trino/type/JsonPath2016Type.java index 02fd29daa304..fc7f6ed88dbc 100644 --- a/core/trino-main/src/main/java/io/trino/type/JsonPath2016Type.java +++ b/core/trino-main/src/main/java/io/trino/type/JsonPath2016Type.java @@ -23,6 +23,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockEncodingSerde; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -57,8 +58,10 @@ public Object getObject(Block block, int position) return null; } - Slice bytes = block.getSlice(position, 0, block.getSliceLength(position)); - return jsonPathCodec.fromJson(bytes.toStringUtf8()); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + String json = valueBlock.getSlice(valuePosition).toStringUtf8(); + return jsonPathCodec.fromJson(json); } @Override diff --git a/core/trino-main/src/main/java/io/trino/type/JsonPathType.java b/core/trino-main/src/main/java/io/trino/type/JsonPathType.java index 767ea5075626..addc7a034f4f 100644 --- a/core/trino-main/src/main/java/io/trino/type/JsonPathType.java +++ b/core/trino-main/src/main/java/io/trino/type/JsonPathType.java @@ -18,6 +18,7 @@ import io.trino.operator.scalar.JsonPath; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -47,7 +48,10 @@ public Object getObject(Block block, int position) return null; } - return new JsonPath(block.getSlice(position, 0, block.getSliceLength(position)).toStringUtf8()); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + String pattern = valueBlock.getSlice(valuePosition).toStringUtf8(); + return new JsonPath(pattern); } @Override diff --git a/core/trino-main/src/main/java/io/trino/type/JsonType.java b/core/trino-main/src/main/java/io/trino/type/JsonType.java index a2d95bdbe10e..f077f2287fb3 100644 --- a/core/trino-main/src/main/java/io/trino/type/JsonType.java +++ b/core/trino-main/src/main/java/io/trino/type/JsonType.java @@ -17,6 +17,7 @@ import io.airlift.slice.Slices; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -62,13 +63,15 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return block.getSlice(position, 0, block.getSliceLength(position)).toStringUtf8(); + return getSlice(block, position).toStringUtf8(); } @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } public void writeString(BlockBuilder blockBuilder, String value) diff --git a/core/trino-main/src/main/java/io/trino/type/LikePatternType.java b/core/trino-main/src/main/java/io/trino/type/LikePatternType.java index 8f9666e77779..180f9a33fa2a 100644 --- a/core/trino-main/src/main/java/io/trino/type/LikePatternType.java +++ b/core/trino-main/src/main/java/io/trino/type/LikePatternType.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -23,8 +24,8 @@ import java.util.Optional; -import static io.airlift.slice.SizeOf.SIZE_OF_INT; import static io.airlift.slice.Slices.utf8Slice; +import static java.nio.charset.StandardCharsets.UTF_8; public class LikePatternType extends AbstractVariableWidthType @@ -50,19 +51,19 @@ public Object getObject(Block block, int position) return null; } + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + Slice slice = valueBlock.getSlice(valuePosition); + // layout is: ? - int offset = 0; - int length = block.getInt(position, offset); - offset += SIZE_OF_INT; - String pattern = block.getSlice(position, offset, length).toStringUtf8(); - offset += length; + int length = slice.getInt(0); + String pattern = slice.toString(4, length, UTF_8); - boolean hasEscape = block.getByte(position, offset) != 0; - offset++; + boolean hasEscape = slice.getByte(4 + length) != 0; Optional escape = Optional.empty(); if (hasEscape) { - escape = Optional.of((char) block.getInt(position, offset)); + escape = Optional.of((char) slice.getInt(4 + length + 1)); } return LikePattern.compile(pattern, escape); diff --git a/core/trino-main/src/main/java/io/trino/type/Re2JRegexpType.java b/core/trino-main/src/main/java/io/trino/type/Re2JRegexpType.java index 98b8690725bb..1d6807fd4d71 100644 --- a/core/trino-main/src/main/java/io/trino/type/Re2JRegexpType.java +++ b/core/trino-main/src/main/java/io/trino/type/Re2JRegexpType.java @@ -18,6 +18,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -54,7 +55,9 @@ public Object getObject(Block block, int position) return null; } - Slice pattern = block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + Slice pattern = valueBlock.getSlice(valuePosition); try { return new Re2JRegexp(dfaStatesLimit, dfaRetries, pattern); } diff --git a/core/trino-main/src/main/java/io/trino/type/TDigestType.java b/core/trino-main/src/main/java/io/trino/type/TDigestType.java index a49fafb84622..b37130082a3b 100644 --- a/core/trino-main/src/main/java/io/trino/type/TDigestType.java +++ b/core/trino-main/src/main/java/io/trino/type/TDigestType.java @@ -17,6 +17,7 @@ import io.airlift.stats.TDigest; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -37,7 +38,9 @@ private TDigestType() @Override public Object getObject(Block block, int position) { - return TDigest.deserialize(block.getSlice(position, 0, block.getSliceLength(position))); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return TDigest.deserialize(valueBlock.getSlice(valuePosition)); } @Override @@ -54,6 +57,8 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return new SqlVarbinary(block.getSlice(position, 0, block.getSliceLength(position)).getBytes()); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return new SqlVarbinary(valueBlock.getSlice(valuePosition).getBytes()); } } diff --git a/core/trino-main/src/main/java/io/trino/type/UnknownType.java b/core/trino-main/src/main/java/io/trino/type/UnknownType.java index 6fd9e1c31703..92406f8684f2 100644 --- a/core/trino-main/src/main/java/io/trino/type/UnknownType.java +++ b/core/trino-main/src/main/java/io/trino/type/UnknownType.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.ByteArrayBlock; import io.trino.spi.block.ByteArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; @@ -51,7 +52,7 @@ private UnknownType() // We never access the native container for UNKNOWN because its null check is always true. // The actual native container type does not matter here. // We choose boolean to represent UNKNOWN because it's the smallest primitive type. - super(new TypeSignature(NAME), boolean.class); + super(new TypeSignature(NAME), boolean.class, ByteArrayBlock.class); } @Override @@ -122,8 +123,8 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) @Override public boolean getBoolean(Block block, int position) { - // Ideally, this function should never be invoked for unknown type. - // However, some logic rely on having a default value before the null check. + // Ideally, this function should never be invoked for the unknown type. + // However, some logic relies on having a default value before the null check. checkArgument(block.isNull(position)); return false; } @@ -132,8 +133,8 @@ public boolean getBoolean(Block block, int position) @Override public void writeBoolean(BlockBuilder blockBuilder, boolean value) { - // Ideally, this function should never be invoked for unknown type. - // However, some logic (e.g. AbstractMinMaxBy) rely on writing a default value before the null check. + // Ideally, this function should never be invoked for the unknown type. + // However, some logic (e.g. AbstractMinMaxBy) relies on writing a default value before the null check. checkArgument(!value); blockBuilder.appendNull(); } diff --git a/core/trino-main/src/main/java/io/trino/type/setdigest/SetDigestType.java b/core/trino-main/src/main/java/io/trino/type/setdigest/SetDigestType.java index d2094c6c2f77..7d3195992756 100644 --- a/core/trino-main/src/main/java/io/trino/type/setdigest/SetDigestType.java +++ b/core/trino-main/src/main/java/io/trino/type/setdigest/SetDigestType.java @@ -17,6 +17,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -44,13 +45,15 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return new SqlVarbinary(block.getSlice(position, 0, block.getSliceLength(position)).getBytes()); + return new SqlVarbinary(getSlice(block, position).getBytes()); } @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } @Override diff --git a/core/trino-main/src/test/java/io/trino/block/AbstractTestBlock.java b/core/trino-main/src/test/java/io/trino/block/AbstractTestBlock.java index 8da257d28d24..986a5fc8cca9 100644 --- a/core/trino-main/src/test/java/io/trino/block/AbstractTestBlock.java +++ b/core/trino-main/src/test/java/io/trino/block/AbstractTestBlock.java @@ -17,7 +17,6 @@ import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; -import io.airlift.slice.Slices; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; @@ -26,12 +25,13 @@ import io.trino.spi.block.DictionaryId; import io.trino.spi.block.MapHashTables; import io.trino.spi.block.TestingBlockEncodingSerde; -import io.trino.spi.block.VariableWidthBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; import java.lang.reflect.Array; import java.lang.reflect.Field; +import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.Arrays; import java.util.List; @@ -84,6 +84,10 @@ protected void assertBlock(Block block, T[] expectedValues) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Invalid position %d in block with %d positions", block.getPositionCount(), block.getPositionCount()); } + + if (block instanceof ValueBlock valueBlock) { + assertBlockClassImplementation(valueBlock.getClass()); + } } private void assertRetainedSize(Block block) @@ -113,7 +117,7 @@ else if (type == BlockBuilderStatus.class) { retainedSize += BlockBuilderStatus.INSTANCE_SIZE; } } - else if (type == Block.class) { + else if (Block.class.isAssignableFrom(type)) { retainedSize += ((Block) field.get(block)).getRetainedSizeInBytes(); } else if (type == Block[].class) { @@ -295,7 +299,13 @@ protected void assertPositionValue(Block block, int position, T expectedValu if (isSliceAccessSupported()) { assertEquals(block.getSliceLength(position), expectedSliceValue.length()); - assertSlicePosition(block, position, expectedSliceValue); + + int length = block.getSliceLength(position); + assertEquals(length, expectedSliceValue.length()); + + for (int offset = 0; offset < length - 3; offset++) { + assertEquals(block.getSlice(position, offset, 3), expectedSliceValue.slice(offset, 3)); + } } assertPositionEquals(block, position, expectedSliceValue); @@ -326,34 +336,6 @@ else if (expectedValue instanceof long[][] expected) { } } - protected void assertSlicePosition(Block block, int position, Slice expectedSliceValue) - { - int length = block.getSliceLength(position); - assertEquals(length, expectedSliceValue.length()); - - Block expectedBlock = toSingeValuedBlock(expectedSliceValue); - for (int offset = 0; offset < length - 3; offset++) { - assertEquals(block.getSlice(position, offset, 3), expectedSliceValue.slice(offset, 3)); - assertTrue(block.bytesEqual(position, offset, expectedSliceValue, offset, 3)); - // if your tests fail here, please change your test to not use this value - assertFalse(block.bytesEqual(position, offset, Slices.utf8Slice("XXX"), 0, 3)); - - assertEquals(block.bytesCompare(position, offset, 3, expectedSliceValue, offset, 3), 0); - assertTrue(block.bytesCompare(position, offset, 3, expectedSliceValue, offset, 2) > 0); - Slice greaterSlice = createGreaterValue(expectedSliceValue, offset, 3); - assertTrue(block.bytesCompare(position, offset, 3, greaterSlice, 0, greaterSlice.length()) < 0); - - assertTrue(block.equals(position, offset, expectedBlock, 0, offset, 3)); - assertEquals(block.compareTo(position, offset, 3, expectedBlock, 0, offset, 3), 0); - - VariableWidthBlockBuilder blockBuilder = VARBINARY.createBlockBuilder(null, 1); - blockBuilder.writeEntry(block.getSlice(position, offset, 3)); - Block segment = blockBuilder.build(); - - assertTrue(block.equals(position, offset, segment, 0, 0, 3)); - } - } - protected boolean isByteAccessSupported() { return true; @@ -498,4 +480,13 @@ protected static void testIncompactBlock(Block block) assertNotCompact(block); testCopyRegionCompactness(block); } + + private void assertBlockClassImplementation(Class clazz) + { + for (Method method : clazz.getMethods()) { + if (method.getReturnType() == ValueBlock.class && !method.isBridge()) { + throw new AssertionError(format("ValueBlock method %s should override return type to be %s", method, clazz.getSimpleName())); + } + } + } } diff --git a/core/trino-main/src/test/java/io/trino/block/BlockAssertions.java b/core/trino-main/src/test/java/io/trino/block/BlockAssertions.java index e420c33be0a6..739716f48e50 100644 --- a/core/trino-main/src/test/java/io/trino/block/BlockAssertions.java +++ b/core/trino-main/src/test/java/io/trino/block/BlockAssertions.java @@ -22,6 +22,7 @@ import io.trino.spi.block.RowBlock; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.ArrayType; import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; @@ -138,7 +139,7 @@ public static RunLengthEncodedBlock createRandomRleBlock(Block block, int positi return (RunLengthEncodedBlock) RunLengthEncodedBlock.create(block.getSingleValueBlock(random().nextInt(block.getPositionCount())), positionCount); } - public static Block createRandomBlockForType(Type type, int positionCount, float nullRate) + public static ValueBlock createRandomBlockForType(Type type, int positionCount, float nullRate) { verifyNullRate(nullRate); @@ -191,12 +192,12 @@ public static Block createRandomBlockForType(Type type, int positionCount, float return createRandomBlockForNestedType(type, positionCount, nullRate); } - public static Block createRandomBlockForNestedType(Type type, int positionCount, float nullRate) + public static ValueBlock createRandomBlockForNestedType(Type type, int positionCount, float nullRate) { return createRandomBlockForNestedType(type, positionCount, nullRate, ENTRY_SIZE); } - public static Block createRandomBlockForNestedType(Type type, int positionCount, float nullRate, int maxCardinality) + public static ValueBlock createRandomBlockForNestedType(Type type, int positionCount, float nullRate, int maxCardinality) { // Builds isNull and offsets of size positionCount boolean[] isNull = null; @@ -222,12 +223,12 @@ public static Block createRandomBlockForNestedType(Type type, int positionCount, // Builds the nested block of size offsets[positionCount]. if (type instanceof ArrayType) { - Block valuesBlock = createRandomBlockForType(((ArrayType) type).getElementType(), offsets[positionCount], nullRate); + ValueBlock valuesBlock = createRandomBlockForType(((ArrayType) type).getElementType(), offsets[positionCount], nullRate); return fromElementBlock(positionCount, Optional.ofNullable(isNull), offsets, valuesBlock); } if (type instanceof MapType mapType) { - Block keyBlock = createRandomBlockForType(mapType.getKeyType(), offsets[positionCount], 0.0f); - Block valueBlock = createRandomBlockForType(mapType.getValueType(), offsets[positionCount], nullRate); + ValueBlock keyBlock = createRandomBlockForType(mapType.getKeyType(), offsets[positionCount], 0.0f); + ValueBlock valueBlock = createRandomBlockForType(mapType.getValueType(), offsets[positionCount], nullRate); return mapType.createBlockFromKeyValue(Optional.ofNullable(isNull), offsets, keyBlock, valueBlock); } @@ -245,19 +246,19 @@ public static Block createRandomBlockForNestedType(Type type, int positionCount, throw new IllegalArgumentException(format("type %s is not supported.", type)); } - public static Block createRandomBooleansBlock(int positionCount, float nullRate) + public static ValueBlock createRandomBooleansBlock(int positionCount, float nullRate) { Random random = random(); return createBooleansBlock(generateListWithNulls(positionCount, nullRate, random::nextBoolean)); } - public static Block createRandomIntsBlock(int positionCount, float nullRate) + public static ValueBlock createRandomIntsBlock(int positionCount, float nullRate) { Random random = random(); return createIntsBlock(generateListWithNulls(positionCount, nullRate, random::nextInt)); } - public static Block createRandomLongDecimalsBlock(int positionCount, float nullRate) + public static ValueBlock createRandomLongDecimalsBlock(int positionCount, float nullRate) { Random random = random(); return createLongDecimalsBlock(generateListWithNulls( @@ -266,7 +267,7 @@ public static Block createRandomLongDecimalsBlock(int positionCount, float nullR () -> String.valueOf(random.nextLong()))); } - public static Block createRandomShortTimestampBlock(TimestampType type, int positionCount, float nullRate) + public static ValueBlock createRandomShortTimestampBlock(TimestampType type, int positionCount, float nullRate) { Random random = random(); return createLongsBlock( @@ -276,7 +277,7 @@ public static Block createRandomShortTimestampBlock(TimestampType type, int posi () -> SqlTimestamp.fromMillis(type.getPrecision(), random.nextLong()).getEpochMicros())); } - public static Block createRandomLongTimestampBlock(TimestampType type, int positionCount, float nullRate) + public static ValueBlock createRandomLongTimestampBlock(TimestampType type, int positionCount, float nullRate) { Random random = random(); return createLongTimestampBlock( @@ -290,7 +291,7 @@ public static Block createRandomLongTimestampBlock(TimestampType type, int posit })); } - public static Block createRandomLongsBlock(int positionCount, int numberOfUniqueValues) + public static ValueBlock createRandomLongsBlock(int positionCount, int numberOfUniqueValues) { checkArgument(positionCount >= numberOfUniqueValues, "numberOfUniqueValues must be between 1 and positionCount: %s but was %s", positionCount, numberOfUniqueValues); int[] uniqueValues = chooseRandomUnique(positionCount, numberOfUniqueValues).stream() @@ -303,13 +304,13 @@ public static Block createRandomLongsBlock(int positionCount, int numberOfUnique .collect(toImmutableList())); } - public static Block createRandomLongsBlock(int positionCount, float nullRate) + public static ValueBlock createRandomLongsBlock(int positionCount, float nullRate) { Random random = random(); return createLongsBlock(generateListWithNulls(positionCount, nullRate, random::nextLong)); } - public static Block createRandomSmallintsBlock(int positionCount, float nullRate) + public static ValueBlock createRandomSmallintsBlock(int positionCount, float nullRate) { Random random = random(); return createTypedLongsBlock( @@ -317,43 +318,43 @@ public static Block createRandomSmallintsBlock(int positionCount, float nullRate generateListWithNulls(positionCount, nullRate, () -> (long) (short) random.nextLong())); } - public static Block createRandomStringBlock(int positionCount, float nullRate, int maxStringLength) + public static ValueBlock createRandomStringBlock(int positionCount, float nullRate, int maxStringLength) { return createStringsBlock( generateListWithNulls(positionCount, nullRate, () -> generateRandomStringWithLength(maxStringLength))); } - private static Block createRandomVarbinariesBlock(int positionCount, float nullRate) + private static ValueBlock createRandomVarbinariesBlock(int positionCount, float nullRate) { Random random = random(); return createSlicesBlock(VARBINARY, generateListWithNulls(positionCount, nullRate, () -> Slices.random(16, random))); } - private static Block createRandomUUIDsBlock(int positionCount, float nullRate) + private static ValueBlock createRandomUUIDsBlock(int positionCount, float nullRate) { Random random = random(); return createSlicesBlock(UUID, generateListWithNulls(positionCount, nullRate, () -> Slices.random(16, random))); } - private static Block createRandomIpAddressesBlock(int positionCount, float nullRate) + private static ValueBlock createRandomIpAddressesBlock(int positionCount, float nullRate) { Random random = random(); return createSlicesBlock(IPADDRESS, generateListWithNulls(positionCount, nullRate, () -> Slices.random(16, random))); } - private static Block createRandomTinyintsBlock(int positionCount, float nullRate) + private static ValueBlock createRandomTinyintsBlock(int positionCount, float nullRate) { Random random = random(); return createTypedLongsBlock(TINYINT, generateListWithNulls(positionCount, nullRate, () -> (long) (byte) random.nextLong())); } - public static Block createRandomDoublesBlock(int positionCount, float nullRate) + public static ValueBlock createRandomDoublesBlock(int positionCount, float nullRate) { Random random = random(); return createDoublesBlock(generateListWithNulls(positionCount, nullRate, random::nextDouble)); } - public static Block createRandomCharsBlock(CharType charType, int positionCount, float nullRate) + public static ValueBlock createRandomCharsBlock(CharType charType, int positionCount, float nullRate) { return createCharsBlock(charType, generateListWithNulls(positionCount, nullRate, () -> generateRandomStringWithLength(charType.getLength()))); } @@ -379,14 +380,14 @@ public static Set chooseNullPositions(int positionCount, float nullRate return chooseRandomUnique(positionCount, nullCount); } - public static Block createStringsBlock(String... values) + public static ValueBlock createStringsBlock(String... values) { requireNonNull(values, "values is null"); return createStringsBlock(Arrays.asList(values)); } - public static Block createStringsBlock(Iterable values) + public static ValueBlock createStringsBlock(Iterable values) { BlockBuilder builder = VARCHAR.createBlockBuilder(null, 100); @@ -399,26 +400,26 @@ public static Block createStringsBlock(Iterable values) } } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createSlicesBlock(Slice... values) + public static ValueBlock createSlicesBlock(Slice... values) { requireNonNull(values, "values is null"); return createSlicesBlock(Arrays.asList(values)); } - public static Block createSlicesBlock(Iterable values) + public static ValueBlock createSlicesBlock(Iterable values) { return createSlicesBlock(VARBINARY, values); } - public static Block createSlicesBlock(Type type, Iterable values) + public static ValueBlock createSlicesBlock(Type type, Iterable values) { return createBlock(type, type::writeSlice, values); } - public static Block createStringSequenceBlock(int start, int end) + public static ValueBlock createStringSequenceBlock(int start, int end) { BlockBuilder builder = VARCHAR.createBlockBuilder(null, 100); @@ -426,7 +427,7 @@ public static Block createStringSequenceBlock(int start, int end) VARCHAR.writeString(builder, String.valueOf(i)); } - return builder.build(); + return builder.buildValueBlock(); } public static Block createStringDictionaryBlock(int start, int length) @@ -445,7 +446,7 @@ public static Block createStringDictionaryBlock(int start, int length) return DictionaryBlock.create(ids.length, builder.build(), ids); } - public static Block createStringArraysBlock(Iterable> values) + public static ValueBlock createStringArraysBlock(Iterable> values) { ArrayType arrayType = new ArrayType(VARCHAR); BlockBuilder builder = arrayType.createBlockBuilder(null, 100); @@ -459,22 +460,22 @@ public static Block createStringArraysBlock(Iterable> } } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createBooleansBlock(Boolean... values) + public static ValueBlock createBooleansBlock(Boolean... values) { requireNonNull(values, "values is null"); return createBooleansBlock(Arrays.asList(values)); } - public static Block createBooleansBlock(Boolean value, int count) + public static ValueBlock createBooleansBlock(Boolean value, int count) { return createBooleansBlock(Collections.nCopies(count, value)); } - public static Block createBooleansBlock(Iterable values) + public static ValueBlock createBooleansBlock(Iterable values) { BlockBuilder builder = BOOLEAN.createBlockBuilder(null, 100); @@ -487,17 +488,17 @@ public static Block createBooleansBlock(Iterable values) } } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createShortDecimalsBlock(String... values) + public static ValueBlock createShortDecimalsBlock(String... values) { requireNonNull(values, "values is null"); return createShortDecimalsBlock(Arrays.asList(values)); } - public static Block createShortDecimalsBlock(Iterable values) + public static ValueBlock createShortDecimalsBlock(Iterable values) { DecimalType shortDecimalType = DecimalType.createDecimalType(1); BlockBuilder builder = shortDecimalType.createBlockBuilder(null, 100); @@ -511,17 +512,17 @@ public static Block createShortDecimalsBlock(Iterable values) } } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createLongDecimalsBlock(String... values) + public static ValueBlock createLongDecimalsBlock(String... values) { requireNonNull(values, "values is null"); return createLongDecimalsBlock(Arrays.asList(values)); } - public static Block createLongDecimalsBlock(Iterable values) + public static ValueBlock createLongDecimalsBlock(Iterable values) { DecimalType longDecimalType = DecimalType.createDecimalType(MAX_SHORT_PRECISION + 1); BlockBuilder builder = longDecimalType.createBlockBuilder(null, 100); @@ -535,16 +536,16 @@ public static Block createLongDecimalsBlock(Iterable values) } } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createLongTimestampBlock(TimestampType type, LongTimestamp... values) + public static ValueBlock createLongTimestampBlock(TimestampType type, LongTimestamp... values) { requireNonNull(values, "values is null"); return createLongTimestampBlock(type, Arrays.asList(values)); } - public static Block createLongTimestampBlock(TimestampType type, Iterable values) + public static ValueBlock createLongTimestampBlock(TimestampType type, Iterable values) { BlockBuilder builder = type.createBlockBuilder(null, 100); @@ -557,51 +558,51 @@ public static Block createLongTimestampBlock(TimestampType type, Iterable values) + public static ValueBlock createCharsBlock(CharType charType, List values) { return createBlock(charType, charType::writeString, values); } - public static Block createTinyintsBlock(Integer... values) + public static ValueBlock createTinyintsBlock(Integer... values) { requireNonNull(values, "values is null"); return createTinyintsBlock(Arrays.asList(values)); } - public static Block createTinyintsBlock(Iterable values) + public static ValueBlock createTinyintsBlock(Iterable values) { return createBlock(TINYINT, (ValueWriter) TINYINT::writeLong, values); } - public static Block createSmallintsBlock(Integer... values) + public static ValueBlock createSmallintsBlock(Integer... values) { requireNonNull(values, "values is null"); return createSmallintsBlock(Arrays.asList(values)); } - public static Block createSmallintsBlock(Iterable values) + public static ValueBlock createSmallintsBlock(Iterable values) { return createBlock(SMALLINT, (ValueWriter) SMALLINT::writeLong, values); } - public static Block createIntsBlock(Integer... values) + public static ValueBlock createIntsBlock(Integer... values) { requireNonNull(values, "values is null"); return createIntsBlock(Arrays.asList(values)); } - public static Block createIntsBlock(Iterable values) + public static ValueBlock createIntsBlock(Iterable values) { return createBlock(INTEGER, (ValueWriter) INTEGER::writeLong, values); } - public static Block createRowBlock(List fieldTypes, Object[]... rows) + public static ValueBlock createRowBlock(List fieldTypes, Object[]... rows) { RowBlockBuilder rowBlockBuilder = new RowBlockBuilder(fieldTypes, null, 1); for (Object[] row : rows) { @@ -647,16 +648,16 @@ else if (fieldValue instanceof Integer) { }); } - return rowBlockBuilder.build(); + return rowBlockBuilder.buildValueBlock(); } - public static Block createEmptyLongsBlock() + public static ValueBlock createEmptyLongsBlock() { - return BIGINT.createFixedSizeBlockBuilder(0).build(); + return BIGINT.createFixedSizeBlockBuilder(0).buildValueBlock(); } // This method makes it easy to create blocks without having to add an L to every value - public static Block createLongsBlock(int... values) + public static ValueBlock createLongsBlock(int... values) { BlockBuilder builder = BIGINT.createBlockBuilder(null, 100); @@ -664,27 +665,27 @@ public static Block createLongsBlock(int... values) BIGINT.writeLong(builder, value); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createLongsBlock(Long... values) + public static ValueBlock createLongsBlock(Long... values) { requireNonNull(values, "values is null"); return createLongsBlock(Arrays.asList(values)); } - public static Block createLongsBlock(Iterable values) + public static ValueBlock createLongsBlock(Iterable values) { return createTypedLongsBlock(BIGINT, values); } - public static Block createTypedLongsBlock(Type type, Iterable values) + public static ValueBlock createTypedLongsBlock(Type type, Iterable values) { return createBlock(type, type::writeLong, values); } - public static Block createLongSequenceBlock(int start, int end) + public static ValueBlock createLongSequenceBlock(int start, int end) { BlockBuilder builder = BIGINT.createFixedSizeBlockBuilder(end - start); @@ -692,7 +693,7 @@ public static Block createLongSequenceBlock(int start, int end) BIGINT.writeLong(builder, i); } - return builder.build(); + return builder.buildValueBlock(); } public static Block createLongDictionaryBlock(int start, int length) @@ -716,34 +717,34 @@ public static Block createLongDictionaryBlock(int start, int length, int diction return DictionaryBlock.create(ids.length, builder.build(), ids); } - public static Block createLongRepeatBlock(int value, int length) + public static ValueBlock createLongRepeatBlock(int value, int length) { BlockBuilder builder = BIGINT.createFixedSizeBlockBuilder(length); for (int i = 0; i < length; i++) { BIGINT.writeLong(builder, value); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createDoubleRepeatBlock(double value, int length) + public static ValueBlock createDoubleRepeatBlock(double value, int length) { BlockBuilder builder = DOUBLE.createFixedSizeBlockBuilder(length); for (int i = 0; i < length; i++) { DOUBLE.writeDouble(builder, value); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createTimestampsWithTimeZoneMillisBlock(Long... values) + public static ValueBlock createTimestampsWithTimeZoneMillisBlock(Long... values) { BlockBuilder builder = TIMESTAMP_TZ_MILLIS.createFixedSizeBlockBuilder(values.length); for (long value : values) { TIMESTAMP_TZ_MILLIS.writeLong(builder, value); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createBooleanSequenceBlock(int start, int end) + public static ValueBlock createBooleanSequenceBlock(int start, int end) { BlockBuilder builder = BOOLEAN.createFixedSizeBlockBuilder(end - start); @@ -751,17 +752,17 @@ public static Block createBooleanSequenceBlock(int start, int end) BOOLEAN.writeBoolean(builder, i % 2 == 0); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createBlockOfReals(Float... values) + public static ValueBlock createBlockOfReals(Float... values) { requireNonNull(values, "values is null"); return createBlockOfReals(Arrays.asList(values)); } - public static Block createBlockOfReals(Iterable values) + public static ValueBlock createBlockOfReals(Iterable values) { BlockBuilder builder = REAL.createBlockBuilder(null, 100); for (Float value : values) { @@ -772,10 +773,10 @@ public static Block createBlockOfReals(Iterable values) REAL.writeLong(builder, floatToRawIntBits(value)); } } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createSequenceBlockOfReal(int start, int end) + public static ValueBlock createSequenceBlockOfReal(int start, int end) { BlockBuilder builder = REAL.createFixedSizeBlockBuilder(end - start); @@ -783,22 +784,22 @@ public static Block createSequenceBlockOfReal(int start, int end) REAL.writeLong(builder, floatToRawIntBits(i)); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createDoublesBlock(Double... values) + public static ValueBlock createDoublesBlock(Double... values) { requireNonNull(values, "values is null"); return createDoublesBlock(Arrays.asList(values)); } - public static Block createDoublesBlock(Iterable values) + public static ValueBlock createDoublesBlock(Iterable values) { return createBlock(DOUBLE, DOUBLE::writeDouble, values); } - public static Block createDoubleSequenceBlock(int start, int end) + public static ValueBlock createDoubleSequenceBlock(int start, int end) { BlockBuilder builder = DOUBLE.createFixedSizeBlockBuilder(end - start); @@ -806,10 +807,10 @@ public static Block createDoubleSequenceBlock(int start, int end) DOUBLE.writeDouble(builder, i); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createArrayBigintBlock(Iterable> values) + public static ValueBlock createArrayBigintBlock(Iterable> values) { ArrayType arrayType = new ArrayType(BIGINT); BlockBuilder builder = arrayType.createBlockBuilder(null, 100); @@ -823,10 +824,10 @@ public static Block createArrayBigintBlock(Iterable> va } } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createDateSequenceBlock(int start, int end) + public static ValueBlock createDateSequenceBlock(int start, int end) { BlockBuilder builder = DATE.createFixedSizeBlockBuilder(end - start); @@ -834,10 +835,10 @@ public static Block createDateSequenceBlock(int start, int end) DATE.writeLong(builder, i); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createTimestampSequenceBlock(int start, int end) + public static ValueBlock createTimestampSequenceBlock(int start, int end) { BlockBuilder builder = TIMESTAMP_MILLIS.createFixedSizeBlockBuilder(end - start); @@ -845,10 +846,10 @@ public static Block createTimestampSequenceBlock(int start, int end) TIMESTAMP_MILLIS.writeLong(builder, multiplyExact(i, MICROSECONDS_PER_MILLISECOND)); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createShortDecimalSequenceBlock(int start, int end, DecimalType type) + public static ValueBlock createShortDecimalSequenceBlock(int start, int end, DecimalType type) { BlockBuilder builder = type.createFixedSizeBlockBuilder(end - start); long base = BigInteger.TEN.pow(type.getScale()).longValue(); @@ -857,10 +858,10 @@ public static Block createShortDecimalSequenceBlock(int start, int end, DecimalT type.writeLong(builder, base * i); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createLongDecimalSequenceBlock(int start, int end, DecimalType type) + public static ValueBlock createLongDecimalSequenceBlock(int start, int end, DecimalType type) { BlockBuilder builder = type.createFixedSizeBlockBuilder(end - start); BigInteger base = BigInteger.TEN.pow(type.getScale()); @@ -869,25 +870,25 @@ public static Block createLongDecimalSequenceBlock(int start, int end, DecimalTy type.writeObject(builder, Int128.valueOf(BigInteger.valueOf(i).multiply(base))); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createColorRepeatBlock(int value, int length) + public static ValueBlock createColorRepeatBlock(int value, int length) { BlockBuilder builder = COLOR.createFixedSizeBlockBuilder(length); for (int i = 0; i < length; i++) { COLOR.writeLong(builder, value); } - return builder.build(); + return builder.buildValueBlock(); } - public static Block createColorSequenceBlock(int start, int end) + public static ValueBlock createColorSequenceBlock(int start, int end) { BlockBuilder builder = COLOR.createBlockBuilder(null, end - start); for (int i = start; i < end; ++i) { COLOR.writeLong(builder, i); } - return builder.build(); + return builder.buildValueBlock(); } public static Block createRepeatedValuesBlock(double value, int positionCount) @@ -904,7 +905,7 @@ public static Block createRepeatedValuesBlock(long value, int positionCount) return RunLengthEncodedBlock.create(blockBuilder.build(), positionCount); } - private static Block createBlock(Type type, ValueWriter valueWriter, Iterable values) + private static ValueBlock createBlock(Type type, ValueWriter valueWriter, Iterable values) { BlockBuilder builder = type.createBlockBuilder(null, 100); @@ -917,7 +918,7 @@ private static Block createBlock(Type type, ValueWriter valueWriter, Iter } } - return builder.build(); + return builder.buildValueBlock(); } private interface ValueWriter diff --git a/core/trino-main/src/test/java/io/trino/block/TestBlockBuilder.java b/core/trino-main/src/test/java/io/trino/block/TestBlockBuilder.java index 6c08e69f4f0a..7eb0dac72e12 100644 --- a/core/trino-main/src/test/java/io/trino/block/TestBlockBuilder.java +++ b/core/trino-main/src/test/java/io/trino/block/TestBlockBuilder.java @@ -138,7 +138,7 @@ private static void assertInvalidPosition(Block block, int[] positions, int offs { assertThatThrownBy(() -> block.getPositions(positions, offset, length).getLong(0, 0)) .isInstanceOfAny(IllegalArgumentException.class, IndexOutOfBoundsException.class) - .hasMessage("Invalid position %d in block with %d positions", positions[0], block.getPositionCount()); + .hasMessage("Invalid position %d and length 1 in block with %d positions", positions[0], block.getPositionCount()); } private static void assertInvalidOffset(Block block, int[] positions, int offset, int length) diff --git a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java index d5c81aee99ba..6702a7e5bd2b 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java @@ -43,6 +43,7 @@ import io.trino.sql.gen.JoinFilterFunctionCompiler; import io.trino.sql.gen.OrderingCompiler; import io.trino.sql.gen.PageFunctionCompiler; +import io.trino.sql.planner.CompilerConfig; import io.trino.sql.planner.LocalExecutionPlanner; import io.trino.sql.planner.NodePartitioningManager; import io.trino.sql.planner.Partitioning; @@ -177,7 +178,8 @@ public static LocalExecutionPlanner createTestingPlanner() PLANNER_CONTEXT.getTypeOperators(), new TableExecuteContextManager(), new ExchangeManagerRegistry(), - new NodeVersion("test")); + new NodeVersion("test"), + new CompilerConfig()); } public static TaskInfo updateTask(SqlTask sqlTask, List splitAssignments, OutputBuffers outputBuffers) diff --git a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java index 00d91fbe8514..174820181ad1 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java @@ -38,6 +38,7 @@ import io.trino.security.AllowAllAccessControl; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationFunctionMetadata; import io.trino.spi.function.AggregationState; @@ -388,7 +389,7 @@ public static final class BlockInputAggregationFunction @InputFunction public static void input( @AggregationState NullableDoubleState state, - @BlockPosition @SqlType(DOUBLE) Block value, + @BlockPosition @SqlType(DOUBLE) ValueBlock value, @BlockIndex int id) { // noop this is only for annotation testing puproses diff --git a/core/trino-main/src/test/java/io/trino/operator/TestDynamicFilterSourceOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestDynamicFilterSourceOperator.java index 0c788312c62e..0e5c486c22df 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestDynamicFilterSourceOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestDynamicFilterSourceOperator.java @@ -53,6 +53,7 @@ import static io.trino.block.BlockAssertions.createDoubleRepeatBlock; import static io.trino.block.BlockAssertions.createDoubleSequenceBlock; import static io.trino.block.BlockAssertions.createDoublesBlock; +import static io.trino.block.BlockAssertions.createIntsBlock; import static io.trino.block.BlockAssertions.createLongRepeatBlock; import static io.trino.block.BlockAssertions.createLongSequenceBlock; import static io.trino.block.BlockAssertions.createLongsBlock; @@ -280,9 +281,9 @@ public void testCollectWithNulls() OperatorFactory operatorFactory = createOperatorFactory(channel(0, INTEGER)); verifyPassthrough(createOperator(operatorFactory), ImmutableList.of(INTEGER), - new Page(createLongsBlock(1, 2, 3)), + new Page(createIntsBlock(1, 2, 3)), new Page(blockWithNulls), - new Page(createLongsBlock(4, 5))); + new Page(createIntsBlock(4, 5))); operatorFactory.noMoreOperators(); assertEquals(partitions.build(), ImmutableList.of( diff --git a/core/trino-main/src/test/java/io/trino/operator/TestPageUtils.java b/core/trino-main/src/test/java/io/trino/operator/TestPageUtils.java index a01b11145539..42f41023daae 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestPageUtils.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestPageUtils.java @@ -14,11 +14,12 @@ package io.trino.operator; import io.trino.spi.Page; +import io.trino.spi.block.ArrayBlock; import io.trino.spi.block.Block; -import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.LazyBlock; import org.junit.jupiter.api.Test; +import java.util.Optional; import java.util.concurrent.atomic.AtomicLong; import static io.trino.block.BlockAssertions.createIntsBlock; @@ -50,17 +51,20 @@ public void testRecordMaterializedBytes() public void testNestedBlocks() { Block elements = lazyWrapper(createIntsBlock(1, 2, 3)); - Block dictBlock = DictionaryBlock.create(2, elements, new int[] {0, 0}); - Page page = new Page(2, dictBlock); + Block arrayBlock = ArrayBlock.fromElementBlock(2, Optional.empty(), new int[] {0, 1, 3}, elements); + long initialArraySize = arrayBlock.getSizeInBytes(); + Page page = new Page(2, arrayBlock); AtomicLong sizeInBytes = new AtomicLong(); recordMaterializedBytes(page, sizeInBytes::getAndAdd); - assertEquals(sizeInBytes.get(), dictBlock.getSizeInBytes()); + assertEquals(arrayBlock.getSizeInBytes(), initialArraySize); + assertEquals(sizeInBytes.get(), arrayBlock.getSizeInBytes()); // dictionary block caches size in bytes - dictBlock.getLoadedBlock(); - assertEquals(sizeInBytes.get(), dictBlock.getSizeInBytes() + elements.getSizeInBytes()); + arrayBlock.getLoadedBlock(); + assertEquals(sizeInBytes.get(), arrayBlock.getSizeInBytes()); + assertEquals(sizeInBytes.get(), initialArraySize + elements.getSizeInBytes()); } private static LazyBlock lazyWrapper(Block block) diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java index 1a106a41c949..0d2561864be8 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java @@ -42,6 +42,7 @@ import java.util.Optional; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; +import static io.trino.operator.aggregation.AccumulatorCompiler.generateAccumulatorFactory; import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL; import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod; @@ -54,15 +55,28 @@ public class TestAccumulatorCompiler { @Test public void testAccumulatorCompilerForTypeSpecificObjectParameter() + { + testAccumulatorCompilerForTypeSpecificObjectParameter(true); + testAccumulatorCompilerForTypeSpecificObjectParameter(false); + } + + private void testAccumulatorCompilerForTypeSpecificObjectParameter(boolean specializedLoops) { TimestampType parameterType = TimestampType.TIMESTAMP_NANOS; assertThat(parameterType.getJavaType()).isEqualTo(LongTimestamp.class); - assertGenerateAccumulator(LongTimestampAggregation.class, LongTimestampAggregationState.class); + assertGenerateAccumulator(LongTimestampAggregation.class, LongTimestampAggregationState.class, specializedLoops); } @Test public void testAccumulatorCompilerForTypeSpecificObjectParameterSeparateClassLoader() throws Exception + { + testAccumulatorCompilerForTypeSpecificObjectParameterSeparateClassLoader(true); + testAccumulatorCompilerForTypeSpecificObjectParameterSeparateClassLoader(false); + } + + private void testAccumulatorCompilerForTypeSpecificObjectParameterSeparateClassLoader(boolean specializedLoops) + throws Exception { TimestampType parameterType = TimestampType.TIMESTAMP_NANOS; assertThat(parameterType.getJavaType()).isEqualTo(LongTimestamp.class); @@ -80,10 +94,10 @@ public void testAccumulatorCompilerForTypeSpecificObjectParameterSeparateClassLo assertThat(aggregation.getCanonicalName()).isEqualTo(LongTimestampAggregation.class.getCanonicalName()); assertThat(aggregation).isNotSameAs(LongTimestampAggregation.class); - assertGenerateAccumulator(aggregation, stateInterface); + assertGenerateAccumulator(aggregation, stateInterface, specializedLoops); } - private static void assertGenerateAccumulator(Class aggregation, Class stateInterface) + private static void assertGenerateAccumulator(Class aggregation, Class stateInterface, boolean specializedLoops) { AccumulatorStateSerializer stateSerializer = StateCompiler.generateStateSerializer(stateInterface); AccumulatorStateFactory stateFactory = StateCompiler.generateStateFactory(stateInterface); @@ -105,7 +119,7 @@ private static void assertGenerateAccumulator(Cl FunctionNullability functionNullability = new FunctionNullability(false, ImmutableList.of(false)); // test if we can compile aggregation - AccumulatorFactory accumulatorFactory = AccumulatorCompiler.generateAccumulatorFactory(signature, implementation, functionNullability); + AccumulatorFactory accumulatorFactory = generateAccumulatorFactory(signature, implementation, functionNullability, specializedLoops); assertThat(accumulatorFactory).isNotNull(); // compile window aggregation diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationLoopBuilder.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationLoopBuilder.java new file mode 100644 index 000000000000..53d55a59df6d --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAggregationLoopBuilder.java @@ -0,0 +1,168 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.block.Block; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.SqlNullable; +import io.trino.spi.function.SqlType; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.lang.invoke.MethodHandle; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static io.trino.operator.aggregation.AggregationLoopBuilder.buildLoop; +import static java.lang.invoke.MethodHandles.lookup; +import static java.lang.invoke.MethodType.methodType; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestAggregationLoopBuilder +{ + private static final MethodHandle INPUT_FUNCTION; + private static final Object LAMBDA_A = "lambda a"; + private static final Object LAMBDA_B = 1234L; + + static { + try { + INPUT_FUNCTION = lookup().findStatic( + TestAggregationLoopBuilder.class, + "input", + methodType(void.class, InvocationList.class, ValueBlock.class, int.class, ValueBlock.class, int.class, Object.class, Object.class)); + } + catch (ReflectiveOperationException e) { + throw new ExceptionInInitializerError(e); + } + } + + private MethodHandle loop; + private List keyBlocks; + private List valueBlocks; + + @BeforeClass + public void setUp() + throws ReflectiveOperationException + { + loop = buildLoop(INPUT_FUNCTION, 1, 2, false); + + ValueBlock keyBasic = new IntArrayBlock(5, Optional.empty(), new int[] {10, 11, 12, 13, 14}); + ValueBlock keyRleValue = new IntArrayBlock(1, Optional.empty(), new int[] {33}); + ValueBlock keyDictionary = new IntArrayBlock(3, Optional.empty(), new int[] {55, 54, 53}); + + keyBlocks = ImmutableList.builder() + .add(new TestParameter(keyBasic, keyBasic, new int[] {0, 1, 2, 3, 4})) + .add(new TestParameter(RunLengthEncodedBlock.create(keyRleValue, 5), keyRleValue, new int[] {0, 0, 0, 0, 0})) + .add(new TestParameter(DictionaryBlock.create(7, keyDictionary, new int[] {9, 9, 2, 1, 0, 1, 2}).getRegion(2, 5), keyDictionary, new int[] {2, 1, 0, 1, 2})) + .build(); + + ValueBlock valueBasic = new IntArrayBlock(5, Optional.empty(), new int[] {10, 11, 12, 13, 14}); + ValueBlock valueRleValue = new IntArrayBlock(1, Optional.empty(), new int[] {44}); + ValueBlock valueDictionary = new IntArrayBlock(3, Optional.empty(), new int[] {66, 65, 64}); + + valueBlocks = ImmutableList.builder() + .add(new TestParameter(valueBasic, valueBasic, new int[] {0, 1, 2, 3, 4})) + .add(new TestParameter(RunLengthEncodedBlock.create(valueRleValue, 5), valueRleValue, new int[] {0, 0, 0, 0, 0})) + .add(new TestParameter(DictionaryBlock.create(7, valueDictionary, new int[] {9, 9, 0, 1, 2, 1, 0}).getRegion(2, 5), valueDictionary, new int[] {0, 1, 2, 1, 0})) + .build(); + } + + @Test + public void testSelectAll() + throws Throwable + { + AggregationMask mask = AggregationMask.createSelectAll(5); + for (TestParameter keyBlock : keyBlocks) { + for (TestParameter valueBlock : valueBlocks) { + InvocationList invocationList = new InvocationList(); + loop.invokeExact(mask, invocationList, keyBlock.inputBlock(), valueBlock.inputBlock(), LAMBDA_A, LAMBDA_B); + assertThat(invocationList.getInvocations()).isEqualTo(buildExpectedInvocation(keyBlock, valueBlock, mask).getInvocations()); + } + } + } + + @Test + public void testMasked() + throws Throwable + { + AggregationMask mask = AggregationMask.createSelectedPositions(5, new int[] {1, 2, 4}, 3); + for (TestParameter keyBlock : keyBlocks) { + for (TestParameter valueBlock : valueBlocks) { + InvocationList invocationList = new InvocationList(); + loop.invokeExact(mask, invocationList, keyBlock.inputBlock(), valueBlock.inputBlock(), LAMBDA_A, LAMBDA_B); + assertThat(invocationList.getInvocations()).isEqualTo(buildExpectedInvocation(keyBlock, valueBlock, mask).getInvocations()); + } + } + } + + private static InvocationList buildExpectedInvocation(TestParameter keyBlock, TestParameter valueBlock, AggregationMask mask) + { + InvocationList invocationList = new InvocationList(); + int[] keyPositions = keyBlock.invokedPositions(); + int[] valuePositions = valueBlock.invokedPositions(); + if (mask.isSelectAll()) { + for (int position = 0; position < keyPositions.length; position++) { + invocationList.add(keyBlock.invokedBlock(), keyPositions[position], valueBlock.invokedBlock(), valuePositions[position], LAMBDA_A, LAMBDA_B); + } + } + else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < mask.getSelectedPositionCount(); i++) { + int position = selectedPositions[i]; + invocationList.add(keyBlock.invokedBlock(), keyPositions[position], valueBlock.invokedBlock(), valuePositions[position], LAMBDA_A, LAMBDA_B); + } + } + return invocationList; + } + + @SuppressWarnings("UnusedVariable") + private record TestParameter(Block inputBlock, ValueBlock invokedBlock, int[] invokedPositions) {} + + public static void input( + @AggregationState InvocationList invocationList, + @BlockPosition @SqlType("K") ValueBlock keyBlock, + @BlockIndex int keyPosition, + @SqlNullable @BlockPosition @SqlType("V") ValueBlock valueBlock, + @BlockIndex int valuePosition, + Object lambdaA, + Object lambdaB) + { + invocationList.add(keyBlock, keyPosition, valueBlock, valuePosition, lambdaA, lambdaB); + } + + public static class InvocationList + { + private final List invocations = new ArrayList<>(); + + public void add(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition, Object lambdaA, Object lambdaB) + { + invocations.add(new Invocation(keyBlock, keyPosition, valueBlock, valuePosition, lambdaA, lambdaB)); + } + + public List getInvocations() + { + return ImmutableList.copyOf(invocations); + } + + public record Invocation(ValueBlock keyBlock, int keyPosition, ValueBlock valueBlock, int valuePosition, Object lambdaA, Object lambdaB) {} + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestCountNullAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestCountNullAggregation.java index a269c53b08b4..835f0f3fb94e 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestCountNullAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestCountNullAggregation.java @@ -18,6 +18,7 @@ import io.trino.operator.aggregation.state.NullableLongState; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -73,7 +74,7 @@ public static final class CountNull private CountNull() {} @InputFunction - public static void input(@AggregationState NullableLongState state, @BlockPosition @SqlNullable @SqlType(StandardTypes.BIGINT) Block block, @BlockIndex int position) + public static void input(@AggregationState NullableLongState state, @BlockPosition @SqlNullable @SqlType(StandardTypes.BIGINT) ValueBlock block, @BlockIndex int position) { if (block.isNull(position)) { state.setValue(state.getValue() + 1); diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java index c960a3c3462e..4554352ed605 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalAverageAggregation.java @@ -17,6 +17,7 @@ import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState; import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.Int128ArrayBlock; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; import io.trino.spi.type.Int128; @@ -233,7 +234,7 @@ private static void addToState(DecimalType type, LongDecimalWithOverflowAndLongS else { BlockBuilder blockBuilder = type.createFixedSizeBlockBuilder(1); type.writeObject(blockBuilder, Int128.valueOf(value)); - DecimalAverageAggregation.inputLongDecimal(state, blockBuilder.build(), 0); + DecimalAverageAggregation.inputLongDecimal(state, (Int128ArrayBlock) blockBuilder.buildValueBlock(), 0); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalSumAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalSumAggregation.java index 7689ab2a3f8d..66ead07005fc 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalSumAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestDecimalSumAggregation.java @@ -16,6 +16,7 @@ import io.trino.operator.aggregation.state.LongDecimalWithOverflowState; import io.trino.operator.aggregation.state.LongDecimalWithOverflowStateFactory; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.Int128ArrayBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Int128; @@ -142,7 +143,7 @@ private static void addToState(LongDecimalWithOverflowState state, BigInteger va else { BlockBuilder blockBuilder = TYPE.createFixedSizeBlockBuilder(1); TYPE.writeObject(blockBuilder, Int128.valueOf(value)); - DecimalSumAggregation.inputLongDecimal(state, blockBuilder.build(), 0); + DecimalSumAggregation.inputLongDecimal(state, (Int128ArrayBlock) blockBuilder.buildValueBlock(), 0); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestTypedHistogram.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestTypedHistogram.java index 24a739c1a597..dcb81d328e2f 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestTypedHistogram.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestTypedHistogram.java @@ -17,6 +17,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; @@ -27,8 +28,8 @@ import java.util.stream.IntStream; import static io.trino.block.BlockAssertions.assertBlockEquals; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -58,15 +59,15 @@ private static void testMassive(boolean grouped, Type type, ObjIntConsumer IntStream.iterate(value, IntUnaryOperator.identity()).limit(value)) .forEach(value -> writeData.accept(inputBlockBuilder, value)); - Block inputBlock = inputBlockBuilder.build(); + ValueBlock inputBlock = inputBlockBuilder.buildValueBlock(); TypedHistogram typedHistogram = new TypedHistogram( type, TYPE_OPERATORS.getReadValueOperator(type, simpleConvention(BLOCK_BUILDER, FLAT)), - TYPE_OPERATORS.getReadValueOperator(type, simpleConvention(FLAT_RETURN, BLOCK_POSITION_NOT_NULL)), + TYPE_OPERATORS.getReadValueOperator(type, simpleConvention(FLAT_RETURN, VALUE_BLOCK_POSITION_NOT_NULL)), TYPE_OPERATORS.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, FLAT)), - TYPE_OPERATORS.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION_NOT_NULL)), - TYPE_OPERATORS.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)), + TYPE_OPERATORS.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, VALUE_BLOCK_POSITION_NOT_NULL)), + TYPE_OPERATORS.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL)), grouped); int groupId = 0; diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java index 57253a903202..63ce7741b279 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java @@ -51,7 +51,7 @@ public TestingAggregationFunction(BoundSignature signature, FunctionNullability .collect(toImmutableList()); intermediateType = (intermediateTypes.size() == 1) ? getOnlyElement(intermediateTypes) : RowType.anonymous(intermediateTypes); this.finalType = signature.getReturnType(); - this.factory = generateAccumulatorFactory(signature, aggregationImplementation, functionNullability); + this.factory = generateAccumulatorFactory(signature, aggregationImplementation, functionNullability, true); distinctFactory = new DistinctAccumulatorFactory( factory, parameterTypes, diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/listagg/TestListaggAggregationFunction.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/listagg/TestListaggAggregationFunction.java index a70a662e7c0d..5cc028271641 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/listagg/TestListaggAggregationFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/listagg/TestListaggAggregationFunction.java @@ -16,7 +16,7 @@ import io.airlift.slice.Slice; import io.trino.metadata.TestingFunctionResolution; import io.trino.spi.TrinoException; -import io.trino.spi.block.Block; +import io.trino.spi.block.ValueBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.sql.analyzer.TypeSignatureProvider; import org.junit.jupiter.api.Test; @@ -46,17 +46,17 @@ public void testInputEmptyState() SingleListaggAggregationState state = new SingleListaggAggregationState(); String s = "value1"; - Block value = createStringsBlock(s); + ValueBlock value = createStringsBlock(s); Slice separator = utf8Slice(","); Slice overflowFiller = utf8Slice("..."); ListaggAggregationFunction.input( state, value, + 0, separator, false, overflowFiller, - true, - 0); + true); VariableWidthBlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, 1); state.write(blockBuilder); @@ -74,11 +74,11 @@ public void testInputOverflowOverflowFillerTooLong() assertThatThrownBy(() -> ListaggAggregationFunction.input( state, createStringsBlock("value1"), + 0, utf8Slice(","), false, utf8Slice(overflowFillerTooLong), - false, - 0)) + false)) .isInstanceOf(TrinoException.class) .matches(throwable -> ((TrinoException) throwable).getErrorCode() == INVALID_FUNCTION_ARGUMENT.toErrorCode()); } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxbyn/TestTypedKeyValueHeap.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxbyn/TestTypedKeyValueHeap.java index ea89f880d812..0e6acb76d8ca 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxbyn/TestTypedKeyValueHeap.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxbyn/TestTypedKeyValueHeap.java @@ -16,8 +16,8 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.spi.type.TypeUtils; @@ -34,8 +34,8 @@ import static io.airlift.slice.Slices.EMPTY_SLICE; import static io.airlift.slice.Slices.utf8Slice; import static io.trino.block.BlockAssertions.assertBlockEquals; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -151,32 +151,32 @@ private static void test(Type keyType, Type valueType, List> private static void test(Type keyType, Type valueType, boolean min, List> testData, Comparator comparator, int capacity) { MethodHandle keyReadFlat = TYPE_OPERATORS.getReadValueOperator(keyType, simpleConvention(BLOCK_BUILDER, FLAT)); - MethodHandle keyWriteFlat = TYPE_OPERATORS.getReadValueOperator(keyType, simpleConvention(FLAT_RETURN, BLOCK_POSITION_NOT_NULL)); + MethodHandle keyWriteFlat = TYPE_OPERATORS.getReadValueOperator(keyType, simpleConvention(FLAT_RETURN, VALUE_BLOCK_POSITION_NOT_NULL)); MethodHandle valueReadFlat = TYPE_OPERATORS.getReadValueOperator(valueType, simpleConvention(BLOCK_BUILDER, FLAT)); - MethodHandle valueWriteFlat = TYPE_OPERATORS.getReadValueOperator(valueType, simpleConvention(FLAT_RETURN, BLOCK_POSITION_NOT_NULL)); + MethodHandle valueWriteFlat = TYPE_OPERATORS.getReadValueOperator(valueType, simpleConvention(FLAT_RETURN, VALUE_BLOCK_POSITION_NOT_NULL)); MethodHandle comparisonFlatFlat; MethodHandle comparisonFlatBlock; if (min) { comparisonFlatFlat = TYPE_OPERATORS.getComparisonUnorderedLastOperator(keyType, simpleConvention(FAIL_ON_NULL, FLAT, FLAT)); - comparisonFlatBlock = TYPE_OPERATORS.getComparisonUnorderedLastOperator(keyType, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION_NOT_NULL)); + comparisonFlatBlock = TYPE_OPERATORS.getComparisonUnorderedLastOperator(keyType, simpleConvention(FAIL_ON_NULL, FLAT, VALUE_BLOCK_POSITION_NOT_NULL)); } else { comparisonFlatFlat = TYPE_OPERATORS.getComparisonUnorderedFirstOperator(keyType, simpleConvention(FAIL_ON_NULL, FLAT, FLAT)); - comparisonFlatBlock = TYPE_OPERATORS.getComparisonUnorderedFirstOperator(keyType, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION_NOT_NULL)); + comparisonFlatBlock = TYPE_OPERATORS.getComparisonUnorderedFirstOperator(keyType, simpleConvention(FAIL_ON_NULL, FLAT, VALUE_BLOCK_POSITION_NOT_NULL)); comparator = comparator.reversed(); } - Block expected = toBlock(valueType, testData.stream() + ValueBlock expected = toBlock(valueType, testData.stream() .sorted(comparing(Entry::key, comparator)) .map(Entry::value) .limit(capacity) .toList()); - Block inputKeys = toBlock(keyType, testData.stream().map(Entry::key).toList()); - Block inputValues = toBlock(valueType, testData.stream().map(Entry::value).toList()); + ValueBlock inputKeys = toBlock(keyType, testData.stream().map(Entry::key).toList()); + ValueBlock inputValues = toBlock(valueType, testData.stream().map(Entry::value).toList()); // verify basic build TypedKeyValueHeap heap = new TypedKeyValueHeap(min, keyReadFlat, keyWriteFlat, valueReadFlat, valueWriteFlat, comparisonFlatFlat, comparisonFlatBlock, keyType, valueType, capacity); - heap.addAll(inputKeys, inputValues); + getAddAll(heap, inputKeys, inputValues); assertEqual(heap, valueType, expected); // verify copy constructor @@ -185,44 +185,47 @@ private static void test(Type keyType, Type valueType, boolean min, List< // build in two parts and merge together TypedKeyValueHeap part1 = new TypedKeyValueHeap(min, keyReadFlat, keyWriteFlat, valueReadFlat, valueWriteFlat, comparisonFlatFlat, comparisonFlatBlock, keyType, valueType, capacity); int splitPoint = inputKeys.getPositionCount() / 2; - part1.addAll( - inputKeys.getRegion(0, splitPoint), - inputValues.getRegion(0, splitPoint)); + getAddAll(part1, inputKeys.getRegion(0, splitPoint), inputValues.getRegion(0, splitPoint)); BlockBuilder part1KeyBlockBuilder = keyType.createBlockBuilder(null, part1.getCapacity()); BlockBuilder part1ValueBlockBuilder = valueType.createBlockBuilder(null, part1.getCapacity()); part1.writeAllUnsorted(part1KeyBlockBuilder, part1ValueBlockBuilder); - Block part1KeyBlock = part1KeyBlockBuilder.build(); - Block part1ValueBlock = part1ValueBlockBuilder.build(); + ValueBlock part1KeyBlock = part1KeyBlockBuilder.buildValueBlock(); + ValueBlock part1ValueBlock = part1ValueBlockBuilder.buildValueBlock(); TypedKeyValueHeap part2 = new TypedKeyValueHeap(min, keyReadFlat, keyWriteFlat, valueReadFlat, valueWriteFlat, comparisonFlatFlat, comparisonFlatBlock, keyType, valueType, capacity); - part2.addAll( - inputKeys.getRegion(splitPoint, inputKeys.getPositionCount() - splitPoint), - inputValues.getRegion(splitPoint, inputValues.getPositionCount() - splitPoint)); + getAddAll(part2, inputKeys.getRegion(splitPoint, inputKeys.getPositionCount() - splitPoint), inputValues.getRegion(splitPoint, inputValues.getPositionCount() - splitPoint)); BlockBuilder part2KeyBlockBuilder = keyType.createBlockBuilder(null, part2.getCapacity()); BlockBuilder part2ValueBlockBuilder = valueType.createBlockBuilder(null, part2.getCapacity()); part2.writeAllUnsorted(part2KeyBlockBuilder, part2ValueBlockBuilder); - Block part2KeyBlock = part2KeyBlockBuilder.build(); - Block part2ValueBlock = part2ValueBlockBuilder.build(); + ValueBlock part2KeyBlock = part2KeyBlockBuilder.buildValueBlock(); + ValueBlock part2ValueBlock = part2ValueBlockBuilder.buildValueBlock(); TypedKeyValueHeap merged = new TypedKeyValueHeap(min, keyReadFlat, keyWriteFlat, valueReadFlat, valueWriteFlat, comparisonFlatFlat, comparisonFlatBlock, keyType, valueType, capacity); - merged.addAll(part1KeyBlock, part1ValueBlock); - merged.addAll(part2KeyBlock, part2ValueBlock); + getAddAll(merged, part1KeyBlock, part1ValueBlock); + getAddAll(merged, part2KeyBlock, part2ValueBlock); assertEqual(merged, valueType, expected); } - private static void assertEqual(TypedKeyValueHeap heap, Type valueType, Block expected) + private static void getAddAll(TypedKeyValueHeap heap, ValueBlock inputKeys, ValueBlock inputValues) + { + for (int i = 0; i < inputKeys.getPositionCount(); i++) { + heap.add(inputKeys, i, inputValues, i); + } + } + + private static void assertEqual(TypedKeyValueHeap heap, Type valueType, ValueBlock expected) { BlockBuilder resultBlockBuilder = valueType.createBlockBuilder(null, OUTPUT_SIZE); heap.writeValuesSorted(resultBlockBuilder); - Block actual = resultBlockBuilder.build(); + ValueBlock actual = resultBlockBuilder.buildValueBlock(); assertBlockEquals(valueType, actual, expected); } - private static Block toBlock(Type type, List inputStream) + private static ValueBlock toBlock(Type type, List inputStream) { BlockBuilder blockBuilder = type.createBlockBuilder(null, INPUT_SIZE); inputStream.forEach(value -> TypeUtils.writeNativeValue(type, blockBuilder, value)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } // TODO remove this suppression when the error prone checker actually supports records correctly diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java index ecba7480d853..fe5b3995c55a 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/minmaxn/TestTypedHeap.java @@ -15,8 +15,8 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.spi.type.TypeUtils; @@ -30,8 +30,8 @@ import java.util.stream.IntStream; import static io.trino.block.BlockAssertions.assertBlockEquals; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -95,26 +95,26 @@ private static void test(Type type, List testData, Comparator comparat private static void test(Type type, boolean min, List testData, Comparator comparator) { MethodHandle readFlat = TYPE_OPERATORS.getReadValueOperator(type, simpleConvention(BLOCK_BUILDER, FLAT)); - MethodHandle writeFlat = TYPE_OPERATORS.getReadValueOperator(type, simpleConvention(FLAT_RETURN, BLOCK_POSITION_NOT_NULL)); + MethodHandle writeFlat = TYPE_OPERATORS.getReadValueOperator(type, simpleConvention(FLAT_RETURN, VALUE_BLOCK_POSITION_NOT_NULL)); MethodHandle comparisonFlatFlat; MethodHandle comparisonFlatBlock; if (min) { comparisonFlatFlat = TYPE_OPERATORS.getComparisonUnorderedLastOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, FLAT)); - comparisonFlatBlock = TYPE_OPERATORS.getComparisonUnorderedLastOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION_NOT_NULL)); + comparisonFlatBlock = TYPE_OPERATORS.getComparisonUnorderedLastOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, VALUE_BLOCK_POSITION_NOT_NULL)); } else { comparisonFlatFlat = TYPE_OPERATORS.getComparisonUnorderedFirstOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, FLAT)); - comparisonFlatBlock = TYPE_OPERATORS.getComparisonUnorderedFirstOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION_NOT_NULL)); + comparisonFlatBlock = TYPE_OPERATORS.getComparisonUnorderedFirstOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, VALUE_BLOCK_POSITION_NOT_NULL)); comparator = comparator.reversed(); } - Block expected = toBlock(type, testData.stream().sorted(comparator).limit(OUTPUT_SIZE).toList()); - Block inputData = toBlock(type, testData); + ValueBlock expected = toBlock(type, testData.stream().sorted(comparator).limit(OUTPUT_SIZE).toList()); + ValueBlock inputData = toBlock(type, testData); // verify basic build TypedHeap heap = new TypedHeap(min, readFlat, writeFlat, comparisonFlatFlat, comparisonFlatBlock, type, OUTPUT_SIZE); - heap.addAll(inputData); + addAll(heap, inputData); assertEqual(heap, type, expected); // verify copy constructor @@ -122,35 +122,42 @@ private static void test(Type type, boolean min, List testData, Comparato // build in two parts and merge together TypedHeap part1 = new TypedHeap(min, readFlat, writeFlat, comparisonFlatFlat, comparisonFlatBlock, type, OUTPUT_SIZE); - part1.addAll(inputData.getRegion(0, inputData.getPositionCount() / 2)); + addAll(part1, inputData.getRegion(0, inputData.getPositionCount() / 2)); BlockBuilder part1BlockBuilder = type.createBlockBuilder(null, part1.getCapacity()); part1.writeAllUnsorted(part1BlockBuilder); - Block part1Block = part1BlockBuilder.build(); + ValueBlock part1Block = part1BlockBuilder.buildValueBlock(); TypedHeap part2 = new TypedHeap(min, readFlat, writeFlat, comparisonFlatFlat, comparisonFlatBlock, type, OUTPUT_SIZE); - part2.addAll(inputData.getRegion(inputData.getPositionCount() / 2, inputData.getPositionCount() - (inputData.getPositionCount() / 2))); + addAll(part2, inputData.getRegion(inputData.getPositionCount() / 2, inputData.getPositionCount() - (inputData.getPositionCount() / 2))); BlockBuilder part2BlockBuilder = type.createBlockBuilder(null, part2.getCapacity()); part2.writeAllUnsorted(part2BlockBuilder); - Block part2Block = part2BlockBuilder.build(); + ValueBlock part2Block = part2BlockBuilder.buildValueBlock(); TypedHeap merged = new TypedHeap(min, readFlat, writeFlat, comparisonFlatFlat, comparisonFlatBlock, type, OUTPUT_SIZE); - merged.addAll(part1Block); - merged.addAll(part2Block); + addAll(merged, part1Block); + addAll(merged, part2Block); assertEqual(merged, type, expected); } - private static void assertEqual(TypedHeap heap, Type type, Block expected) + private static void addAll(TypedHeap heap, ValueBlock inputData) + { + for (int i = 0; i < inputData.getPositionCount(); i++) { + heap.add(inputData, i); + } + } + + private static void assertEqual(TypedHeap heap, Type type, ValueBlock expected) { BlockBuilder resultBlockBuilder = type.createBlockBuilder(null, OUTPUT_SIZE); heap.writeAllSorted(resultBlockBuilder); - Block actual = resultBlockBuilder.build(); + ValueBlock actual = resultBlockBuilder.buildValueBlock(); assertBlockEquals(type, actual, expected); } - private static Block toBlock(Type type, List inputStream) + private static ValueBlock toBlock(Type type, List inputStream) { BlockBuilder blockBuilder = type.createBlockBuilder(null, INPUT_SIZE); inputStream.forEach(value -> TypeUtils.writeNativeValue(type, blockBuilder, value)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/join/TestLookupJoinPageBuilder.java b/core/trino-main/src/test/java/io/trino/operator/join/TestLookupJoinPageBuilder.java index 8c9077e03a06..95247484c0f8 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/TestLookupJoinPageBuilder.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/TestLookupJoinPageBuilder.java @@ -20,6 +20,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -100,7 +101,7 @@ public void testDifferentPositions() JoinProbe probe = joinProbeFactory.createJoinProbe(page); Page output = lookupJoinPageBuilder.build(probe); assertEquals(output.getChannelCount(), 2); - assertTrue(output.getBlock(0) instanceof DictionaryBlock); + assertTrue(output.getBlock(0) instanceof LongArrayBlock); assertEquals(output.getPositionCount(), 0); lookupJoinPageBuilder.reset(); diff --git a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestLookupJoinPageBuilder.java b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestLookupJoinPageBuilder.java index 184540e20f01..d2214565dc7b 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestLookupJoinPageBuilder.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestLookupJoinPageBuilder.java @@ -21,6 +21,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -101,7 +102,7 @@ public void testDifferentPositions() JoinProbe probe = joinProbeFactory.createJoinProbe(page, lookupSource); Page output = lookupJoinPageBuilder.build(probe); assertEquals(output.getChannelCount(), 2); - assertTrue(output.getBlock(0) instanceof DictionaryBlock); + assertTrue(output.getBlock(0) instanceof LongArrayBlock); assertEquals(output.getPositionCount(), 0); lookupJoinPageBuilder.reset(); diff --git a/core/trino-main/src/test/java/io/trino/operator/output/TestPagePartitioner.java b/core/trino-main/src/test/java/io/trino/operator/output/TestPagePartitioner.java index 491949cbf6fe..7dee87dfe24d 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/TestPagePartitioner.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/TestPagePartitioner.java @@ -459,7 +459,7 @@ private static ImmutableList getTypes() IPADDRESS); } - private Block createBlockForType(Type type, int positionsPerPage) + private static Block createBlockForType(Type type, int positionsPerPage) { return createRandomBlockForType(type, positionsPerPage, 0.2F); } @@ -707,7 +707,7 @@ public Stream getEnqueuedDeserialized(int partition) public List getEnqueued(int partition) { Collection serializedPages = enqueued.get(partition); - return serializedPages == null ? ImmutableList.of() : ImmutableList.copyOf(serializedPages); + return ImmutableList.copyOf(serializedPages); } public void throwOnEnqueue(RuntimeException throwOnEnqueue) @@ -813,31 +813,20 @@ public Optional getFailureCause() } } - private static class SumModuloPartitionFunction + private record SumModuloPartitionFunction(int partitionCount, int... hashChannels) implements PartitionFunction { - private final int[] hashChannels; - private final int partitionCount; - - SumModuloPartitionFunction(int partitionCount, int... hashChannels) + private SumModuloPartitionFunction { checkArgument(partitionCount > 0); - this.partitionCount = partitionCount; - this.hashChannels = hashChannels; - } - - @Override - public int getPartitionCount() - { - return partitionCount; } @Override public int getPartition(Page page, int position) { long value = 0; - for (int i = 0; i < hashChannels.length; i++) { - value += page.getBlock(hashChannels[i]).getLong(position, 0); + for (int hashChannel : hashChannels) { + value += page.getBlock(hashChannel).getLong(position, 0); } return toIntExact(Math.abs(value) % partitionCount); diff --git a/core/trino-main/src/test/java/io/trino/operator/output/TestPositionsAppender.java b/core/trino-main/src/test/java/io/trino/operator/output/TestPositionsAppender.java index 2f5760112780..c25649d3edf2 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/TestPositionsAppender.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/TestPositionsAppender.java @@ -24,6 +24,7 @@ import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.block.RowBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.type.ArrayType; import io.trino.spi.type.BigintType; @@ -90,8 +91,6 @@ public void testMixedBlockTypes(TestType type) List input = ImmutableList.of( input(emptyBlock(type)), input(nullBlock(type, 3), 0, 2), - input(nullBlock(TestType.UNKNOWN, 3), 0, 2), // a := null projections are handled by UnknownType null block - input(nullBlock(TestType.UNKNOWN, 1), 0), // a := null projections are handled by UnknownType null block, 1 position uses non RLE block input(notNullBlock(type, 3), 1, 2), input(partiallyNullBlock(type, 4), 0, 1, 2, 3), input(partiallyNullBlock(type, 4)), // empty position list @@ -169,7 +168,7 @@ public static Object[][] differentValues() {TestType.INTEGER, createIntsBlock(0), createIntsBlock(1)}, {TestType.CHAR_10, createStringsBlock("0"), createStringsBlock("1")}, {TestType.VARCHAR, createStringsBlock("0"), createStringsBlock("1")}, - {TestType.DOUBLE, createDoublesBlock(0D), createDoublesBlock(1D)}, + {TestType.DOUBLE, createDoublesBlock(0.0), createDoublesBlock(1.0)}, {TestType.SMALLINT, createSmallintsBlock(0), createSmallintsBlock(1)}, {TestType.TINYINT, createTinyintsBlock(0), createTinyintsBlock(1)}, {TestType.VARBINARY, createSlicesBlock(Slices.allocate(Long.BYTES)), createSlicesBlock(Slices.allocate(Long.BYTES).getOutput().appendLong(1).slice())}, @@ -184,7 +183,7 @@ public static Object[][] differentValues() @Test(dataProvider = "types") public void testMultipleRleWithTheSameValueProduceRle(TestType type) { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); Block value = notNullBlock(type, 1); positionsAppender.append(allPositions(3), rleBlock(value, 3)); @@ -199,7 +198,7 @@ public void testMultipleRleWithTheSameValueProduceRle(TestType type) public void testRleAppendForComplexTypeWithNullElement(TestType type, Block value) { checkArgument(value.getPositionCount() == 1); - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); positionsAppender.append(allPositions(3), rleBlock(value, 3)); positionsAppender.append(allPositions(2), rleBlock(value, 2)); @@ -213,7 +212,7 @@ public void testRleAppendForComplexTypeWithNullElement(TestType type, Block valu @Test(dataProvider = "types") public void testRleAppendedWithSinglePositionDoesNotProduceRle(TestType type) { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); Block value = notNullBlock(type, 1); positionsAppender.append(allPositions(3), rleBlock(value, 3)); @@ -226,16 +225,16 @@ public void testRleAppendedWithSinglePositionDoesNotProduceRle(TestType type) } @Test(dataProvider = "types") - public void testMultipleTheSameDictionariesProduceDictionary(TestType type) + public static void testMultipleTheSameDictionariesProduceDictionary(TestType type) { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); testMultipleTheSameDictionariesProduceDictionary(type, positionsAppender); // test if appender can accept different dictionary after a build testMultipleTheSameDictionariesProduceDictionary(type, positionsAppender); } - private void testMultipleTheSameDictionariesProduceDictionary(TestType type, PositionsAppender positionsAppender) + private static void testMultipleTheSameDictionariesProduceDictionary(TestType type, UnnestingPositionsAppender positionsAppender) { Block dictionary = createRandomBlockForType(type, 4, 0); positionsAppender.append(allPositions(3), createRandomDictionaryBlock(dictionary, 3)); @@ -279,11 +278,11 @@ public void testDictionarySingleThenFlat(TestType type) { BlockView firstInput = input(dictionaryBlock(type, 1, 4, 0), 0); BlockView secondInput = input(dictionaryBlock(type, 2, 4, 0), 0, 1); - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); long initialRetainedSize = positionsAppender.getRetainedSizeInBytes(); - firstInput.getPositions().forEach((int position) -> positionsAppender.append(position, firstInput.getBlock())); - positionsAppender.append(secondInput.getPositions(), secondInput.getBlock()); + firstInput.positions().forEach((int position) -> positionsAppender.append(position, firstInput.block())); + positionsAppender.append(secondInput.positions(), secondInput.block()); assertBuildResult(type, ImmutableList.of(firstInput, secondInput), positionsAppender, initialRetainedSize); } @@ -291,7 +290,7 @@ public void testDictionarySingleThenFlat(TestType type) @Test(dataProvider = "types") public void testConsecutiveBuilds(TestType type) { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); // empty block positionsAppender.append(positions(), emptyBlock(type)); @@ -329,17 +328,17 @@ public void testConsecutiveBuilds(TestType type) } // testcase for jit bug described https://github.com/trinodb/trino/issues/12821. - // this test needs to be run first (hence lowest priority) as order of tests - // influence jit compilation making this problem to not occur if other tests are run first. + // this test needs to be run first (hence the lowest priority) as the test order + // influences jit compilation, making this problem to not occur if other tests are run first. @Test(priority = Integer.MIN_VALUE) public void testSliceRle() { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(VARCHAR, 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(VARCHAR, 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); // first append some not empty value to avoid RleAwarePositionsAppender for the empty value positionsAppender.appendRle(singleValueBlock("some value"), 1); // append empty value multiple times to trigger jit compilation - Block emptyStringBlock = singleValueBlock(""); + ValueBlock emptyStringBlock = singleValueBlock(""); for (int i = 0; i < 1000; i++) { positionsAppender.appendRle(emptyStringBlock, 2000); } @@ -355,7 +354,7 @@ public void testRowWithNestedFields() rleBlock(TestType.VARCHAR, 2) }); - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type, 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type, 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); positionsAppender.append(allPositions(2), rowBLock); Block actual = positionsAppender.build(); @@ -375,24 +374,24 @@ public static Object[][] complexTypesWithNullElementBlock() public static Object[][] types() { return Arrays.stream(TestType.values()) - .filter(testType -> !testType.equals(TestType.UNKNOWN)) + .filter(testType -> testType != TestType.UNKNOWN) .map(type -> new Object[] {type}) .toArray(Object[][]::new); } - private static Block singleValueBlock(String value) + private static ValueBlock singleValueBlock(String value) { BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, 1); VARCHAR.writeSlice(blockBuilder, Slices.utf8Slice(value)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } - private IntArrayList allPositions(int count) + private static IntArrayList allPositions(int count) { return new IntArrayList(IntStream.range(0, count).toArray()); } - private BlockView input(Block block, int... positions) + private static BlockView input(Block block, int... positions) { return new BlockView(block, new IntArrayList(positions)); } @@ -402,53 +401,53 @@ private static IntArrayList positions(int... positions) return new IntArrayList(positions); } - private Block dictionaryBlock(Block dictionary, int positionCount) + private static Block dictionaryBlock(Block dictionary, int positionCount) { return createRandomDictionaryBlock(dictionary, positionCount); } - private Block dictionaryBlock(Block dictionary, int[] ids) + private static Block dictionaryBlock(Block dictionary, int[] ids) { return DictionaryBlock.create(ids.length, dictionary, ids); } - private Block dictionaryBlock(TestType type, int positionCount, int dictionarySize, float nullRate) + private static Block dictionaryBlock(TestType type, int positionCount, int dictionarySize, float nullRate) { Block dictionary = createRandomBlockForType(type, dictionarySize, nullRate); return createRandomDictionaryBlock(dictionary, positionCount); } - private RunLengthEncodedBlock rleBlock(Block value, int positionCount) + private static RunLengthEncodedBlock rleBlock(Block value, int positionCount) { checkArgument(positionCount >= 2); return (RunLengthEncodedBlock) RunLengthEncodedBlock.create(value, positionCount); } - private RunLengthEncodedBlock rleBlock(TestType type, int positionCount) + private static RunLengthEncodedBlock rleBlock(TestType type, int positionCount) { checkArgument(positionCount >= 2); Block rleValue = createRandomBlockForType(type, 1, 0); return (RunLengthEncodedBlock) RunLengthEncodedBlock.create(rleValue, positionCount); } - private RunLengthEncodedBlock nullRleBlock(TestType type, int positionCount) + private static RunLengthEncodedBlock nullRleBlock(TestType type, int positionCount) { checkArgument(positionCount >= 2); Block rleValue = nullBlock(type, 1); return (RunLengthEncodedBlock) RunLengthEncodedBlock.create(rleValue, positionCount); } - private Block partiallyNullBlock(TestType type, int positionCount) + private static Block partiallyNullBlock(TestType type, int positionCount) { return createRandomBlockForType(type, positionCount, 0.5F); } - private Block notNullBlock(TestType type, int positionCount) + private static Block notNullBlock(TestType type, int positionCount) { return createRandomBlockForType(type, positionCount, 0); } - private Block nullBlock(TestType type, int positionCount) + private static Block nullBlock(TestType type, int positionCount) { BlockBuilder blockBuilder = type.getType().createBlockBuilder(null, positionCount); for (int i = 0; i < positionCount; i++) { @@ -466,19 +465,19 @@ private static Block nullBlock(Type type, int positionCount) return blockBuilder.build(); } - private Block emptyBlock(TestType type) + private static Block emptyBlock(TestType type) { return type.adapt(type.getType().createBlockBuilder(null, 0).build()); } - private Block createRandomBlockForType(TestType type, int positionCount, float nullRate) + private static Block createRandomBlockForType(TestType type, int positionCount, float nullRate) { return type.adapt(BlockAssertions.createRandomBlockForType(type.getType(), positionCount, nullRate)); } - private void testNullRle(Type type, Block source) + private static void testNullRle(Type type, Block source) { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type, 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type, 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); // extract null positions IntArrayList positions = new IntArrayList(source.getPositionCount()); for (int i = 0; i < source.getPositionCount(); i++) { @@ -495,22 +494,22 @@ private void testNullRle(Type type, Block source) assertInstanceOf(actual, RunLengthEncodedBlock.class); } - private void testAppend(TestType type, List inputs) + private static void testAppend(TestType type, List inputs) { testAppendBatch(type, inputs); testAppendSingle(type, inputs); } - private void testAppendBatch(TestType type, List inputs) + private static void testAppendBatch(TestType type, List inputs) { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); long initialRetainedSize = positionsAppender.getRetainedSizeInBytes(); - inputs.forEach(input -> positionsAppender.append(input.getPositions(), input.getBlock())); + inputs.forEach(input -> positionsAppender.append(input.positions(), input.block())); assertBuildResult(type, inputs, positionsAppender, initialRetainedSize); } - private void assertBuildResult(TestType type, List inputs, PositionsAppender positionsAppender, long initialRetainedSize) + private static void assertBuildResult(TestType type, List inputs, UnnestingPositionsAppender positionsAppender, long initialRetainedSize) { long sizeInBytes = positionsAppender.getSizeInBytes(); assertGreaterThanOrEqual(positionsAppender.getRetainedSizeInBytes(), sizeInBytes); @@ -524,12 +523,12 @@ private void assertBuildResult(TestType type, List inputs, PositionsA assertEquals(secondBlock.getPositionCount(), 0); } - private void testAppendSingle(TestType type, List inputs) + private static void testAppendSingle(TestType type, List inputs) { - PositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + UnnestingPositionsAppender positionsAppender = POSITIONS_APPENDER_FACTORY.create(type.getType(), 10, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); long initialRetainedSize = positionsAppender.getRetainedSizeInBytes(); - inputs.forEach(input -> input.getPositions().forEach((int position) -> positionsAppender.append(position, input.getBlock()))); + inputs.forEach(input -> input.positions().forEach((int position) -> positionsAppender.append(position, input.block()))); long sizeInBytes = positionsAppender.getSizeInBytes(); assertGreaterThanOrEqual(positionsAppender.getRetainedSizeInBytes(), sizeInBytes); Block actual = positionsAppender.build(); @@ -542,7 +541,7 @@ private void testAppendSingle(TestType type, List inputs) assertEquals(secondBlock.getPositionCount(), 0); } - private void assertBlockIsValid(Block actual, long sizeInBytes, Type type, List inputs) + private static void assertBlockIsValid(Block actual, long sizeInBytes, Type type, List inputs) { PageBuilderStatus pageBuilderStatus = new PageBuilderStatus(); BlockBuilderStatus blockBuilderStatus = pageBuilderStatus.createBlockBuilderStatus(); @@ -552,12 +551,12 @@ private void assertBlockIsValid(Block actual, long sizeInBytes, Type type, List< assertEquals(sizeInBytes, pageBuilderStatus.getSizeInBytes()); } - private Block buildBlock(Type type, List inputs, BlockBuilderStatus blockBuilderStatus) + private static Block buildBlock(Type type, List inputs, BlockBuilderStatus blockBuilderStatus) { BlockBuilder blockBuilder = type.createBlockBuilder(blockBuilderStatus, 10); for (BlockView input : inputs) { - for (int position : input.getPositions()) { - type.appendTo(input.getBlock(), position, blockBuilder); + for (int position : input.positions()) { + type.appendTo(input.block(), position, blockBuilder); } } return blockBuilder.build(); @@ -606,30 +605,12 @@ public Type getType() } } - private static class BlockView + private record BlockView(Block block, IntArrayList positions) { - private final Block block; - private final IntArrayList positions; - - private BlockView(Block block, IntArrayList positions) - { - this.block = requireNonNull(block, "block is null"); - this.positions = requireNonNull(positions, "positions is null"); - } - - public Block getBlock() - { - return block; - } - - public IntArrayList getPositions() - { - return positions; - } - - public void appendTo(PositionsAppender positionsAppender) + private BlockView { - positionsAppender.append(getPositions(), getBlock()); + requireNonNull(block, "block is null"); + requireNonNull(positions, "positions is null"); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/output/TestSkewedPartitionRebalancer.java b/core/trino-main/src/test/java/io/trino/operator/output/TestSkewedPartitionRebalancer.java index 3f00f88f3130..5a61bf221287 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/TestSkewedPartitionRebalancer.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/TestSkewedPartitionRebalancer.java @@ -28,14 +28,14 @@ import static io.trino.spi.type.BigintType.BIGINT; import static org.assertj.core.api.Assertions.assertThat; -public class TestSkewedPartitionRebalancer +class TestSkewedPartitionRebalancer { private static final long MIN_PARTITION_DATA_PROCESSED_REBALANCE_THRESHOLD = DataSize.of(1, MEGABYTE).toBytes(); private static final long MIN_DATA_PROCESSED_REBALANCE_THRESHOLD = DataSize.of(50, MEGABYTE).toBytes(); private static final int MAX_REBALANCED_PARTITIONS = 30; @Test - public void testRebalanceWithSkewness() + void testRebalanceWithSkewness() { int partitionCount = 3; SkewedPartitionRebalancer rebalancer = new SkewedPartitionRebalancer( @@ -51,7 +51,7 @@ public void testRebalanceWithSkewness() rebalancer.addPartitionRowCount(1, 1000); rebalancer.addPartitionRowCount(2, 1000); rebalancer.addDataProcessed(DataSize.of(40, MEGABYTE).toBytes()); - // No rebalancing will happen since data processed is less than 50MB limit + // No rebalancing will happen since the data processed is less than 50MB rebalancer.rebalance(); assertThat(getPartitionPositions(function, 17)) @@ -96,7 +96,7 @@ public void testRebalanceWithSkewness() } @Test - public void testRebalanceWithoutSkewness() + void testRebalanceWithoutSkewness() { int partitionCount = 6; SkewedPartitionRebalancer rebalancer = new SkewedPartitionRebalancer( @@ -128,7 +128,7 @@ public void testRebalanceWithoutSkewness() } @Test - public void testNoRebalanceWhenDataWrittenIsLessThanTheRebalanceLimit() + void testNoRebalanceWhenDataWrittenIsLessThanTheRebalanceLimit() { int partitionCount = 3; SkewedPartitionRebalancer rebalancer = new SkewedPartitionRebalancer( @@ -157,7 +157,7 @@ public void testNoRebalanceWhenDataWrittenIsLessThanTheRebalanceLimit() } @Test - public void testNoRebalanceWhenDataWrittenByThePartitionIsLessThanWriterScalingMinDataProcessed() + void testNoRebalanceWhenDataWrittenByThePartitionIsLessThanWriterScalingMinDataProcessed() { int partitionCount = 3; long minPartitionDataProcessedRebalanceThreshold = DataSize.of(50, MEGABYTE).toBytes(); @@ -187,7 +187,7 @@ public void testNoRebalanceWhenDataWrittenByThePartitionIsLessThanWriterScalingM } @Test - public void testRebalancePartitionToSingleTaskInARebalancingLoop() + void testRebalancePartitionToSingleTaskInARebalancingLoop() { int partitionCount = 3; SkewedPartitionRebalancer rebalancer = new SkewedPartitionRebalancer( @@ -204,7 +204,7 @@ public void testRebalancePartitionToSingleTaskInARebalancingLoop() rebalancer.addPartitionRowCount(2, 0); rebalancer.addDataProcessed(DataSize.of(60, MEGABYTE).toBytes()); - // rebalancing will only happen to single task even though two tasks are available + // rebalancing will only happen to a single task even though two tasks are available rebalancer.rebalance(); assertThat(getPartitionPositions(function, 17)) @@ -344,10 +344,10 @@ public void testRebalancePartitionWithMaxRebalancedPartitionsPerTask() .containsExactly(ImmutableList.of(0, 1), ImmutableList.of(1, 0), ImmutableList.of(2)); } - private List> getPartitionPositions(PartitionFunction function, int maxPosition) + private static List> getPartitionPositions(PartitionFunction function, int maxPosition) { List> partitionPositions = new ArrayList<>(); - for (int partition = 0; partition < function.getPartitionCount(); partition++) { + for (int partition = 0; partition < function.partitionCount(); partition++) { partitionPositions.add(new ArrayList<>()); } @@ -364,22 +364,9 @@ private static Page dummyPage() return SequencePageBuilder.createSequencePage(ImmutableList.of(BIGINT), 100, 0); } - private static class TestPartitionFunction + private record TestPartitionFunction(int partitionCount) implements PartitionFunction { - private final int partitionCount; - - private TestPartitionFunction(int partitionCount) - { - this.partitionCount = partitionCount; - } - - @Override - public int getPartitionCount() - { - return partitionCount; - } - @Override public int getPartition(Page page, int position) { diff --git a/core/trino-main/src/test/java/io/trino/operator/output/TestSlicePositionsAppender.java b/core/trino-main/src/test/java/io/trino/operator/output/TestSlicePositionsAppender.java index f9470fe93eaf..c90ffe5234b0 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/TestSlicePositionsAppender.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/TestSlicePositionsAppender.java @@ -17,6 +17,7 @@ import io.airlift.slice.Slices; import io.trino.spi.block.Block; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import org.junit.jupiter.api.Test; import java.util.Arrays; @@ -34,7 +35,7 @@ public void testAppendEmptySliceRle() { // test SlicePositionAppender.appendRle with empty value (Slice with length 0) PositionsAppender positionsAppender = new SlicePositionsAppender(1, 100); - Block value = createStringsBlock(""); + ValueBlock value = createStringsBlock(""); positionsAppender.appendRle(value, 10); Block actualBlock = positionsAppender.build(); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestBlockAndPositionNullConvention.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestBlockAndPositionNullConvention.java index 0181173f5903..40534ddbd362 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestBlockAndPositionNullConvention.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestBlockAndPositionNullConvention.java @@ -15,7 +15,7 @@ import io.airlift.slice.Slice; import io.trino.metadata.InternalFunctionBundle; -import io.trino.spi.block.Block; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.BlockIndex; import io.trino.spi.function.BlockPosition; import io.trino.spi.function.ScalarFunction; @@ -51,6 +51,7 @@ public void init() assertions = new QueryAssertions(); assertions.addFunctions(InternalFunctionBundle.builder() .scalar(FunctionWithBlockAndPositionConvention.class) + .scalar(FunctionWithValueBlockAndPositionConvention.class) .build()); } @@ -105,7 +106,7 @@ public static Object generic(@TypeParameter("E") Type type, @SqlNullable @SqlTyp @TypeParameter("E") @SqlNullable @SqlType("E") - public static Object generic(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType("E") Block block, @BlockIndex int position) + public static Object generic(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType("E") ValueBlock block, @BlockIndex int position) { hitBlockPositionObject.set(true); return readNativeValue(type, block, position); @@ -124,7 +125,7 @@ public static Slice specializedSlice(@TypeParameter("E") Type type, @SqlNullable @TypeParameter("E") @SqlNullable @SqlType("E") - public static Slice specializedSlice(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType(value = "E", nativeContainerType = Slice.class) Block block, @BlockIndex int position) + public static Slice specializedSlice(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType(value = "E", nativeContainerType = Slice.class) ValueBlock block, @BlockIndex int position) { hitBlockPositionSlice.set(true); return type.getSlice(block, position); @@ -141,7 +142,7 @@ public static Boolean speciailizedBoolean(@TypeParameter("E") Type type, @SqlNul @TypeParameter("E") @SqlNullable @SqlType("E") - public static Boolean speciailizedBoolean(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType(value = "E", nativeContainerType = boolean.class) Block block, @BlockIndex int position) + public static Boolean speciailizedBoolean(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType(value = "E", nativeContainerType = boolean.class) ValueBlock block, @BlockIndex int position) { hitBlockPositionBoolean.set(true); return type.getBoolean(block, position); @@ -158,7 +159,7 @@ public static Long getLong(@SqlNullable @SqlType(StandardTypes.BIGINT) Long numb @SqlType(StandardTypes.BIGINT) @SqlNullable - public static Long getBlockPosition(@SqlNullable @BlockPosition @SqlType(value = StandardTypes.BIGINT, nativeContainerType = long.class) Block block, @BlockIndex int position) + public static Long getBlockPosition(@SqlNullable @BlockPosition @SqlType(value = StandardTypes.BIGINT, nativeContainerType = long.class) ValueBlock block, @BlockIndex int position) { hitBlockPositionBigint.set(true); return BIGINT.getLong(block, position); @@ -173,7 +174,126 @@ public static Double getDouble(@SqlNullable @SqlType(StandardTypes.DOUBLE) Doubl @SqlType(StandardTypes.DOUBLE) @SqlNullable - public static Double getDouble(@SqlNullable @BlockPosition @SqlType(value = StandardTypes.DOUBLE, nativeContainerType = double.class) Block block, @BlockIndex int position) + public static Double getDouble(@SqlNullable @BlockPosition @SqlType(value = StandardTypes.DOUBLE, nativeContainerType = double.class) ValueBlock block, @BlockIndex int position) + { + hitBlockPositionDouble.set(true); + return DOUBLE.getDouble(block, position); + } + } + + @Test + public void testValueBlockPosition() + { + assertThat(assertions.function("test_value_block_position", "BIGINT '1234'")) + .isEqualTo(1234L); + assertTrue(FunctionWithValueBlockAndPositionConvention.hitBlockPositionBigint.get()); + + assertThat(assertions.function("test_value_block_position", "12.34e0")) + .isEqualTo(12.34); + assertTrue(FunctionWithValueBlockAndPositionConvention.hitBlockPositionDouble.get()); + + assertThat(assertions.function("test_value_block_position", "'hello'")) + .hasType(createVarcharType(5)) + .isEqualTo("hello"); + assertTrue(FunctionWithValueBlockAndPositionConvention.hitBlockPositionSlice.get()); + + assertThat(assertions.function("test_value_block_position", "true")) + .isEqualTo(true); + assertTrue(FunctionWithValueBlockAndPositionConvention.hitBlockPositionBoolean.get()); + } + + @ScalarFunction("test_value_block_position") + public static final class FunctionWithValueBlockAndPositionConvention + { + private static final AtomicBoolean hitBlockPositionBigint = new AtomicBoolean(); + private static final AtomicBoolean hitBlockPositionDouble = new AtomicBoolean(); + private static final AtomicBoolean hitBlockPositionSlice = new AtomicBoolean(); + private static final AtomicBoolean hitBlockPositionBoolean = new AtomicBoolean(); + private static final AtomicBoolean hitBlockPositionObject = new AtomicBoolean(); + + // generic implementations + // these will not work right now because MethodHandle is not properly adapted + + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Object generic(@TypeParameter("E") Type type, @SqlNullable @SqlType("E") Object object) + { + return object; + } + + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Object generic(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType("E") ValueBlock block, @BlockIndex int position) + { + hitBlockPositionObject.set(true); + return readNativeValue(type, block, position); + } + + // specialized + + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Slice specializedSlice(@TypeParameter("E") Type type, @SqlNullable @SqlType("E") Slice slice) + { + return slice; + } + + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Slice specializedSlice(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType(value = "E", nativeContainerType = Slice.class) ValueBlock block, @BlockIndex int position) + { + hitBlockPositionSlice.set(true); + return type.getSlice(block, position); + } + + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Boolean speciailizedBoolean(@TypeParameter("E") Type type, @SqlNullable @SqlType("E") Boolean bool) + { + return bool; + } + + @TypeParameter("E") + @SqlNullable + @SqlType("E") + public static Boolean speciailizedBoolean(@TypeParameter("E") Type type, @SqlNullable @BlockPosition @SqlType(value = "E", nativeContainerType = boolean.class) ValueBlock block, @BlockIndex int position) + { + hitBlockPositionBoolean.set(true); + return type.getBoolean(block, position); + } + + // exact + + @SqlType(StandardTypes.BIGINT) + @SqlNullable + public static Long getLong(@SqlNullable @SqlType(StandardTypes.BIGINT) Long number) + { + return number; + } + + @SqlType(StandardTypes.BIGINT) + @SqlNullable + public static Long getBlockPosition(@SqlNullable @BlockPosition @SqlType(value = StandardTypes.BIGINT, nativeContainerType = long.class) ValueBlock block, @BlockIndex int position) + { + hitBlockPositionBigint.set(true); + return BIGINT.getLong(block, position); + } + + @SqlType(StandardTypes.DOUBLE) + @SqlNullable + public static Double getDouble(@SqlNullable @SqlType(StandardTypes.DOUBLE) Double number) + { + return number; + } + + @SqlType(StandardTypes.DOUBLE) + @SqlNullable + public static Double getDouble(@SqlNullable @BlockPosition @SqlType(value = StandardTypes.DOUBLE, nativeContainerType = double.class) ValueBlock block, @BlockIndex int position) { hitBlockPositionDouble.set(true); return DOUBLE.getDouble(block, position); diff --git a/core/trino-main/src/test/java/io/trino/spiller/TestGenericPartitioningSpiller.java b/core/trino-main/src/test/java/io/trino/spiller/TestGenericPartitioningSpiller.java index 13a9443c1277..ee754fa020ce 100644 --- a/core/trino-main/src/test/java/io/trino/spiller/TestGenericPartitioningSpiller.java +++ b/core/trino-main/src/test/java/io/trino/spiller/TestGenericPartitioningSpiller.java @@ -238,7 +238,7 @@ private static class FourFixedPartitionsPartitionFunction } @Override - public int getPartitionCount() + public int partitionCount() { return 4; } @@ -274,7 +274,7 @@ private static class ModuloPartitionFunction } @Override - public int getPartitionCount() + public int partitionCount() { return partitionCount; } diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor.java b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor.java index 62775fc986d4..e0b1ae1407cc 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor.java @@ -22,6 +22,7 @@ import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; +import io.trino.spi.block.VariableWidthBlock; import io.trino.sql.relational.CallExpression; import io.trino.sql.relational.RowExpression; import io.trino.sql.relational.SpecialForm; @@ -175,14 +176,20 @@ private static boolean filter(int position, Block discountBlock, Block shipDateB private static boolean lessThan(Block left, int leftPosition, Slice right) { + VariableWidthBlock leftBlock = (VariableWidthBlock) left.getUnderlyingValueBlock(); + Slice leftSlice = leftBlock.getRawSlice(); + int leftOffset = leftBlock.getRawSliceOffset(leftPosition); int leftLength = left.getSliceLength(leftPosition); - return left.bytesCompare(leftPosition, 0, leftLength, right, 0, right.length()) < 0; + return leftSlice.compareTo(leftOffset, leftLength, right, 0, right.length()) < 0; } private static boolean greaterThanOrEqual(Block left, int leftPosition, Slice right) { + VariableWidthBlock leftBlock = (VariableWidthBlock) left.getUnderlyingValueBlock(); + Slice leftSlice = leftBlock.getRawSlice(); + int leftOffset = leftBlock.getRawSliceOffset(leftPosition); int leftLength = left.getSliceLength(leftPosition); - return left.bytesCompare(leftPosition, 0, leftLength, right, 0, right.length()) >= 0; + return leftSlice.compareTo(leftOffset, leftLength, right, 0, right.length()) >= 0; } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestCompilerConfig.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestCompilerConfig.java index 51c01c21caa9..77ee738f4f1d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestCompilerConfig.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestCompilerConfig.java @@ -28,16 +28,21 @@ public class TestCompilerConfig public void testDefaults() { assertRecordedDefaults(recordDefaults(CompilerConfig.class) - .setExpressionCacheSize(10_000)); + .setExpressionCacheSize(10_000) + .setSpecializeAggregationLoops(true)); } @Test public void testExplicitPropertyMappings() { - Map properties = ImmutableMap.of("compiler.expression-cache-size", "52"); + Map properties = ImmutableMap.builder() + .put("compiler.expression-cache-size", "52") + .put("compiler.specialized-aggregation-loops", "false") + .buildOrThrow(); CompilerConfig expected = new CompilerConfig() - .setExpressionCacheSize(52); + .setExpressionCacheSize(52) + .setSpecializeAggregationLoops(false); assertFullMapping(properties, expected); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java index 608102463fbb..4583fb3792e7 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDeleteAndInsertMergeProcessor.java @@ -81,7 +81,7 @@ public void testSimpleDeletedRowMerge() // Show that the row to be deleted is rowId 0, e.g. ('Dave', 11, 'Devon') SqlRow rowIdRow = outputPage.getBlock(4).getObject(0, SqlRow.class); - assertThat(INTEGER.getInt(rowIdRow.getRawFieldBlock(1), rowIdRow.getRawIndex())).isEqualTo(0); + assertThat(BIGINT.getLong(rowIdRow.getRawFieldBlock(1), rowIdRow.getRawIndex())).isEqualTo(0); } @Test diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformUncorrelatedSubqueryToJoin.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformUncorrelatedSubqueryToJoin.java index 45b2fd1eff8b..f88bde2a1f88 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformUncorrelatedSubqueryToJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformUncorrelatedSubqueryToJoin.java @@ -120,7 +120,7 @@ public void testRewriteRightCorrelatedJoin() .matches( project( ImmutableMap.of( - "a", expression("if(b > a, a, null)"), + "a", expression("if(b > a, a, cast(null AS BIGINT))"), "b", expression("b")), join(Type.INNER, builder -> builder .left(values("a")) diff --git a/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java b/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java index 0ea4a710916d..19d8600f960d 100644 --- a/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java +++ b/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java @@ -23,6 +23,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockEncodingSerde; import io.trino.spi.block.TestingBlockEncodingSerde; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.ArrayType; import io.trino.spi.type.LongTimestamp; import io.trino.spi.type.LongTimestampWithTimeZone; @@ -59,9 +60,10 @@ import static io.trino.spi.connector.SortOrder.DESC_NULLS_FIRST; import static io.trino.spi.connector.SortOrder.DESC_NULLS_LAST; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.DEFAULT_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @@ -89,7 +91,7 @@ public abstract class AbstractTestType private final BlockEncodingSerde blockEncodingSerde = new TestingBlockEncodingSerde(); private final Class objectValueType; - private final Block testBlock; + private final ValueBlock testBlock; protected final Type type; private final TypeOperators typeOperators; @@ -116,35 +118,35 @@ public abstract class AbstractTestType private final BlockPositionIsDistinctFrom distinctFromOperator; private final SortedMap expectedStackValues; private final SortedMap expectedObjectValues; - private final Block testBlockWithNulls; + private final ValueBlock testBlockWithNulls; - protected AbstractTestType(Type type, Class objectValueType, Block testBlock) + protected AbstractTestType(Type type, Class objectValueType, ValueBlock testBlock) { this(type, objectValueType, testBlock, testBlock); } - protected AbstractTestType(Type type, Class objectValueType, Block testBlock, Block expectedValues) + protected AbstractTestType(Type type, Class objectValueType, ValueBlock testBlock, ValueBlock expectedValues) { this.type = requireNonNull(type, "type is null"); typeOperators = new TypeOperators(); - readBlockMethod = typeOperators.getReadValueOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)); + readBlockMethod = typeOperators.getReadValueOperator(type, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); writeBlockMethod = typeOperators.getReadValueOperator(type, simpleConvention(BLOCK_BUILDER, NEVER_NULL)); writeFlatToBlockMethod = typeOperators.getReadValueOperator(type, simpleConvention(BLOCK_BUILDER, FLAT)); readFlatMethod = typeOperators.getReadValueOperator(type, simpleConvention(FAIL_ON_NULL, FLAT)); writeFlatMethod = typeOperators.getReadValueOperator(type, simpleConvention(FLAT_RETURN, NEVER_NULL)); - writeBlockToFlatMethod = typeOperators.getReadValueOperator(type, simpleConvention(FLAT_RETURN, BLOCK_POSITION)); + writeBlockToFlatMethod = typeOperators.getReadValueOperator(type, simpleConvention(FLAT_RETURN, VALUE_BLOCK_POSITION)); blockTypeOperators = new BlockTypeOperators(typeOperators); if (type.isComparable()) { stackStackEqualOperator = typeOperators.getEqualOperator(type, simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL)); flatFlatEqualOperator = typeOperators.getEqualOperator(type, simpleConvention(NULLABLE_RETURN, FLAT, FLAT)); - flatBlockPositionEqualOperator = typeOperators.getEqualOperator(type, simpleConvention(NULLABLE_RETURN, FLAT, BLOCK_POSITION)); - blockPositionFlatEqualOperator = typeOperators.getEqualOperator(type, simpleConvention(NULLABLE_RETURN, BLOCK_POSITION, FLAT)); + flatBlockPositionEqualOperator = typeOperators.getEqualOperator(type, simpleConvention(NULLABLE_RETURN, FLAT, VALUE_BLOCK_POSITION)); + blockPositionFlatEqualOperator = typeOperators.getEqualOperator(type, simpleConvention(NULLABLE_RETURN, VALUE_BLOCK_POSITION, FLAT)); flatHashCodeOperator = typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, FLAT)); flatXxHash64Operator = typeOperators.getXxHash64Operator(type, simpleConvention(FAIL_ON_NULL, FLAT)); flatFlatDistinctFromOperator = typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, FLAT)); - flatBlockPositionDistinctFromOperator = typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION)); - blockPositionFlatDistinctFromOperator = typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, FLAT)); + flatBlockPositionDistinctFromOperator = typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, VALUE_BLOCK_POSITION)); + blockPositionFlatDistinctFromOperator = typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION, FLAT)); equalOperator = blockTypeOperators.getEqualOperator(type); hashCodeOperator = blockTypeOperators.getHashCodeOperator(type); @@ -176,7 +178,7 @@ protected AbstractTestType(Type type, Class objectValueType, Block testBlock, this.testBlockWithNulls = createAlternatingNullsBlock(testBlock); } - private Block createAlternatingNullsBlock(Block testBlock) + private ValueBlock createAlternatingNullsBlock(Block testBlock) { BlockBuilder nullsBlockBuilder = type.createBlockBuilder(null, testBlock.getPositionCount()); for (int position = 0; position < testBlock.getPositionCount(); position++) { @@ -202,7 +204,7 @@ else if (type.getJavaType() == Slice.class) { } nullsBlockBuilder.appendNull(); } - return nullsBlockBuilder.build(); + return nullsBlockBuilder.buildValueBlock(); } @Test @@ -337,7 +339,7 @@ else if (stackStackEqualOperator != null) { assertFalse((boolean) flatBlockPositionDistinctFromOperator.invokeExact(fixed, elementFixedOffset, variable, testBlock, i)); assertFalse((boolean) blockPositionFlatDistinctFromOperator.invokeExact(testBlock, i, fixed, elementFixedOffset, variable)); - Block nullValue = type.createBlockBuilder(null, 1).appendNull().build(); + ValueBlock nullValue = type.createBlockBuilder(null, 1).appendNull().buildValueBlock(); assertTrue((boolean) flatBlockPositionDistinctFromOperator.invokeExact(fixed, elementFixedOffset, variable, nullValue, 0)); assertTrue((boolean) blockPositionFlatDistinctFromOperator.invokeExact(nullValue, 0, fixed, elementFixedOffset, variable)); } @@ -349,7 +351,7 @@ protected Object getSampleValue() return requireNonNull(Iterables.get(expectedStackValues.values(), 0), "sample value is null"); } - protected void assertPositionEquals(Block block, int position, Object expectedStackValue, Object expectedObjectValue) + protected void assertPositionEquals(ValueBlock block, int position, Object expectedStackValue, Object expectedObjectValue) throws Throwable { long hash = 0; @@ -364,16 +366,16 @@ protected void assertPositionEquals(Block block, int position, Object expectedSt BlockBuilder blockBuilder = type.createBlockBuilder(null, 1); type.appendTo(block, position, blockBuilder); - assertPositionValue(blockBuilder.build(), 0, expectedStackValue, hash, expectedObjectValue); + assertPositionValue(blockBuilder.buildValueBlock(), 0, expectedStackValue, hash, expectedObjectValue); if (expectedStackValue != null) { blockBuilder = type.createBlockBuilder(null, 1); writeBlockMethod.invoke(expectedStackValue, blockBuilder); - assertPositionValue(blockBuilder.build(), 0, expectedStackValue, hash, expectedObjectValue); + assertPositionValue(blockBuilder.buildValueBlock(), 0, expectedStackValue, hash, expectedObjectValue); } } - private void assertPositionValue(Block block, int position, Object expectedStackValue, long expectedHash, Object expectedObjectValue) + private void assertPositionValue(ValueBlock block, int position, Object expectedStackValue, long expectedHash, Object expectedObjectValue) throws Throwable { assertEquals(block.isNull(position), expectedStackValue == null); @@ -643,7 +645,7 @@ else if (javaType == Slice.class) { else { type.writeObject(blockBuilder, value); } - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } /** @@ -727,7 +729,7 @@ else if (javaType == Slice.class) { else { type.writeObject(blockBuilder, value); } - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } private static SortedMap indexStackValues(Type type, Block block) diff --git a/core/trino-main/src/test/java/io/trino/type/TestArrayOfMapOfBigintVarcharType.java b/core/trino-main/src/test/java/io/trino/type/TestArrayOfMapOfBigintVarcharType.java index 5b6e95354165..7c04b7e78363 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestArrayOfMapOfBigintVarcharType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestArrayOfMapOfBigintVarcharType.java @@ -14,8 +14,8 @@ package io.trino.type; import com.google.common.collect.ImmutableMap; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.ArrayType; import org.junit.jupiter.api.Test; @@ -39,7 +39,7 @@ public TestArrayOfMapOfBigintVarcharType() super(TYPE, List.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = TYPE.createBlockBuilder(null, 4); TYPE.writeObject(blockBuilder, arrayBlockOf(TYPE.getElementType(), @@ -51,7 +51,7 @@ public static Block createTestBlock() TYPE.writeObject(blockBuilder, arrayBlockOf(TYPE.getElementType(), sqlMapOf(BIGINT, VARCHAR, ImmutableMap.of(100, "hundred")), sqlMapOf(BIGINT, VARCHAR, ImmutableMap.of(200, "two hundred")))); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestBigintArrayType.java b/core/trino-main/src/test/java/io/trino/type/TestBigintArrayType.java index d486da1a824c..7f466ba8e8bc 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestBigintArrayType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestBigintArrayType.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -34,14 +35,14 @@ public TestBigintArrayType() super(TESTING_TYPE_MANAGER.getType(arrayType(BIGINT.getTypeSignature())), List.class, createTestBlock(TESTING_TYPE_MANAGER.getType(arrayType(BIGINT.getTypeSignature())))); } - public static Block createTestBlock(Type arrayType) + public static ValueBlock createTestBlock(Type arrayType) { BlockBuilder blockBuilder = arrayType.createBlockBuilder(null, 4); arrayType.writeObject(blockBuilder, arrayBlockOf(BIGINT, 1, 2)); arrayType.writeObject(blockBuilder, arrayBlockOf(BIGINT, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(BIGINT, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(BIGINT, 100, 200, 300)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override @@ -54,7 +55,7 @@ protected Object getGreaterValue(Object value) } BIGINT.writeLong(blockBuilder, 1L); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Test diff --git a/core/trino-main/src/test/java/io/trino/type/TestBigintType.java b/core/trino-main/src/test/java/io/trino/type/TestBigintType.java index 91ca30931c58..c22a8800f2ad 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestBigintType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestBigintType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type.Range; import org.junit.jupiter.api.Test; @@ -32,7 +32,7 @@ public TestBigintType() super(BIGINT, Long.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = BIGINT.createBlockBuilder(null, 15); BIGINT.writeLong(blockBuilder, 1111); @@ -46,7 +46,7 @@ public static Block createTestBlock() BIGINT.writeLong(blockBuilder, 3333); BIGINT.writeLong(blockBuilder, 3333); BIGINT.writeLong(blockBuilder, 4444); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestBigintVarcharMapType.java b/core/trino-main/src/test/java/io/trino/type/TestBigintVarcharMapType.java index 1ba0f98ce4dd..88f279d6dd75 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestBigintVarcharMapType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestBigintVarcharMapType.java @@ -14,8 +14,8 @@ package io.trino.type; import com.google.common.collect.ImmutableMap; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -36,13 +36,13 @@ public TestBigintVarcharMapType() super(mapType(BIGINT, VARCHAR), Map.class, createTestBlock(mapType(BIGINT, VARCHAR))); } - public static Block createTestBlock(Type mapType) + public static ValueBlock createTestBlock(Type mapType) { BlockBuilder blockBuilder = mapType.createBlockBuilder(null, 2); mapType.writeObject(blockBuilder, sqlMapOf(BIGINT, VARCHAR, ImmutableMap.of(1, "hi"))); mapType.writeObject(blockBuilder, sqlMapOf(BIGINT, VARCHAR, ImmutableMap.of(1, "2", 2, "hello"))); mapType.writeObject(blockBuilder, sqlMapOf(BIGINT, VARCHAR, ImmutableMap.of(1, "123456789012345", 2, "hello-world-hello-world-hello-world"))); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestBooleanType.java b/core/trino-main/src/test/java/io/trino/type/TestBooleanType.java index c26e18547374..b267d02f8c9c 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestBooleanType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestBooleanType.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.ByteArrayBlock; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.BooleanType; import org.junit.jupiter.api.Test; @@ -66,7 +67,7 @@ public void testBooleanBlockWithSingleNonNullValue() assertFalse(BooleanType.createBlockForSingleNonNullValue(false).mayHaveNull()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = BOOLEAN.createBlockBuilder(null, 15); BOOLEAN.writeBoolean(blockBuilder, true); @@ -80,7 +81,7 @@ public static Block createTestBlock() BOOLEAN.writeBoolean(blockBuilder, true); BOOLEAN.writeBoolean(blockBuilder, true); BOOLEAN.writeBoolean(blockBuilder, false); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestBoundedVarcharType.java b/core/trino-main/src/test/java/io/trino/type/TestBoundedVarcharType.java index e482515905a0..7874cea36276 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestBoundedVarcharType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestBoundedVarcharType.java @@ -15,8 +15,8 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import org.junit.jupiter.api.Test; @@ -34,7 +34,7 @@ public TestBoundedVarcharType() super(createVarcharType(6), String.class, createTestBlock(createVarcharType(6))); } - private static Block createTestBlock(VarcharType type) + private static ValueBlock createTestBlock(VarcharType type) { BlockBuilder blockBuilder = type.createBlockBuilder(null, 15); type.writeString(blockBuilder, "apple"); @@ -48,7 +48,7 @@ private static Block createTestBlock(VarcharType type) type.writeString(blockBuilder, "cherry"); type.writeString(blockBuilder, "cherry"); type.writeString(blockBuilder, "date"); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestCharType.java b/core/trino-main/src/test/java/io/trino/type/TestCharType.java index 77b55e254241..333fa0086fb3 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestCharType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestCharType.java @@ -18,6 +18,7 @@ import io.airlift.slice.Slices; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.type.CharType; import io.trino.spi.type.Type; @@ -45,7 +46,7 @@ public TestCharType() super(CHAR_TYPE, String.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = CHAR_TYPE.createBlockBuilder(null, 15); CHAR_TYPE.writeString(blockBuilder, "apple"); @@ -59,7 +60,7 @@ public static Block createTestBlock() CHAR_TYPE.writeString(blockBuilder, "cherry"); CHAR_TYPE.writeString(blockBuilder, "cherry"); CHAR_TYPE.writeString(blockBuilder, "date"); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestColorArrayType.java b/core/trino-main/src/test/java/io/trino/type/TestColorArrayType.java index 01b283cd73aa..f2d81fe09119 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestColorArrayType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestColorArrayType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -35,14 +35,14 @@ public TestColorArrayType() super(TESTING_TYPE_MANAGER.getType(arrayType(COLOR.getTypeSignature())), List.class, createTestBlock(TESTING_TYPE_MANAGER.getType(arrayType(COLOR.getTypeSignature())))); } - public static Block createTestBlock(Type arrayType) + public static ValueBlock createTestBlock(Type arrayType) { BlockBuilder blockBuilder = arrayType.createBlockBuilder(null, 4); arrayType.writeObject(blockBuilder, arrayBlockOf(COLOR, 1, 2)); arrayType.writeObject(blockBuilder, arrayBlockOf(COLOR, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(COLOR, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(COLOR, 100, 200, 300)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestColorType.java b/core/trino-main/src/test/java/io/trino/type/TestColorType.java index 6b4557f9487f..3d640f0d3f80 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestColorType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestColorType.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import org.junit.jupiter.api.Test; import static io.trino.operator.scalar.ColorFunctions.rgb; @@ -54,7 +55,7 @@ public void testGetObjectValue() } } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = COLOR.createBlockBuilder(null, 15); COLOR.writeLong(blockBuilder, rgb(1, 1, 1)); @@ -68,7 +69,7 @@ public static Block createTestBlock() COLOR.writeLong(blockBuilder, rgb(3, 3, 3)); COLOR.writeLong(blockBuilder, rgb(3, 3, 3)); COLOR.writeLong(blockBuilder, rgb(4, 4, 4)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestConventionDependencies.java b/core/trino-main/src/test/java/io/trino/type/TestConventionDependencies.java index fa1ebfd0f75d..309179070ac6 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestConventionDependencies.java +++ b/core/trino-main/src/test/java/io/trino/type/TestConventionDependencies.java @@ -16,6 +16,8 @@ import io.trino.metadata.InternalFunctionBundle; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.BlockIndex; import io.trino.spi.function.BlockPosition; import io.trino.spi.function.Convention; @@ -37,6 +39,7 @@ import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.type.IntegerType.INTEGER; import static org.assertj.core.api.Assertions.assertThat; @@ -55,6 +58,7 @@ public void init() assertions.addFunctions(InternalFunctionBundle.builder() .scalar(RegularConvention.class) .scalar(BlockPositionConvention.class) + .scalar(ValueBlockPositionConvention.class) .scalar(Add.class) .build()); @@ -88,6 +92,15 @@ public void testConventionDependencies() assertThat(assertions.function("block_position_convention", "ARRAY[56, 275, 36]")) .isEqualTo(367); + + assertThat(assertions.function("value_block_position_convention", "ARRAY[1, 2, 3]")) + .isEqualTo(6); + + assertThat(assertions.function("value_block_position_convention", "ARRAY[25, 0, 5]")) + .isEqualTo(30); + + assertThat(assertions.function("value_block_position_convention", "ARRAY[56, 275, 36]")) + .isEqualTo(367); } @ScalarFunction("regular_convention") @@ -138,6 +151,34 @@ public static long testBlockPositionConvention( } } + @ScalarFunction("value_block_position_convention") + public static final class ValueBlockPositionConvention + { + @SqlType(StandardTypes.INTEGER) + public static long testBlockPositionConvention( + @FunctionDependency( + name = "add", + argumentTypes = {StandardTypes.INTEGER, StandardTypes.INTEGER}, + convention = @Convention(arguments = {NEVER_NULL, VALUE_BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle function, + @SqlType("array(integer)") Block array) + { + ValueBlock arrayValues = array.getUnderlyingValueBlock(); + + long sum = 0; + for (int i = 0; i < array.getPositionCount(); i++) { + try { + sum = (long) function.invokeExact(sum, arrayValues, array.getUnderlyingValuePosition(i)); + } + catch (Throwable t) { + throwIfInstanceOf(t, Error.class); + throwIfInstanceOf(t, TrinoException.class); + throw new TrinoException(GENERIC_INTERNAL_ERROR, t); + } + } + return sum; + } + } + @ScalarFunction("add") public static final class Add { @@ -152,7 +193,7 @@ public static long add( @SqlType(StandardTypes.INTEGER) public static long addBlockPosition( @SqlType(StandardTypes.INTEGER) long first, - @BlockPosition @SqlType(value = StandardTypes.INTEGER, nativeContainerType = long.class) Block block, + @BlockPosition @SqlType(value = StandardTypes.INTEGER, nativeContainerType = long.class) IntArrayBlock block, @BlockIndex int position) { return Math.addExact((int) first, INTEGER.getInt(block, position)); diff --git a/core/trino-main/src/test/java/io/trino/type/TestDateType.java b/core/trino-main/src/test/java/io/trino/type/TestDateType.java index 8d27947777ff..9e3565cfcb86 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestDateType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestDateType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.SqlDate; import io.trino.spi.type.Type.Range; import org.junit.jupiter.api.Test; @@ -33,7 +33,7 @@ public TestDateType() super(DATE, SqlDate.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = DATE.createBlockBuilder(null, 15); DATE.writeLong(blockBuilder, 1111); @@ -47,7 +47,7 @@ public static Block createTestBlock() DATE.writeLong(blockBuilder, 3333); DATE.writeLong(blockBuilder, 3333); DATE.writeLong(blockBuilder, 4444); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestDoubleType.java b/core/trino-main/src/test/java/io/trino/type/TestDoubleType.java index 73c366d7e3d0..d13a3eee46c1 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestDoubleType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestDoubleType.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.LongArrayBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; import io.trino.type.BlockTypeOperators.BlockPositionXxHash64; import org.junit.jupiter.api.Test; @@ -34,7 +35,7 @@ public TestDoubleType() super(DOUBLE, Double.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = DOUBLE.createBlockBuilder(null, 15); DOUBLE.writeDouble(blockBuilder, 11.11); @@ -48,7 +49,7 @@ public static Block createTestBlock() DOUBLE.writeDouble(blockBuilder, 33.33); DOUBLE.writeDouble(blockBuilder, 33.33); DOUBLE.writeDouble(blockBuilder, 44.44); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestIntegerArrayType.java b/core/trino-main/src/test/java/io/trino/type/TestIntegerArrayType.java index 5d7853f2aa44..40dda9308495 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestIntegerArrayType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestIntegerArrayType.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -34,14 +35,14 @@ public TestIntegerArrayType() super(TESTING_TYPE_MANAGER.getType(arrayType(INTEGER.getTypeSignature())), List.class, createTestBlock(TESTING_TYPE_MANAGER.getType(arrayType(INTEGER.getTypeSignature())))); } - public static Block createTestBlock(Type arrayType) + public static ValueBlock createTestBlock(Type arrayType) { BlockBuilder blockBuilder = arrayType.createBlockBuilder(null, 4); arrayType.writeObject(blockBuilder, arrayBlockOf(INTEGER, 1, 2)); arrayType.writeObject(blockBuilder, arrayBlockOf(INTEGER, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(INTEGER, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(INTEGER, 100, 200, 300)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override @@ -54,7 +55,7 @@ protected Object getGreaterValue(Object value) } INTEGER.writeLong(blockBuilder, 1L); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Test diff --git a/core/trino-main/src/test/java/io/trino/type/TestIntegerType.java b/core/trino-main/src/test/java/io/trino/type/TestIntegerType.java index 798a449d58b5..2a76b84a3c3a 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestIntegerType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestIntegerType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type.Range; import org.junit.jupiter.api.Test; @@ -32,7 +32,7 @@ public TestIntegerType() super(INTEGER, Integer.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = INTEGER.createBlockBuilder(null, 15); INTEGER.writeLong(blockBuilder, 1111); @@ -46,7 +46,7 @@ public static Block createTestBlock() INTEGER.writeLong(blockBuilder, 3333); INTEGER.writeLong(blockBuilder, 3333); INTEGER.writeLong(blockBuilder, 4444); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestIntegerVarcharMapType.java b/core/trino-main/src/test/java/io/trino/type/TestIntegerVarcharMapType.java index 69287415be6f..5349a0475f2a 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestIntegerVarcharMapType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestIntegerVarcharMapType.java @@ -14,8 +14,8 @@ package io.trino.type; import com.google.common.collect.ImmutableMap; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -36,13 +36,13 @@ public TestIntegerVarcharMapType() super(mapType(INTEGER, VARCHAR), Map.class, createTestBlock(mapType(INTEGER, VARCHAR))); } - public static Block createTestBlock(Type mapType) + public static ValueBlock createTestBlock(Type mapType) { BlockBuilder blockBuilder = mapType.createBlockBuilder(null, 2); mapType.writeObject(blockBuilder, sqlMapOf(INTEGER, VARCHAR, ImmutableMap.of(1, "hi"))); mapType.writeObject(blockBuilder, sqlMapOf(INTEGER, VARCHAR, ImmutableMap.of(1, "2", 2, "hello"))); mapType.writeObject(blockBuilder, sqlMapOf(INTEGER, VARCHAR, ImmutableMap.of(1, "123456789012345", 2, "hello-world-hello-world-hello-world"))); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestIntervalDayTimeType.java b/core/trino-main/src/test/java/io/trino/type/TestIntervalDayTimeType.java index eddb22084a4b..e967ec08bff4 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestIntervalDayTimeType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestIntervalDayTimeType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import org.junit.jupiter.api.Test; import static io.trino.type.IntervalDayTimeType.INTERVAL_DAY_TIME; @@ -28,7 +28,7 @@ public TestIntervalDayTimeType() super(INTERVAL_DAY_TIME, SqlIntervalDayTime.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = INTERVAL_DAY_TIME.createBlockBuilder(null, 15); INTERVAL_DAY_TIME.writeLong(blockBuilder, 1111); @@ -42,7 +42,7 @@ public static Block createTestBlock() INTERVAL_DAY_TIME.writeLong(blockBuilder, 3333); INTERVAL_DAY_TIME.writeLong(blockBuilder, 3333); INTERVAL_DAY_TIME.writeLong(blockBuilder, 4444); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestIntervalYearMonthType.java b/core/trino-main/src/test/java/io/trino/type/TestIntervalYearMonthType.java index f16d0d577828..108ad544ead9 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestIntervalYearMonthType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestIntervalYearMonthType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import org.junit.jupiter.api.Test; import static io.trino.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH; @@ -28,7 +28,7 @@ public TestIntervalYearMonthType() super(INTERVAL_YEAR_MONTH, SqlIntervalYearMonth.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = INTERVAL_YEAR_MONTH.createBlockBuilder(null, 15); INTERVAL_YEAR_MONTH.writeLong(blockBuilder, 1111); @@ -42,7 +42,7 @@ public static Block createTestBlock() INTERVAL_YEAR_MONTH.writeLong(blockBuilder, 3333); INTERVAL_YEAR_MONTH.writeLong(blockBuilder, 3333); INTERVAL_YEAR_MONTH.writeLong(blockBuilder, 4444); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestIpAddressType.java b/core/trino-main/src/test/java/io/trino/type/TestIpAddressType.java index a1f057f3c415..6c4a4c7d42ee 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestIpAddressType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestIpAddressType.java @@ -16,8 +16,8 @@ import com.google.common.net.InetAddresses; import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import org.junit.jupiter.api.Test; import static com.google.common.base.Preconditions.checkState; @@ -33,7 +33,7 @@ public TestIpAddressType() super(IPADDRESS, String.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = IPADDRESS.createBlockBuilder(null, 1); IPADDRESS.writeSlice(blockBuilder, getSliceForAddress("2001:db8::ff00:42:8320")); @@ -46,7 +46,7 @@ public static Block createTestBlock() IPADDRESS.writeSlice(blockBuilder, getSliceForAddress("2001:db8::ff00:42:8327")); IPADDRESS.writeSlice(blockBuilder, getSliceForAddress("2001:db8::ff00:42:8328")); IPADDRESS.writeSlice(blockBuilder, getSliceForAddress("2001:db8::ff00:42:8329")); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestJsonType.java b/core/trino-main/src/test/java/io/trino/type/TestJsonType.java index 26bd09677460..20f14a17b824 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestJsonType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestJsonType.java @@ -15,8 +15,8 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import org.junit.jupiter.api.Test; import static io.trino.type.JsonType.JSON; @@ -31,12 +31,12 @@ public TestJsonType() super(JSON, String.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = JSON.createBlockBuilder(null, 1); Slice slice = Slices.utf8Slice("{\"x\":1, \"y\":2}"); JSON.writeSlice(blockBuilder, slice); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestLongDecimalType.java b/core/trino-main/src/test/java/io/trino/type/TestLongDecimalType.java index ec01f82524b7..1dd5ebb9a895 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestLongDecimalType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestLongDecimalType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; import io.trino.spi.type.Int128; @@ -36,7 +36,7 @@ public TestLongDecimalType() super(LONG_DECIMAL_TYPE, SqlDecimal.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = LONG_DECIMAL_TYPE.createBlockBuilder(null, 15); writeBigDecimal(LONG_DECIMAL_TYPE, blockBuilder, new BigDecimal("-12345678901234567890.1234567890")); @@ -50,7 +50,7 @@ public static Block createTestBlock() writeBigDecimal(LONG_DECIMAL_TYPE, blockBuilder, new BigDecimal("32345678901234567890.1234567890")); writeBigDecimal(LONG_DECIMAL_TYPE, blockBuilder, new BigDecimal("32345678901234567890.1234567890")); writeBigDecimal(LONG_DECIMAL_TYPE, blockBuilder, new BigDecimal("42345678901234567890.1234567890")); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestLongTimestampType.java b/core/trino-main/src/test/java/io/trino/type/TestLongTimestampType.java index d1aeb3cef4bb..6e194fafc0fc 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestLongTimestampType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestLongTimestampType.java @@ -14,8 +14,8 @@ package io.trino.type; import com.google.common.collect.ImmutableList; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.LongTimestamp; import io.trino.spi.type.SqlTimestamp; import io.trino.spi.type.Type.Range; @@ -36,7 +36,7 @@ public TestLongTimestampType() super(TIMESTAMP_NANOS, SqlTimestamp.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = TIMESTAMP_NANOS.createBlockBuilder(null, 15); TIMESTAMP_NANOS.writeObject(blockBuilder, new LongTimestamp(1111_123, 123_000)); @@ -50,7 +50,7 @@ public static Block createTestBlock() TIMESTAMP_NANOS.writeObject(blockBuilder, new LongTimestamp(3333_123, 123_000)); TIMESTAMP_NANOS.writeObject(blockBuilder, new LongTimestamp(3333_123, 123_000)); TIMESTAMP_NANOS.writeObject(blockBuilder, new LongTimestamp(4444_123, 123_000)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestLongTimestampWithTimeZoneType.java b/core/trino-main/src/test/java/io/trino/type/TestLongTimestampWithTimeZoneType.java index 9016e994ada4..ce610e7fcf63 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestLongTimestampWithTimeZoneType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestLongTimestampWithTimeZoneType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.LongTimestampWithTimeZone; import io.trino.spi.type.SqlTimestampWithTimeZone; import io.trino.spi.type.Type; @@ -40,7 +40,7 @@ public TestLongTimestampWithTimeZoneType() super(TIMESTAMP_TZ_MICROS, SqlTimestampWithTimeZone.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = TIMESTAMP_TZ_MICROS.createBlockBuilder(null, 15); TIMESTAMP_TZ_MICROS.writeObject(blockBuilder, LongTimestampWithTimeZone.fromEpochMillisAndFraction(1111, 0, getTimeZoneKeyForOffset(0))); @@ -54,7 +54,7 @@ public static Block createTestBlock() TIMESTAMP_TZ_MICROS.writeObject(blockBuilder, LongTimestampWithTimeZone.fromEpochMillisAndFraction(3333, 0, getTimeZoneKeyForOffset(8))); TIMESTAMP_TZ_MICROS.writeObject(blockBuilder, LongTimestampWithTimeZone.fromEpochMillisAndFraction(3333, 0, getTimeZoneKeyForOffset(9))); TIMESTAMP_TZ_MICROS.writeObject(blockBuilder, LongTimestampWithTimeZone.fromEpochMillisAndFraction(4444, 0, getTimeZoneKeyForOffset(10))); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestRealType.java b/core/trino-main/src/test/java/io/trino/type/TestRealType.java index 1349b5697130..7fbfa28d67c1 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestRealType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestRealType.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.IntArrayBlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; import org.junit.jupiter.api.Test; @@ -34,7 +35,7 @@ public TestRealType() super(REAL, Float.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = REAL.createBlockBuilder(null, 30); REAL.writeLong(blockBuilder, floatToRawIntBits(11.11F)); @@ -48,7 +49,7 @@ public static Block createTestBlock() REAL.writeLong(blockBuilder, floatToRawIntBits(33.33F)); REAL.writeLong(blockBuilder, floatToRawIntBits(33.33F)); REAL.writeLong(blockBuilder, floatToRawIntBits(44.44F)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestShortDecimalType.java b/core/trino-main/src/test/java/io/trino/type/TestShortDecimalType.java index 2807bdc50ed1..712e83d682fd 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestShortDecimalType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestShortDecimalType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.DecimalType; import io.trino.spi.type.SqlDecimal; import org.junit.jupiter.api.Test; @@ -32,7 +32,7 @@ public TestShortDecimalType() super(SHORT_DECIMAL_TYPE, SqlDecimal.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = SHORT_DECIMAL_TYPE.createBlockBuilder(null, 15); SHORT_DECIMAL_TYPE.writeLong(blockBuilder, -1234); @@ -46,7 +46,7 @@ public static Block createTestBlock() SHORT_DECIMAL_TYPE.writeLong(blockBuilder, 3321); SHORT_DECIMAL_TYPE.writeLong(blockBuilder, 3321); SHORT_DECIMAL_TYPE.writeLong(blockBuilder, 4321); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestShortTimestampType.java b/core/trino-main/src/test/java/io/trino/type/TestShortTimestampType.java index f5b113ab4d7f..a80089c14bca 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestShortTimestampType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestShortTimestampType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.SqlTimestamp; import io.trino.spi.type.Type; import io.trino.spi.type.Type.Range; @@ -39,7 +39,7 @@ public TestShortTimestampType() super(TIMESTAMP_MILLIS, SqlTimestamp.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = TIMESTAMP_MILLIS.createBlockBuilder(null, 15); TIMESTAMP_MILLIS.writeLong(blockBuilder, 1111_000); @@ -53,7 +53,7 @@ public static Block createTestBlock() TIMESTAMP_MILLIS.writeLong(blockBuilder, 3333_000); TIMESTAMP_MILLIS.writeLong(blockBuilder, 3333_000); TIMESTAMP_MILLIS.writeLong(blockBuilder, 4444_000); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestShortTimestampWithTimeZoneType.java b/core/trino-main/src/test/java/io/trino/type/TestShortTimestampWithTimeZoneType.java index 848a60321307..2fbf03b4963e 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestShortTimestampWithTimeZoneType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestShortTimestampWithTimeZoneType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.SqlTimestampWithTimeZone; import org.junit.jupiter.api.Test; @@ -32,7 +32,7 @@ public TestShortTimestampWithTimeZoneType() super(TIMESTAMP_TZ_MILLIS, SqlTimestampWithTimeZone.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = TIMESTAMP_TZ_MILLIS.createBlockBuilder(null, 15); TIMESTAMP_TZ_MILLIS.writeLong(blockBuilder, packDateTimeWithZone(1111, getTimeZoneKeyForOffset(0))); @@ -46,7 +46,7 @@ public static Block createTestBlock() TIMESTAMP_TZ_MILLIS.writeLong(blockBuilder, packDateTimeWithZone(3333, getTimeZoneKeyForOffset(8))); TIMESTAMP_TZ_MILLIS.writeLong(blockBuilder, packDateTimeWithZone(3333, getTimeZoneKeyForOffset(9))); TIMESTAMP_TZ_MILLIS.writeLong(blockBuilder, packDateTimeWithZone(4444, getTimeZoneKeyForOffset(10))); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestSimpleRowType.java b/core/trino-main/src/test/java/io/trino/type/TestSimpleRowType.java index 4051d8275b1b..a31cadf58b20 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestSimpleRowType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestSimpleRowType.java @@ -14,9 +14,9 @@ package io.trino.type; import com.google.common.collect.ImmutableList; -import io.trino.spi.block.Block; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.SqlRow; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.RowType; import org.junit.jupiter.api.Test; @@ -41,7 +41,7 @@ public TestSimpleRowType() super(TYPE, List.class, createTestBlock()); } - private static Block createTestBlock() + private static ValueBlock createTestBlock() { RowBlockBuilder blockBuilder = TYPE.createBlockBuilder(null, 3); @@ -60,7 +60,7 @@ private static Block createTestBlock() VARCHAR.writeSlice(fieldBuilders.get(1), utf8Slice("dog")); }); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestSmallintArrayType.java b/core/trino-main/src/test/java/io/trino/type/TestSmallintArrayType.java index 40f97a4426f2..7874f8defe9a 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestSmallintArrayType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestSmallintArrayType.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -34,14 +35,14 @@ public TestSmallintArrayType() super(TESTING_TYPE_MANAGER.getType(arrayType(SMALLINT.getTypeSignature())), List.class, createTestBlock(TESTING_TYPE_MANAGER.getType(arrayType(SMALLINT.getTypeSignature())))); } - public static Block createTestBlock(Type arrayType) + public static ValueBlock createTestBlock(Type arrayType) { BlockBuilder blockBuilder = arrayType.createBlockBuilder(null, 4); arrayType.writeObject(blockBuilder, arrayBlockOf(SMALLINT, 1, 2)); arrayType.writeObject(blockBuilder, arrayBlockOf(SMALLINT, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(SMALLINT, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(SMALLINT, 100, 200, 300)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override @@ -54,7 +55,7 @@ protected Object getGreaterValue(Object value) } SMALLINT.writeLong(blockBuilder, 1L); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Test diff --git a/core/trino-main/src/test/java/io/trino/type/TestSmallintType.java b/core/trino-main/src/test/java/io/trino/type/TestSmallintType.java index bfbf5f76d737..46a163aa9d2b 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestSmallintType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestSmallintType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type.Range; import org.junit.jupiter.api.Test; @@ -32,7 +32,7 @@ public TestSmallintType() super(SMALLINT, Short.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = SMALLINT.createBlockBuilder(null, 15); SMALLINT.writeLong(blockBuilder, 1111); @@ -46,7 +46,7 @@ public static Block createTestBlock() SMALLINT.writeLong(blockBuilder, 3333); SMALLINT.writeLong(blockBuilder, 3333); SMALLINT.writeLong(blockBuilder, 4444); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestSmallintVarcharMapType.java b/core/trino-main/src/test/java/io/trino/type/TestSmallintVarcharMapType.java index 813611cbef70..15eead1a0eb6 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestSmallintVarcharMapType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestSmallintVarcharMapType.java @@ -14,8 +14,8 @@ package io.trino.type; import com.google.common.collect.ImmutableMap; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -36,13 +36,13 @@ public TestSmallintVarcharMapType() super(mapType(SMALLINT, VARCHAR), Map.class, createTestBlock(mapType(SMALLINT, VARCHAR))); } - public static Block createTestBlock(Type mapType) + public static ValueBlock createTestBlock(Type mapType) { BlockBuilder blockBuilder = mapType.createBlockBuilder(null, 2); mapType.writeObject(blockBuilder, sqlMapOf(SMALLINT, VARCHAR, ImmutableMap.of(1, "hi"))); mapType.writeObject(blockBuilder, sqlMapOf(SMALLINT, VARCHAR, ImmutableMap.of(1, "2", 2, "hello"))); mapType.writeObject(blockBuilder, sqlMapOf(SMALLINT, VARCHAR, ImmutableMap.of(1, "123456789012345", 2, "hello-world-hello-world-hello-world"))); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestTimeType.java b/core/trino-main/src/test/java/io/trino/type/TestTimeType.java index 99bc510a2984..2454aa80d464 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestTimeType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestTimeType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.SqlTime; import org.junit.jupiter.api.Test; @@ -29,7 +29,7 @@ public TestTimeType() super(TIME_MILLIS, SqlTime.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = TIME_MILLIS.createBlockBuilder(null, 15); TIME_MILLIS.writeLong(blockBuilder, 1_111_000_000_000L); @@ -43,7 +43,7 @@ public static Block createTestBlock() TIME_MILLIS.writeLong(blockBuilder, 3_333_000_000_000L); TIME_MILLIS.writeLong(blockBuilder, 3_333_000_000_000L); TIME_MILLIS.writeLong(blockBuilder, 4_444_000_000_000L); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestTimeWithTimeZoneType.java b/core/trino-main/src/test/java/io/trino/type/TestTimeWithTimeZoneType.java index 35657579b923..df18e4c47593 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestTimeWithTimeZoneType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestTimeWithTimeZoneType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.SqlTimeWithTimeZone; import org.junit.jupiter.api.Test; @@ -32,7 +32,7 @@ public TestTimeWithTimeZoneType() super(TIME_TZ_MILLIS, SqlTimeWithTimeZone.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = TIME_TZ_MILLIS.createBlockBuilder(null, 15); TIME_TZ_MILLIS.writeLong(blockBuilder, packTimeWithTimeZone(1_111_000_000L, 0)); @@ -46,7 +46,7 @@ public static Block createTestBlock() TIME_TZ_MILLIS.writeLong(blockBuilder, packTimeWithTimeZone(3_333_000_000L, 8)); TIME_TZ_MILLIS.writeLong(blockBuilder, packTimeWithTimeZone(3_333_000_000L, 9)); TIME_TZ_MILLIS.writeLong(blockBuilder, packTimeWithTimeZone(4_444_000_000L, 10)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestTinyintArrayType.java b/core/trino-main/src/test/java/io/trino/type/TestTinyintArrayType.java index e78bb757d25d..327622dd2c28 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestTinyintArrayType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestTinyintArrayType.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -34,14 +35,14 @@ public TestTinyintArrayType() super(TESTING_TYPE_MANAGER.getType(arrayType(TINYINT.getTypeSignature())), List.class, createTestBlock(TESTING_TYPE_MANAGER.getType(arrayType(TINYINT.getTypeSignature())))); } - public static Block createTestBlock(Type arrayType) + public static ValueBlock createTestBlock(Type arrayType) { BlockBuilder blockBuilder = arrayType.createBlockBuilder(null, 4); arrayType.writeObject(blockBuilder, arrayBlockOf(TINYINT, 1, 2)); arrayType.writeObject(blockBuilder, arrayBlockOf(TINYINT, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(TINYINT, 1, 2, 3)); arrayType.writeObject(blockBuilder, arrayBlockOf(TINYINT, 100, 110, 127)); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override @@ -54,7 +55,7 @@ protected Object getGreaterValue(Object value) } TINYINT.writeLong(blockBuilder, 1L); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Test diff --git a/core/trino-main/src/test/java/io/trino/type/TestTinyintType.java b/core/trino-main/src/test/java/io/trino/type/TestTinyintType.java index a4d3b0c75cab..c4987a648156 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestTinyintType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestTinyintType.java @@ -13,8 +13,8 @@ */ package io.trino.type; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type.Range; import org.junit.jupiter.api.Test; @@ -32,7 +32,7 @@ public TestTinyintType() super(TINYINT, Byte.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = TINYINT.createBlockBuilder(null, 15); TINYINT.writeLong(blockBuilder, 111); @@ -46,7 +46,7 @@ public static Block createTestBlock() TINYINT.writeLong(blockBuilder, 33); TINYINT.writeLong(blockBuilder, 33); TINYINT.writeLong(blockBuilder, 44); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestTinyintVarcharMapType.java b/core/trino-main/src/test/java/io/trino/type/TestTinyintVarcharMapType.java index 4f34ce2ec4d9..522bc24c44d8 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestTinyintVarcharMapType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestTinyintVarcharMapType.java @@ -14,8 +14,8 @@ package io.trino.type; import com.google.common.collect.ImmutableMap; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -36,13 +36,13 @@ public TestTinyintVarcharMapType() super(mapType(TINYINT, VARCHAR), Map.class, createTestBlock(mapType(TINYINT, VARCHAR))); } - public static Block createTestBlock(Type mapType) + public static ValueBlock createTestBlock(Type mapType) { BlockBuilder blockBuilder = mapType.createBlockBuilder(null, 2); mapType.writeObject(blockBuilder, sqlMapOf(TINYINT, VARCHAR, ImmutableMap.of(1, "hi"))); mapType.writeObject(blockBuilder, sqlMapOf(TINYINT, VARCHAR, ImmutableMap.of(1, "2", 2, "hello"))); mapType.writeObject(blockBuilder, sqlMapOf(TINYINT, VARCHAR, ImmutableMap.of(1, "123456789012345", 2, "hello-world-hello-world-hello-world"))); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestUnboundedVarcharType.java b/core/trino-main/src/test/java/io/trino/type/TestUnboundedVarcharType.java index 36911a64fd90..4fb2e02eab98 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestUnboundedVarcharType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestUnboundedVarcharType.java @@ -15,8 +15,8 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import org.junit.jupiter.api.Test; import static io.trino.spi.type.VarcharType.VARCHAR; @@ -30,7 +30,7 @@ public TestUnboundedVarcharType() super(VARCHAR, String.class, createTestBlock()); } - private static Block createTestBlock() + private static ValueBlock createTestBlock() { BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, 15); VARCHAR.writeString(blockBuilder, "apple"); @@ -44,7 +44,7 @@ private static Block createTestBlock() VARCHAR.writeString(blockBuilder, "cherry"); VARCHAR.writeString(blockBuilder, "cherry"); VARCHAR.writeString(blockBuilder, "date"); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestUnknownType.java b/core/trino-main/src/test/java/io/trino/type/TestUnknownType.java index ff22b3c1d43e..6a2a0ce364ac 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestUnknownType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestUnknownType.java @@ -29,7 +29,7 @@ public TestUnknownType() .appendNull() .appendNull() .appendNull() - .build()); + .buildValueBlock()); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestUuidType.java b/core/trino-main/src/test/java/io/trino/type/TestUuidType.java index a83ef738ba16..c10478f35c53 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestUuidType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestUuidType.java @@ -15,8 +15,8 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.TypeOperators; import org.junit.jupiter.api.Test; @@ -44,14 +44,14 @@ public TestUuidType() super(UUID, String.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = UUID.createBlockBuilder(null, 1); for (int i = 0; i < 10; i++) { String uuid = "6b5f5b65-67e4-43b0-8ee3-586cd49f58a" + i; UUID.writeSlice(blockBuilder, castFromVarcharToUuid(utf8Slice(uuid))); } - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestVarbinaryType.java b/core/trino-main/src/test/java/io/trino/type/TestVarbinaryType.java index c9468a5421db..4b3e2c92dcd2 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestVarbinaryType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestVarbinaryType.java @@ -15,8 +15,8 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.SqlVarbinary; import org.junit.jupiter.api.Test; @@ -31,7 +31,7 @@ public TestVarbinaryType() super(VARBINARY, SqlVarbinary.class, createTestBlock()); } - public static Block createTestBlock() + public static ValueBlock createTestBlock() { BlockBuilder blockBuilder = VARBINARY.createBlockBuilder(null, 15); VARBINARY.writeSlice(blockBuilder, Slices.utf8Slice("apple")); @@ -45,7 +45,7 @@ public static Block createTestBlock() VARBINARY.writeSlice(blockBuilder, Slices.utf8Slice("cherry")); VARBINARY.writeSlice(blockBuilder, Slices.utf8Slice("cherry")); VARBINARY.writeSlice(blockBuilder, Slices.utf8Slice("date")); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestVarcharArrayType.java b/core/trino-main/src/test/java/io/trino/type/TestVarcharArrayType.java index 2c403c4ec5f9..0015ab005c09 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestVarcharArrayType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestVarcharArrayType.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -35,13 +36,13 @@ public TestVarcharArrayType() super(TESTING_TYPE_MANAGER.getType(arrayType(VARCHAR.getTypeSignature())), List.class, createTestBlock(TESTING_TYPE_MANAGER.getType(arrayType(VARCHAR.getTypeSignature())))); } - public static Block createTestBlock(Type arrayType) + public static ValueBlock createTestBlock(Type arrayType) { BlockBuilder blockBuilder = arrayType.createBlockBuilder(null, 4); arrayType.writeObject(blockBuilder, arrayBlockOf(VARCHAR, "1", "2")); arrayType.writeObject(blockBuilder, arrayBlockOf(VARCHAR, "the", "quick", "brown", "fox")); arrayType.writeObject(blockBuilder, arrayBlockOf(VARCHAR, "one-two-three-four-five", "123456789012345", "the quick brown fox", "hello-world-hello-world-hello-world")); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/type/TestVarcharVarcharMapType.java b/core/trino-main/src/test/java/io/trino/type/TestVarcharVarcharMapType.java index dc8858afb405..42a666cda7e2 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestVarcharVarcharMapType.java +++ b/core/trino-main/src/test/java/io/trino/type/TestVarcharVarcharMapType.java @@ -14,8 +14,8 @@ package io.trino.type; import com.google.common.collect.ImmutableMap; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -35,13 +35,13 @@ public TestVarcharVarcharMapType() super(mapType(VARCHAR, VARCHAR), Map.class, createTestBlock(mapType(VARCHAR, VARCHAR))); } - public static Block createTestBlock(Type mapType) + public static ValueBlock createTestBlock(Type mapType) { BlockBuilder blockBuilder = mapType.createBlockBuilder(null, 2); mapType.writeObject(blockBuilder, sqlMapOf(VARCHAR, VARCHAR, ImmutableMap.of("hi", "there"))); mapType.writeObject(blockBuilder, sqlMapOf(VARCHAR, VARCHAR, ImmutableMap.of("one", "1", "hello", "world"))); mapType.writeObject(blockBuilder, sqlMapOf(VARCHAR, VARCHAR, ImmutableMap.of("one-two-three-four-five", "123456789012345", "the quick brown fox", "hello-world-hello-world-hello-world"))); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java b/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java index 19fe8278990d..b8b86d57f99c 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java +++ b/core/trino-parser/src/main/java/io/trino/sql/QueryUtil.java @@ -25,7 +25,6 @@ import io.trino.sql.tree.Identifier; import io.trino.sql.tree.LogicalExpression; import io.trino.sql.tree.Node; -import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.Offset; import io.trino.sql.tree.OrderBy; import io.trino.sql.tree.QualifiedName; @@ -268,25 +267,6 @@ public static Query singleValueQuery(String columnName, boolean value) aliased(values, "t", ImmutableList.of(columnName))); } - // TODO pass column types - public static Query emptyQuery(List columns) - { - Select select = selectList(columns.stream() - .map(column -> new SingleColumn(new NullLiteral(), QueryUtil.identifier(column))) - .toArray(SelectItem[]::new)); - Optional where = Optional.of(FALSE_LITERAL); - return query(new QuerySpecification( - select, - Optional.empty(), - where, - Optional.empty(), - Optional.empty(), - ImmutableList.of(), - Optional.empty(), - Optional.empty(), - Optional.empty())); - } - public static Query query(QueryBody body) { return new Query( diff --git a/core/trino-spi/pom.xml b/core/trino-spi/pom.xml index 7eb925b18135..ce43593d4d25 100644 --- a/core/trino-spi/pom.xml +++ b/core/trino-spi/pom.xml @@ -464,6 +464,489 @@ class io.trino.spi.block.VariableWidthBlock class io.trino.spi.block.VariableWidthBlock + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractArrayBlock::copyPositions(int[], int, int) @ io.trino.spi.block.ArrayBlock + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractArrayBlock::copyRegion(int, int) @ io.trino.spi.block.ArrayBlock + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ArrayBlock::copyWithAppendedNull() + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ArrayBlock::getLoadedBlock() + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::getLoadedBlock() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractArrayBlock::getRegion(int, int) @ io.trino.spi.block.ArrayBlock + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractArrayBlock::getSingleValueBlock(int) @ io.trino.spi.block.ArrayBlock + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ByteArrayBlock::copyPositions(int[], int, int) + method io.trino.spi.block.ByteArrayBlock io.trino.spi.block.ByteArrayBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ByteArrayBlock::copyRegion(int, int) + method io.trino.spi.block.ByteArrayBlock io.trino.spi.block.ByteArrayBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ByteArrayBlock::copyWithAppendedNull() + method io.trino.spi.block.ByteArrayBlock io.trino.spi.block.ByteArrayBlock::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ByteArrayBlock::getRegion(int, int) + method io.trino.spi.block.ByteArrayBlock io.trino.spi.block.ByteArrayBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ByteArrayBlock::getSingleValueBlock(int) + method io.trino.spi.block.ByteArrayBlock io.trino.spi.block.ByteArrayBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Fixed12Block::copyPositions(int[], int, int) + method io.trino.spi.block.Fixed12Block io.trino.spi.block.Fixed12Block::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Fixed12Block::copyRegion(int, int) + method io.trino.spi.block.Fixed12Block io.trino.spi.block.Fixed12Block::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Fixed12Block::copyWithAppendedNull() + method io.trino.spi.block.Fixed12Block io.trino.spi.block.Fixed12Block::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Fixed12Block::getRegion(int, int) + method io.trino.spi.block.Fixed12Block io.trino.spi.block.Fixed12Block::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Fixed12Block::getSingleValueBlock(int) + method io.trino.spi.block.Fixed12Block io.trino.spi.block.Fixed12Block::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Int128ArrayBlock::copyPositions(int[], int, int) + method io.trino.spi.block.Int128ArrayBlock io.trino.spi.block.Int128ArrayBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Int128ArrayBlock::copyRegion(int, int) + method io.trino.spi.block.Int128ArrayBlock io.trino.spi.block.Int128ArrayBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Int128ArrayBlock::copyWithAppendedNull() + method io.trino.spi.block.Int128ArrayBlock io.trino.spi.block.Int128ArrayBlock::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Int128ArrayBlock::getRegion(int, int) + method io.trino.spi.block.Int128ArrayBlock io.trino.spi.block.Int128ArrayBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Int128ArrayBlock::getSingleValueBlock(int) + method io.trino.spi.block.Int128ArrayBlock io.trino.spi.block.Int128ArrayBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.IntArrayBlock::copyPositions(int[], int, int) + method io.trino.spi.block.IntArrayBlock io.trino.spi.block.IntArrayBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.IntArrayBlock::copyRegion(int, int) + method io.trino.spi.block.IntArrayBlock io.trino.spi.block.IntArrayBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.IntArrayBlock::copyWithAppendedNull() + method io.trino.spi.block.IntArrayBlock io.trino.spi.block.IntArrayBlock::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.IntArrayBlock::getRegion(int, int) + method io.trino.spi.block.IntArrayBlock io.trino.spi.block.IntArrayBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.IntArrayBlock::getSingleValueBlock(int) + method io.trino.spi.block.IntArrayBlock io.trino.spi.block.IntArrayBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.LongArrayBlock::copyPositions(int[], int, int) + method io.trino.spi.block.LongArrayBlock io.trino.spi.block.LongArrayBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.LongArrayBlock::copyRegion(int, int) + method io.trino.spi.block.LongArrayBlock io.trino.spi.block.LongArrayBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.LongArrayBlock::copyWithAppendedNull() + method io.trino.spi.block.LongArrayBlock io.trino.spi.block.LongArrayBlock::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.LongArrayBlock::getRegion(int, int) + method io.trino.spi.block.LongArrayBlock io.trino.spi.block.LongArrayBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.LongArrayBlock::getSingleValueBlock(int) + method io.trino.spi.block.LongArrayBlock io.trino.spi.block.LongArrayBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractMapBlock::copyPositions(int[], int, int) @ io.trino.spi.block.MapBlock + method io.trino.spi.block.MapBlock io.trino.spi.block.MapBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractMapBlock::copyRegion(int, int) @ io.trino.spi.block.MapBlock + method io.trino.spi.block.MapBlock io.trino.spi.block.MapBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.MapBlock::copyWithAppendedNull() + method io.trino.spi.block.MapBlock io.trino.spi.block.MapBlock::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractMapBlock::getRegion(int, int) @ io.trino.spi.block.MapBlock + method io.trino.spi.block.MapBlock io.trino.spi.block.MapBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractMapBlock::getSingleValueBlock(int) @ io.trino.spi.block.MapBlock + method io.trino.spi.block.MapBlock io.trino.spi.block.MapBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractRowBlock::copyPositions(int[], int, int) @ io.trino.spi.block.RowBlock + method io.trino.spi.block.RowBlock io.trino.spi.block.RowBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractRowBlock::copyRegion(int, int) @ io.trino.spi.block.RowBlock + method io.trino.spi.block.RowBlock io.trino.spi.block.RowBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.RowBlock::copyWithAppendedNull() + method io.trino.spi.block.RowBlock io.trino.spi.block.RowBlock::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractRowBlock::getRegion(int, int) @ io.trino.spi.block.RowBlock + method io.trino.spi.block.RowBlock io.trino.spi.block.RowBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractRowBlock::getSingleValueBlock(int) @ io.trino.spi.block.RowBlock + method io.trino.spi.block.RowBlock io.trino.spi.block.RowBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ShortArrayBlock::copyPositions(int[], int, int) + method io.trino.spi.block.ShortArrayBlock io.trino.spi.block.ShortArrayBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ShortArrayBlock::copyRegion(int, int) + method io.trino.spi.block.ShortArrayBlock io.trino.spi.block.ShortArrayBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ShortArrayBlock::copyWithAppendedNull() + method io.trino.spi.block.ShortArrayBlock io.trino.spi.block.ShortArrayBlock::copyWithAppendedNull() + ADD YOUR EXPLANATION FOR THE NECESSITY OF THIS CHANGE + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ShortArrayBlock::getRegion(int, int) + method io.trino.spi.block.ShortArrayBlock io.trino.spi.block.ShortArrayBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ShortArrayBlock::getSingleValueBlock(int) + method io.trino.spi.block.ShortArrayBlock io.trino.spi.block.ShortArrayBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.VariableWidthBlock::copyPositions(int[], int, int) + method io.trino.spi.block.VariableWidthBlock io.trino.spi.block.VariableWidthBlock::copyPositions(int[], int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.VariableWidthBlock::copyRegion(int, int) + method io.trino.spi.block.VariableWidthBlock io.trino.spi.block.VariableWidthBlock::copyRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.VariableWidthBlock::copyWithAppendedNull() + method io.trino.spi.block.VariableWidthBlock io.trino.spi.block.VariableWidthBlock::copyWithAppendedNull() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.VariableWidthBlock::getRegion(int, int) + method io.trino.spi.block.VariableWidthBlock io.trino.spi.block.VariableWidthBlock::getRegion(int, int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.AbstractVariableWidthBlock::getSingleValueBlock(int) @ io.trino.spi.block.VariableWidthBlock + method io.trino.spi.block.VariableWidthBlock io.trino.spi.block.VariableWidthBlock::getSingleValueBlock(int) + + + java.method.addedToInterface + method io.trino.spi.block.ValueBlock io.trino.spi.block.Block::getUnderlyingValueBlock() + + + java.method.addedToInterface + method int io.trino.spi.block.Block::getUnderlyingValuePosition(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.RunLengthEncodedBlock::getSingleValueBlock(int) + method io.trino.spi.block.ValueBlock io.trino.spi.block.RunLengthEncodedBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.RunLengthEncodedBlock::getValue() + method io.trino.spi.block.ValueBlock io.trino.spi.block.RunLengthEncodedBlock::getValue() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Block::getSingleValueBlock(int) + method io.trino.spi.block.ValueBlock io.trino.spi.block.Block::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.DictionaryBlock::getSingleValueBlock(int) + method io.trino.spi.block.ValueBlock io.trino.spi.block.DictionaryBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.LazyBlock::getSingleValueBlock(int) + method io.trino.spi.block.ValueBlock io.trino.spi.block.LazyBlock::getSingleValueBlock(int) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.DictionaryBlock::getDictionary() + method io.trino.spi.block.ValueBlock io.trino.spi.block.DictionaryBlock::getDictionary() + + + java.method.addedToInterface + method io.trino.spi.block.ValueBlock io.trino.spi.block.BlockBuilder::buildValueBlock() + + + java.method.numberOfParametersChanged + method void io.trino.spi.type.AbstractType::<init>(io.trino.spi.type.TypeSignature, java.lang.Class<?>) + method void io.trino.spi.type.AbstractType::<init>(io.trino.spi.type.TypeSignature, java.lang.Class<?>, java.lang.Class<? extends io.trino.spi.block.ValueBlock>) + + + java.method.numberOfParametersChanged + method void io.trino.spi.type.TimeWithTimeZoneType::<init>(int, java.lang.Class<?>) + method void io.trino.spi.type.TimeWithTimeZoneType::<init>(int, java.lang.Class<?>, java.lang.Class<? extends io.trino.spi.block.ValueBlock>) + + + java.method.addedToInterface + method java.lang.Class<? extends io.trino.spi.block.ValueBlock> io.trino.spi.type.Type::getValueBlockType() + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.type.TypeUtils::writeNativeValue(io.trino.spi.type.Type, java.lang.Object) + method io.trino.spi.block.ValueBlock io.trino.spi.type.TypeUtils::writeNativeValue(io.trino.spi.type.Type, java.lang.Object) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ArrayBlock::fromElementBlock(int, java.util.Optional<boolean[]>, int[], io.trino.spi.block.Block) + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlock::fromElementBlock(int, java.util.Optional<boolean[]>, int[], io.trino.spi.block.Block) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.RowBlock::fromFieldBlocks(int, java.util.Optional<boolean[]>, io.trino.spi.block.Block[]) + method io.trino.spi.block.RowBlock io.trino.spi.block.RowBlock::fromFieldBlocks(int, java.util.Optional<boolean[]>, io.trino.spi.block.Block[]) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.type.MapType::createBlockFromKeyValue(java.util.Optional<boolean[]>, int[], io.trino.spi.block.Block, io.trino.spi.block.Block) + method io.trino.spi.block.MapBlock io.trino.spi.type.MapType::createBlockFromKeyValue(java.util.Optional<boolean[]>, int[], io.trino.spi.block.Block, io.trino.spi.block.Block) + + + java.method.visibilityIncreased + method int[] io.trino.spi.block.DictionaryBlock::getRawIds() + method int[] io.trino.spi.block.DictionaryBlock::getRawIds() + package + public + + + java.method.visibilityIncreased + method int io.trino.spi.block.DictionaryBlock::getRawIdsOffset() + method int io.trino.spi.block.DictionaryBlock::getRawIdsOffset() + package + public + + + java.method.removed + method int io.trino.spi.block.Block::bytesCompare(int, int, int, io.airlift.slice.Slice, int, int) + + + java.method.removed + method boolean io.trino.spi.block.Block::bytesEqual(int, int, io.airlift.slice.Slice, int, int) + + + java.method.removed + method int io.trino.spi.block.Block::compareTo(int, int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.removed + method boolean io.trino.spi.block.Block::equals(int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.removed + method long io.trino.spi.block.Block::hash(int, int, int) + + + java.method.removed + method int io.trino.spi.block.DictionaryBlock::bytesCompare(int, int, int, io.airlift.slice.Slice, int, int) + + + java.method.removed + method boolean io.trino.spi.block.DictionaryBlock::bytesEqual(int, int, io.airlift.slice.Slice, int, int) + + + java.method.removed + method int io.trino.spi.block.DictionaryBlock::compareTo(int, int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.removed + method boolean io.trino.spi.block.DictionaryBlock::equals(int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.removed + method long io.trino.spi.block.DictionaryBlock::hash(int, int, int) + + + java.method.removed + method int io.trino.spi.block.LazyBlock::bytesCompare(int, int, int, io.airlift.slice.Slice, int, int) + + + java.method.removed + method boolean io.trino.spi.block.LazyBlock::bytesEqual(int, int, io.airlift.slice.Slice, int, int) + + + java.method.removed + method int io.trino.spi.block.LazyBlock::compareTo(int, int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.removed + method boolean io.trino.spi.block.LazyBlock::equals(int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.removed + method long io.trino.spi.block.LazyBlock::hash(int, int, int) + + + java.method.removed + method int io.trino.spi.block.RunLengthEncodedBlock::bytesCompare(int, int, int, io.airlift.slice.Slice, int, int) + + + java.method.removed + method boolean io.trino.spi.block.RunLengthEncodedBlock::bytesEqual(int, int, io.airlift.slice.Slice, int, int) + ADD YOUR EXPLANATION FOR THE NECESSITY OF THIS CHANGE + + + java.method.removed + method int io.trino.spi.block.RunLengthEncodedBlock::compareTo(int, int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.removed + method boolean io.trino.spi.block.RunLengthEncodedBlock::equals(int, int, io.trino.spi.block.Block, int, int, int) + + + java.method.removed + method long io.trino.spi.block.RunLengthEncodedBlock::hash(int, int, int) + + + java.method.removed + method void io.trino.spi.block.Block::writeSliceTo(int, int, int, io.airlift.slice.SliceOutput) + + + java.method.removed + method void io.trino.spi.block.DictionaryBlock::writeSliceTo(int, int, int, io.airlift.slice.SliceOutput) + + + java.method.removed + method void io.trino.spi.block.LazyBlock::writeSliceTo(int, int, int, io.airlift.slice.SliceOutput) + + + java.method.removed + method void io.trino.spi.block.RunLengthEncodedBlock::writeSliceTo(int, int, int, io.airlift.slice.SliceOutput) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + method io.trino.spi.block.ArrayBlock io.trino.spi.block.ArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ByteArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + method io.trino.spi.block.ByteArrayBlock io.trino.spi.block.ByteArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Fixed12BlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + method io.trino.spi.block.Fixed12Block io.trino.spi.block.Fixed12BlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.Int128ArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + method io.trino.spi.block.Int128ArrayBlock io.trino.spi.block.Int128ArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.IntArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + method io.trino.spi.block.IntArrayBlock io.trino.spi.block.IntArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.LongArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + method io.trino.spi.block.LongArrayBlock io.trino.spi.block.LongArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + + + java.method.returnTypeChangedCovariantly + method io.trino.spi.block.Block io.trino.spi.block.ShortArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + method io.trino.spi.block.ShortArrayBlock io.trino.spi.block.ShortArrayBlockEncoding::readBlock(io.trino.spi.block.BlockEncodingSerde, io.airlift.slice.SliceInput) + + + java.method.nowStatic + method void io.trino.spi.type.AbstractIntType::checkValueValid(long) + method void io.trino.spi.type.AbstractIntType::checkValueValid(long) + diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlock.java index 16d2eda205a0..a95abd2700bc 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlock.java @@ -37,7 +37,7 @@ import static java.util.Objects.requireNonNull; public class ArrayBlock - implements Block + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(ArrayBlock.class); @@ -54,7 +54,7 @@ public class ArrayBlock * Create an array block directly from columnar nulls, values, and offsets into the values. * A null array must have no entries. */ - public static Block fromElementBlock(int positionCount, Optional valueIsNullOptional, int[] arrayOffset, Block values) + public static ArrayBlock fromElementBlock(int positionCount, Optional valueIsNullOptional, int[] arrayOffset, Block values) { boolean[] valueIsNull = valueIsNullOptional.orElse(null); validateConstructorArguments(0, positionCount, valueIsNull, arrayOffset, values); @@ -216,7 +216,7 @@ public boolean isLoaded() } @Override - public Block getLoadedBlock() + public ArrayBlock getLoadedBlock() { Block loadedValuesBlock = values.getLoadedBlock(); @@ -232,7 +232,7 @@ public Block getLoadedBlock() } @Override - public Block copyWithAppendedNull() + public ArrayBlock copyWithAppendedNull() { boolean[] newValueIsNull = copyIsNullAndAppendNull(valueIsNull, arrayOffset, getPositionCount()); int[] newOffsets = copyOffsetsAndAppendNull(offsets, arrayOffset, getPositionCount()); @@ -246,7 +246,7 @@ public Block copyWithAppendedNull() } @Override - public Block copyPositions(int[] positions, int offset, int length) + public ArrayBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -278,7 +278,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int position, int length) + public ArrayBlock getRegion(int position, int length) { int positionCount = getPositionCount(); checkValidRegion(positionCount, position, length); @@ -343,7 +343,7 @@ else if (rawElementBlock instanceof RunLengthEncodedBlock) { } @Override - public Block copyRegion(int position, int length) + public ArrayBlock copyRegion(int position, int length) { int positionCount = getPositionCount(); checkValidRegion(positionCount, position, length); @@ -369,15 +369,19 @@ public T getObject(int position, Class clazz) if (clazz != Block.class) { throw new IllegalArgumentException("clazz must be Block.class"); } - checkReadablePosition(this, position); + return clazz.cast(getArray(position)); + } + public Block getArray(int position) + { + checkReadablePosition(this, position); int startValueOffset = offsets[position + arrayOffset]; int endValueOffset = offsets[position + 1 + arrayOffset]; - return clazz.cast(values.getRegion(startValueOffset, endValueOffset - startValueOffset)); + return values.getRegion(startValueOffset, endValueOffset - startValueOffset); } @Override - public Block getSingleValueBlock(int position) + public ArrayBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); @@ -421,6 +425,12 @@ public boolean isNull(int position) return valueIsNull != null && valueIsNull[position + arrayOffset]; } + @Override + public ArrayBlock getUnderlyingValueBlock() + { + return this; + } + public T apply(ArrayBlockFunction function, int position) { checkReadablePosition(this, position); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockBuilder.java index c23ae4c325a4..df28648003e7 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockBuilder.java @@ -176,6 +176,15 @@ public Block build() if (!hasNonNullRow) { return nullRle(positionCount); } + return buildValueBlock(); + } + + @Override + public ValueBlock buildValueBlock() + { + if (currentEntryOpened) { + throw new IllegalStateException("Current entry must be closed before the block can be built"); + } return createArrayBlockInternal(0, positionCount, hasNullValue ? valueIsNull : null, offsets, values.build()); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockEncoding.java index bf3a3a9d26c2..ad01d4128132 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ArrayBlockEncoding.java @@ -50,11 +50,11 @@ public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceO for (int position = 0; position < positionCount + 1; position++) { sliceOutput.writeInt(offsets[offsetBase + position] - valuesStartOffset); } - encodeNullsAsBits(sliceOutput, block); + encodeNullsAsBits(sliceOutput, arrayBlock); } @Override - public Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) + public ArrayBlock readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) { Block values = blockEncodingSerde.readBlock(sliceInput); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Block.java b/core/trino-spi/src/main/java/io/trino/spi/block/Block.java index 3281bb7e6727..e712cfdb3ef0 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Block.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Block.java @@ -14,7 +14,6 @@ package io.trino.spi.block; import io.airlift.slice.Slice; -import io.airlift.slice.SliceOutput; import java.util.Collections; import java.util.List; @@ -22,6 +21,7 @@ import java.util.function.ObjLongConsumer; import static io.trino.spi.block.BlockUtil.checkArrayRange; +import static io.trino.spi.block.DictionaryId.randomDictionaryId; public interface Block { @@ -74,14 +74,6 @@ default Slice getSlice(int position, int offset, int length) throw new UnsupportedOperationException(getClass().getName()); } - /** - * Writes a slice at {@code offset} in the value at {@code position} into the {@code output} slice output. - */ - default void writeSliceTo(int position, int offset, int length, SliceOutput output) - { - throw new UnsupportedOperationException(getClass().getName()); - } - /** * Gets an object in the value at {@code position}. */ @@ -90,58 +82,6 @@ default T getObject(int position, Class clazz) throw new UnsupportedOperationException(getClass().getName()); } - /** - * Is the byte sequences at {@code offset} in the value at {@code position} equal - * to the byte sequence at {@code otherOffset} in {@code otherSlice}. - * This method must be implemented if @{code getSlice} is implemented. - */ - default boolean bytesEqual(int position, int offset, Slice otherSlice, int otherOffset, int length) - { - throw new UnsupportedOperationException(getClass().getName()); - } - - /** - * Compares the byte sequences at {@code offset} in the value at {@code position} - * to the byte sequence at {@code otherOffset} in {@code otherSlice}. - * This method must be implemented if @{code getSlice} is implemented. - */ - default int bytesCompare(int position, int offset, int length, Slice otherSlice, int otherOffset, int otherLength) - { - throw new UnsupportedOperationException(getClass().getName()); - } - - /** - * Is the byte sequences at {@code offset} in the value at {@code position} equal - * to the byte sequence at {@code otherOffset} in the value at {@code otherPosition} - * in {@code otherBlock}. - * This method must be implemented if @{code getSlice} is implemented. - */ - default boolean equals(int position, int offset, Block otherBlock, int otherPosition, int otherOffset, int length) - { - throw new UnsupportedOperationException(getClass().getName()); - } - - /** - * Calculates the hash code the byte sequences at {@code offset} in the - * value at {@code position}. - * This method must be implemented if @{code getSlice} is implemented. - */ - default long hash(int position, int offset, int length) - { - throw new UnsupportedOperationException(getClass().getName()); - } - - /** - * Compares the byte sequences at {@code offset} in the value at {@code position} - * to the byte sequence at {@code otherOffset} in the value at {@code otherPosition} - * in {@code otherBlock}. - * This method must be implemented if @{code getSlice} is implemented. - */ - default int compareTo(int leftPosition, int leftOffset, int leftLength, Block rightBlock, int rightPosition, int rightOffset, int rightLength) - { - throw new UnsupportedOperationException(getClass().getName()); - } - /** * Gets the value at the specified position as a single element block. The method * must copy the data into a new block. @@ -151,7 +91,7 @@ default int compareTo(int leftPosition, int leftOffset, int leftLength, Block ri * * @throws IllegalArgumentException if this position is not valid */ - Block getSingleValueBlock(int position); + ValueBlock getSingleValueBlock(int position); /** * Returns the number of positions in this block. @@ -243,7 +183,7 @@ default Block getPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); - return new DictionaryBlock(offset, length, this, positions); + return DictionaryBlock.createInternal(offset, length, this, positions, randomDictionaryId()); } /** @@ -334,4 +274,14 @@ default List getChildren() * i.e. not on in-progress block builders. */ Block copyWithAppendedNull(); + + /** + * Returns the underlying value block underlying this block. + */ + ValueBlock getUnderlyingValueBlock(); + + /** + * Returns the position in the underlying value block corresponding to the specified position in this block. + */ + int getUnderlyingValuePosition(int position); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/BlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/BlockBuilder.java index 79f78dca5634..7d458991497e 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/BlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/BlockBuilder.java @@ -43,9 +43,15 @@ public interface BlockBuilder /** * Builds the block. This method can be called multiple times. + * The return value may be a block such as RLE to allow for optimizations when all block values are the same. */ Block build(); + /** + * Builds a ValueBlock. This method can be called multiple times. + */ + ValueBlock buildValueBlock(); + /** * Creates a new block builder of the same type based on the current usage statistics of this block builder. */ diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlock.java index da643b1aa370..44d45ca0a4c7 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlock.java @@ -31,7 +31,7 @@ import static io.trino.spi.block.BlockUtil.ensureCapacity; public class ByteArrayBlock - implements Block + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(ByteArrayBlock.class); public static final int SIZE_IN_BYTES_PER_POSITION = Byte.BYTES + Byte.BYTES; @@ -128,10 +128,15 @@ public int getPositionCount() @Override public byte getByte(int position, int offset) { - checkReadablePosition(this, position); if (offset != 0) { throw new IllegalArgumentException("offset must be zero"); } + return getByte(position); + } + + public byte getByte(int position) + { + checkReadablePosition(this, position); return values[position + arrayOffset]; } @@ -149,7 +154,7 @@ public boolean isNull(int position) } @Override - public Block getSingleValueBlock(int position) + public ByteArrayBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); return new ByteArrayBlock( @@ -160,7 +165,7 @@ public Block getSingleValueBlock(int position) } @Override - public Block copyPositions(int[] positions, int offset, int length) + public ByteArrayBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -181,7 +186,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int positionOffset, int length) + public ByteArrayBlock getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -189,7 +194,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public Block copyRegion(int positionOffset, int length) + public ByteArrayBlock copyRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -210,7 +215,7 @@ public String getEncodingName() } @Override - public Block copyWithAppendedNull() + public ByteArrayBlock copyWithAppendedNull() { boolean[] newValueIsNull = copyIsNullAndAppendNull(valueIsNull, arrayOffset, positionCount); byte[] newValues = ensureCapacity(values, arrayOffset + positionCount + 1); @@ -218,6 +223,12 @@ public Block copyWithAppendedNull() return new ByteArrayBlock(arrayOffset, positionCount + 1, newValueIsNull, newValues); } + @Override + public ByteArrayBlock getUnderlyingValueBlock() + { + return this; + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockBuilder.java index 0d5592813c5e..559ead304ead 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockBuilder.java @@ -13,8 +13,6 @@ */ package io.trino.spi.block; -import io.airlift.slice.Slice; -import io.airlift.slice.Slices; import jakarta.annotation.Nullable; import java.util.Arrays; @@ -91,6 +89,12 @@ public Block build() if (!hasNonNullValue) { return RunLengthEncodedBlock.create(NULL_VALUE_BLOCK, positionCount); } + return buildValueBlock(); + } + + @Override + public ByteArrayBlock buildValueBlock() + { return new ByteArrayBlock(0, positionCount, hasNullValue ? valueIsNull : null, values); } @@ -150,9 +154,4 @@ public String toString() sb.append('}'); return sb.toString(); } - - Slice getValuesSlice() - { - return Slices.wrappedBuffer(values, 0, positionCount); - } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockEncoding.java index 0fc86d4549d1..17f346f4e440 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ByteArrayBlockEncoding.java @@ -13,7 +13,6 @@ */ package io.trino.spi.block; -import io.airlift.slice.Slice; import io.airlift.slice.SliceInput; import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; @@ -37,20 +36,21 @@ public String getName() @Override public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceOutput, Block block) { - int positionCount = block.getPositionCount(); + ByteArrayBlock byteArrayBlock = (ByteArrayBlock) block; + int positionCount = byteArrayBlock.getPositionCount(); sliceOutput.appendInt(positionCount); - encodeNullsAsBits(sliceOutput, block); + encodeNullsAsBits(sliceOutput, byteArrayBlock); - if (!block.mayHaveNull()) { - sliceOutput.writeBytes(getValuesSlice(block)); + if (!byteArrayBlock.mayHaveNull()) { + sliceOutput.writeBytes(byteArrayBlock.getValuesSlice()); } else { byte[] valuesWithoutNull = new byte[positionCount]; int nonNullPositionCount = 0; for (int i = 0; i < positionCount; i++) { - valuesWithoutNull[nonNullPositionCount] = block.getByte(i, 0); - if (!block.isNull(i)) { + valuesWithoutNull[nonNullPositionCount] = byteArrayBlock.getByte(i); + if (!byteArrayBlock.isNull(i)) { nonNullPositionCount++; } } @@ -61,7 +61,7 @@ public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceO } @Override - public Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) + public ByteArrayBlock readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) { int positionCount = sliceInput.readInt(); @@ -105,16 +105,4 @@ else if (packed != -1) { // At least one non-null } return new ByteArrayBlock(0, positionCount, valueIsNull, values); } - - private Slice getValuesSlice(Block block) - { - if (block instanceof ByteArrayBlock) { - return ((ByteArrayBlock) block).getValuesSlice(); - } - if (block instanceof ByteArrayBlockBuilder) { - return ((ByteArrayBlockBuilder) block).getValuesSlice(); - } - - throw new IllegalArgumentException("Unexpected block type " + block.getClass().getSimpleName()); - } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ColumnarRow.java b/core/trino-spi/src/main/java/io/trino/spi/block/ColumnarRow.java index 44f51b1e9e30..d88feeb95c39 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ColumnarRow.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ColumnarRow.java @@ -17,6 +17,7 @@ import java.util.List; +import static io.trino.spi.block.DictionaryId.randomDictionaryId; import static java.util.Objects.requireNonNull; public final class ColumnarRow @@ -103,11 +104,12 @@ private static ColumnarRow toColumnarRowFromDictionaryWithoutNulls(DictionaryBlo Block[] fields = new Block[columnarRow.getFieldCount()]; for (int i = 0; i < fields.length; i++) { // Reuse the dictionary ids array directly since no nulls are present - fields[i] = new DictionaryBlock( + fields[i] = DictionaryBlock.createInternal( dictionaryBlock.getRawIdsOffset(), dictionaryBlock.getPositionCount(), columnarRow.getField(i), - dictionaryBlock.getRawIds()); + dictionaryBlock.getRawIds(), + randomDictionaryId()); } return new ColumnarRow(dictionaryBlock.getPositionCount(), null, fields); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/DictionaryBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/DictionaryBlock.java index dd28bbe541c6..a110d1897d1e 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/DictionaryBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/DictionaryBlock.java @@ -14,7 +14,6 @@ package io.trino.spi.block; import io.airlift.slice.Slice; -import io.airlift.slice.SliceOutput; import java.util.ArrayList; import java.util.Arrays; @@ -40,7 +39,7 @@ public class DictionaryBlock private static final int NULL_NOT_FOUND = -1; private final int positionCount; - private final Block dictionary; + private final ValueBlock dictionary; private final int idsOffset; private final int[] ids; private final long retainedSizeInBytes; @@ -54,7 +53,7 @@ public class DictionaryBlock public static Block create(int positionCount, Block dictionary, int[] ids) { - return createInternal(positionCount, dictionary, ids, randomDictionaryId()); + return createInternal(0, positionCount, dictionary, ids, randomDictionaryId()); } /** @@ -62,16 +61,16 @@ public static Block create(int positionCount, Block dictionary, int[] ids) */ public static Block createProjectedDictionaryBlock(int positionCount, Block dictionary, int[] ids, DictionaryId dictionarySourceId) { - return createInternal(positionCount, dictionary, ids, dictionarySourceId); + return createInternal(0, positionCount, dictionary, ids, dictionarySourceId); } - private static Block createInternal(int positionCount, Block dictionary, int[] ids, DictionaryId dictionarySourceId) + static Block createInternal(int idsOffset, int positionCount, Block dictionary, int[] ids, DictionaryId dictionarySourceId) { if (positionCount == 0) { return dictionary.copyRegion(0, 0); } if (positionCount == 1) { - return dictionary.getRegion(ids[0], 1); + return dictionary.getRegion(ids[idsOffset], 1); } // if dictionary is an RLE then this can just be a new RLE @@ -79,25 +78,19 @@ private static Block createInternal(int positionCount, Block dictionary, int[] i return RunLengthEncodedBlock.create(rle.getValue(), positionCount); } - // unwrap dictionary in dictionary - if (dictionary instanceof DictionaryBlock dictionaryBlock) { - int[] newIds = new int[positionCount]; - for (int position = 0; position < positionCount; position++) { - newIds[position] = dictionaryBlock.getId(ids[position]); - } - dictionary = dictionaryBlock.getDictionary(); - dictionarySourceId = randomDictionaryId(); - ids = newIds; + if (dictionary instanceof ValueBlock valueBlock) { + return new DictionaryBlock(idsOffset, positionCount, valueBlock, ids, false, false, dictionarySourceId); } - return new DictionaryBlock(0, positionCount, dictionary, ids, false, false, dictionarySourceId); - } - DictionaryBlock(int idsOffset, int positionCount, Block dictionary, int[] ids) - { - this(idsOffset, positionCount, dictionary, ids, false, false, randomDictionaryId()); + // unwrap dictionary in dictionary + int[] newIds = new int[positionCount]; + for (int position = 0; position < positionCount; position++) { + newIds[position] = dictionary.getUnderlyingValuePosition(ids[idsOffset + position]); + } + return new DictionaryBlock(0, positionCount, dictionary.getUnderlyingValueBlock(), newIds, false, false, randomDictionaryId()); } - private DictionaryBlock(int idsOffset, int positionCount, Block dictionary, int[] ids, boolean dictionaryIsCompacted, boolean isSequentialIds, DictionaryId dictionarySourceId) + private DictionaryBlock(int idsOffset, int positionCount, ValueBlock dictionary, int[] ids, boolean dictionaryIsCompacted, boolean isSequentialIds, DictionaryId dictionarySourceId) { requireNonNull(dictionary, "dictionary is null"); requireNonNull(ids, "ids is null"); @@ -130,12 +123,12 @@ private DictionaryBlock(int idsOffset, int positionCount, Block dictionary, int[ this.isSequentialIds = isSequentialIds; } - int[] getRawIds() + public int[] getRawIds() { return ids; } - int getRawIdsOffset() + public int getRawIdsOffset() { return idsOffset; } @@ -176,12 +169,6 @@ public Slice getSlice(int position, int offset, int length) return dictionary.getSlice(getId(position), offset, length); } - @Override - public void writeSliceTo(int position, int offset, int length, SliceOutput output) - { - dictionary.writeSliceTo(getId(position), offset, length, output); - } - @Override public T getObject(int position, Class clazz) { @@ -189,37 +176,7 @@ public T getObject(int position, Class clazz) } @Override - public boolean bytesEqual(int position, int offset, Slice otherSlice, int otherOffset, int length) - { - return dictionary.bytesEqual(getId(position), offset, otherSlice, otherOffset, length); - } - - @Override - public int bytesCompare(int position, int offset, int length, Slice otherSlice, int otherOffset, int otherLength) - { - return dictionary.bytesCompare(getId(position), offset, length, otherSlice, otherOffset, otherLength); - } - - @Override - public boolean equals(int position, int offset, Block otherBlock, int otherPosition, int otherOffset, int length) - { - return dictionary.equals(getId(position), offset, otherBlock, otherPosition, otherOffset, length); - } - - @Override - public long hash(int position, int offset, int length) - { - return dictionary.hash(getId(position), offset, length); - } - - @Override - public int compareTo(int leftPosition, int leftOffset, int leftLength, Block rightBlock, int rightPosition, int rightOffset, int rightLength) - { - return dictionary.compareTo(getId(leftPosition), leftOffset, leftLength, rightBlock, rightPosition, rightOffset, rightLength); - } - - @Override - public Block getSingleValueBlock(int position) + public ValueBlock getSingleValueBlock(int position) { return dictionary.getSingleValueBlock(getId(position)); } @@ -431,7 +388,7 @@ public Block copyPositions(int[] positions, int offset, int length) } newIds[i] = newId; } - Block compactDictionary = dictionary.copyPositions(positionsToCopy.elements(), 0, positionsToCopy.size()); + ValueBlock compactDictionary = dictionary.copyPositions(positionsToCopy.elements(), 0, positionsToCopy.size()); if (positionsToCopy.size() == length) { // discovered that all positions are unique, so return the unwrapped underlying dictionary directly return compactDictionary; @@ -534,7 +491,7 @@ public Block copyWithAppendedNull() { int desiredLength = idsOffset + positionCount + 1; int[] newIds = Arrays.copyOf(ids, desiredLength); - Block newDictionary = dictionary; + ValueBlock newDictionary = dictionary; int nullIndex = NULL_NOT_FOUND; @@ -569,29 +526,24 @@ public String toString() } @Override - public boolean isLoaded() + public final List getChildren() { - return dictionary.isLoaded(); + return singletonList(getDictionary()); } @Override - public Block getLoadedBlock() + public ValueBlock getUnderlyingValueBlock() { - Block loadedDictionary = dictionary.getLoadedBlock(); - - if (loadedDictionary == dictionary) { - return this; - } - return new DictionaryBlock(idsOffset, getPositionCount(), loadedDictionary, ids, false, false, randomDictionaryId()); + return dictionary; } @Override - public final List getChildren() + public int getUnderlyingValuePosition(int position) { - return singletonList(getDictionary()); + return getId(position); } - public Block getDictionary() + public ValueBlock getDictionary() { return dictionary; } @@ -675,7 +627,7 @@ public DictionaryBlock compact() newIds[i] = newId; } try { - Block compactDictionary = dictionary.copyPositions(dictionaryPositionsToCopy.elements(), 0, dictionaryPositionsToCopy.size()); + ValueBlock compactDictionary = dictionary.copyPositions(dictionaryPositionsToCopy.elements(), 0, dictionaryPositionsToCopy.size()); return new DictionaryBlock( 0, positionCount, @@ -736,7 +688,7 @@ public static List compactRelatedBlocks(List b } try { - Block compactDictionary = dictionaryBlock.getDictionary().copyPositions(dictionaryPositionsToCopy, 0, numberOfIndexes); + ValueBlock compactDictionary = dictionaryBlock.getDictionary().copyPositions(dictionaryPositionsToCopy, 0, numberOfIndexes); outputDictionaryBlocks.add(new DictionaryBlock( 0, positionCount, diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12Block.java b/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12Block.java index b716a5a6935a..92ef4cf88c30 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12Block.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12Block.java @@ -29,7 +29,7 @@ import static io.trino.spi.block.BlockUtil.ensureCapacity; public class Fixed12Block - implements Block + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(Fixed12Block.class); public static final int FIXED12_BYTES = Long.BYTES + Integer.BYTES; @@ -127,12 +127,11 @@ public int getPositionCount() @Override public long getLong(int position, int offset) { - checkReadablePosition(this, position); if (offset != 0) { // If needed, we can add support for offset 4 throw new IllegalArgumentException("offset must be 0"); } - return decodeFixed12First(values, position + positionOffset); + return getFixed12First(position); } @Override @@ -151,6 +150,17 @@ public int getInt(int position, int offset) throw new IllegalArgumentException("offset must be 0, 4, or 8"); } + public long getFixed12First(int position) + { + checkReadablePosition(this, position); + return decodeFixed12First(values, position + positionOffset); + } + + public int getFixed12Second(int position) + { + return decodeFixed12Second(values, position + positionOffset); + } + @Override public boolean mayHaveNull() { @@ -165,7 +175,7 @@ public boolean isNull(int position) } @Override - public Block getSingleValueBlock(int position) + public Fixed12Block getSingleValueBlock(int position) { checkReadablePosition(this, position); int index = (position + positionOffset) * 3; @@ -177,7 +187,7 @@ public Block getSingleValueBlock(int position) } @Override - public Block copyPositions(int[] positions, int offset, int length) + public Fixed12Block copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -202,7 +212,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int positionOffset, int length) + public Fixed12Block getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -210,7 +220,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public Block copyRegion(int positionOffset, int length) + public Fixed12Block copyRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -231,13 +241,19 @@ public String getEncodingName() } @Override - public Block copyWithAppendedNull() + public Fixed12Block copyWithAppendedNull() { boolean[] newValueIsNull = copyIsNullAndAppendNull(valueIsNull, positionOffset, positionCount); int[] newValues = ensureCapacity(values, (positionOffset + positionCount + 1) * 3); return new Fixed12Block(positionOffset, positionCount + 1, newValueIsNull, newValues); } + @Override + public Fixed12Block getUnderlyingValueBlock() + { + return this; + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12BlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12BlockBuilder.java index 1ad36b3fe132..f0d9e278510f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12BlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12BlockBuilder.java @@ -90,6 +90,12 @@ public Block build() if (!hasNonNullValue) { return RunLengthEncodedBlock.create(NULL_VALUE_BLOCK, positionCount); } + return buildValueBlock(); + } + + @Override + public Fixed12Block buildValueBlock() + { return new Fixed12Block(0, positionCount, hasNullValue ? valueIsNull : null, values); } @@ -149,9 +155,4 @@ public String toString() sb.append('}'); return sb.toString(); } - - int[] getRawValues() - { - return values; - } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12BlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12BlockEncoding.java index e33239743cf7..131837f74c86 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12BlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Fixed12BlockEncoding.java @@ -33,30 +33,23 @@ public String getName() @Override public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceOutput, Block block) { - int positionCount = block.getPositionCount(); + Fixed12Block fixed12Block = (Fixed12Block) block; + int positionCount = fixed12Block.getPositionCount(); sliceOutput.appendInt(positionCount); - encodeNullsAsBits(sliceOutput, block); + encodeNullsAsBits(sliceOutput, fixed12Block); - if (!block.mayHaveNull()) { - if (block instanceof Fixed12Block valueBlock) { - sliceOutput.writeInts(valueBlock.getRawValues(), valueBlock.getPositionOffset() * 3, valueBlock.getPositionCount() * 3); - } - else if (block instanceof Fixed12BlockBuilder blockBuilder) { - sliceOutput.writeInts(blockBuilder.getRawValues(), 0, blockBuilder.getPositionCount() * 3); - } - else { - throw new IllegalArgumentException("Unexpected block type " + block.getClass().getSimpleName()); - } + if (!fixed12Block.mayHaveNull()) { + sliceOutput.writeInts(fixed12Block.getRawValues(), fixed12Block.getPositionOffset() * 3, fixed12Block.getPositionCount() * 3); } else { int[] valuesWithoutNull = new int[positionCount * 3]; int nonNullPositionCount = 0; for (int i = 0; i < positionCount; i++) { - valuesWithoutNull[nonNullPositionCount] = block.getInt(i, 0); - valuesWithoutNull[nonNullPositionCount + 1] = block.getInt(i, 4); - valuesWithoutNull[nonNullPositionCount + 2] = block.getInt(i, 8); - if (!block.isNull(i)) { + valuesWithoutNull[nonNullPositionCount] = fixed12Block.getInt(i, 0); + valuesWithoutNull[nonNullPositionCount + 1] = fixed12Block.getInt(i, 4); + valuesWithoutNull[nonNullPositionCount + 2] = fixed12Block.getInt(i, 8); + if (!fixed12Block.isNull(i)) { nonNullPositionCount += 3; } } @@ -67,7 +60,7 @@ else if (block instanceof Fixed12BlockBuilder blockBuilder) { } @Override - public Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) + public Fixed12Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) { int positionCount = sliceInput.readInt(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlock.java index 311fd731982f..641c23ac4d9c 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlock.java @@ -13,6 +13,7 @@ */ package io.trino.spi.block; +import io.trino.spi.type.Int128; import jakarta.annotation.Nullable; import java.util.Optional; @@ -29,7 +30,7 @@ import static io.trino.spi.block.BlockUtil.ensureCapacity; public class Int128ArrayBlock - implements Block + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(Int128ArrayBlock.class); public static final int INT128_BYTES = Long.BYTES + Long.BYTES; @@ -127,16 +128,34 @@ public int getPositionCount() @Override public long getLong(int position, int offset) { - checkReadablePosition(this, position); if (offset == 0) { - return values[(position + positionOffset) * 2]; + return getInt128High(position); } if (offset == 8) { - return values[((position + positionOffset) * 2) + 1]; + return getInt128Low(position); } throw new IllegalArgumentException("offset must be 0 or 8"); } + public Int128 getInt128(int position) + { + checkReadablePosition(this, position); + int offset = (position + positionOffset) * 2; + return Int128.valueOf(values[offset], values[offset + 1]); + } + + public long getInt128High(int position) + { + checkReadablePosition(this, position); + return values[(position + positionOffset) * 2]; + } + + public long getInt128Low(int position) + { + checkReadablePosition(this, position); + return values[((position + positionOffset) * 2) + 1]; + } + @Override public boolean mayHaveNull() { @@ -151,7 +170,7 @@ public boolean isNull(int position) } @Override - public Block getSingleValueBlock(int position) + public Int128ArrayBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); return new Int128ArrayBlock( @@ -164,7 +183,7 @@ public Block getSingleValueBlock(int position) } @Override - public Block copyPositions(int[] positions, int offset, int length) + public Int128ArrayBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -186,7 +205,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int positionOffset, int length) + public Int128ArrayBlock getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -194,7 +213,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public Block copyRegion(int positionOffset, int length) + public Int128ArrayBlock copyRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -215,13 +234,19 @@ public String getEncodingName() } @Override - public Block copyWithAppendedNull() + public Int128ArrayBlock copyWithAppendedNull() { boolean[] newValueIsNull = copyIsNullAndAppendNull(valueIsNull, positionOffset, positionCount); long[] newValues = ensureCapacity(values, (positionOffset + positionCount + 1) * 2); return new Int128ArrayBlock(positionOffset, positionCount + 1, newValueIsNull, newValues); } + @Override + public Int128ArrayBlock getUnderlyingValueBlock() + { + return this; + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockBuilder.java index a3b7dc78dff1..f22ae8951fea 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockBuilder.java @@ -90,6 +90,12 @@ public Block build() if (!hasNonNullValue) { return RunLengthEncodedBlock.create(NULL_VALUE_BLOCK, positionCount); } + return buildValueBlock(); + } + + @Override + public Int128ArrayBlock buildValueBlock() + { return new Int128ArrayBlock(0, positionCount, hasNullValue ? valueIsNull : null, values); } @@ -149,9 +155,4 @@ public String toString() sb.append('}'); return sb.toString(); } - - long[] getRawValues() - { - return values; - } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockEncoding.java index 889d7814716c..78e8191202e5 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/Int128ArrayBlockEncoding.java @@ -33,29 +33,22 @@ public String getName() @Override public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceOutput, Block block) { - int positionCount = block.getPositionCount(); + Int128ArrayBlock int128ArrayBlock = (Int128ArrayBlock) block; + int positionCount = int128ArrayBlock.getPositionCount(); sliceOutput.appendInt(positionCount); - encodeNullsAsBits(sliceOutput, block); + encodeNullsAsBits(sliceOutput, int128ArrayBlock); - if (!block.mayHaveNull()) { - if (block instanceof Int128ArrayBlock valueBlock) { - sliceOutput.writeLongs(valueBlock.getRawValues(), valueBlock.getPositionOffset() * 2, valueBlock.getPositionCount() * 2); - } - else if (block instanceof Int128ArrayBlockBuilder blockBuilder) { - sliceOutput.writeLongs(blockBuilder.getRawValues(), 0, blockBuilder.getPositionCount() * 2); - } - else { - throw new IllegalArgumentException("Unexpected block type " + block.getClass().getSimpleName()); - } + if (!int128ArrayBlock.mayHaveNull()) { + sliceOutput.writeLongs(int128ArrayBlock.getRawValues(), int128ArrayBlock.getPositionOffset() * 2, int128ArrayBlock.getPositionCount() * 2); } else { long[] valuesWithoutNull = new long[positionCount * 2]; int nonNullPositionCount = 0; for (int i = 0; i < positionCount; i++) { - valuesWithoutNull[nonNullPositionCount] = block.getLong(i, 0); - valuesWithoutNull[nonNullPositionCount + 1] = block.getLong(i, 8); - if (!block.isNull(i)) { + valuesWithoutNull[nonNullPositionCount] = int128ArrayBlock.getInt128High(i); + valuesWithoutNull[nonNullPositionCount + 1] = int128ArrayBlock.getInt128Low(i); + if (!int128ArrayBlock.isNull(i)) { nonNullPositionCount += 2; } } @@ -66,7 +59,7 @@ else if (block instanceof Int128ArrayBlockBuilder blockBuilder) { } @Override - public Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) + public Int128ArrayBlock readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) { int positionCount = sliceInput.readInt(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlock.java index 2160d585d96c..2aa843d23a29 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlock.java @@ -30,7 +30,7 @@ import static io.trino.spi.block.BlockUtil.ensureCapacity; public class IntArrayBlock - implements Block + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(IntArrayBlock.class); public static final int SIZE_IN_BYTES_PER_POSITION = Integer.BYTES + Byte.BYTES; @@ -127,10 +127,15 @@ public int getPositionCount() @Override public int getInt(int position, int offset) { - checkReadablePosition(this, position); if (offset != 0) { throw new IllegalArgumentException("offset must be zero"); } + return getInt(position); + } + + public int getInt(int position) + { + checkReadablePosition(this, position); return values[position + arrayOffset]; } @@ -148,7 +153,7 @@ public boolean isNull(int position) } @Override - public Block getSingleValueBlock(int position) + public IntArrayBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); return new IntArrayBlock( @@ -159,7 +164,7 @@ public Block getSingleValueBlock(int position) } @Override - public Block copyPositions(int[] positions, int offset, int length) + public IntArrayBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -180,7 +185,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int positionOffset, int length) + public IntArrayBlock getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -188,7 +193,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public Block copyRegion(int positionOffset, int length) + public IntArrayBlock copyRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -209,7 +214,7 @@ public String getEncodingName() } @Override - public Block copyWithAppendedNull() + public IntArrayBlock copyWithAppendedNull() { boolean[] newValueIsNull = copyIsNullAndAppendNull(valueIsNull, arrayOffset, positionCount); int[] newValues = ensureCapacity(values, arrayOffset + positionCount + 1); @@ -217,6 +222,12 @@ public Block copyWithAppendedNull() return new IntArrayBlock(arrayOffset, positionCount + 1, newValueIsNull, newValues); } + @Override + public IntArrayBlock getUnderlyingValueBlock() + { + return this; + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockBuilder.java index 52d8ae115b0c..bf124103418b 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockBuilder.java @@ -89,6 +89,12 @@ public Block build() if (!hasNonNullValue) { return RunLengthEncodedBlock.create(NULL_VALUE_BLOCK, positionCount); } + return buildValueBlock(); + } + + @Override + public IntArrayBlock buildValueBlock() + { return new IntArrayBlock(0, positionCount, hasNullValue ? valueIsNull : null, values); } @@ -148,9 +154,4 @@ public String toString() sb.append('}'); return sb.toString(); } - - int[] getRawValues() - { - return values; - } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockEncoding.java index ffcf3b87060c..408475020e9a 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/IntArrayBlockEncoding.java @@ -35,28 +35,21 @@ public String getName() @Override public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceOutput, Block block) { - int positionCount = block.getPositionCount(); + IntArrayBlock intArrayBlock = (IntArrayBlock) block; + int positionCount = intArrayBlock.getPositionCount(); sliceOutput.appendInt(positionCount); - encodeNullsAsBits(sliceOutput, block); + encodeNullsAsBits(sliceOutput, intArrayBlock); - if (!block.mayHaveNull()) { - if (block instanceof IntArrayBlock valueBlock) { - sliceOutput.writeInts(valueBlock.getRawValues(), valueBlock.getRawValuesOffset(), valueBlock.getPositionCount()); - } - else if (block instanceof IntArrayBlockBuilder blockBuilder) { - sliceOutput.writeInts(blockBuilder.getRawValues(), 0, blockBuilder.getPositionCount()); - } - else { - throw new IllegalArgumentException("Unexpected block type " + block.getClass().getSimpleName()); - } + if (!intArrayBlock.mayHaveNull()) { + sliceOutput.writeInts(intArrayBlock.getRawValues(), intArrayBlock.getRawValuesOffset(), intArrayBlock.getPositionCount()); } else { int[] valuesWithoutNull = new int[positionCount]; int nonNullPositionCount = 0; for (int i = 0; i < positionCount; i++) { - valuesWithoutNull[nonNullPositionCount] = block.getInt(i, 0); - if (!block.isNull(i)) { + valuesWithoutNull[nonNullPositionCount] = intArrayBlock.getInt(i); + if (!intArrayBlock.isNull(i)) { nonNullPositionCount++; } } @@ -67,7 +60,7 @@ else if (block instanceof IntArrayBlockBuilder blockBuilder) { } @Override - public Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) + public IntArrayBlock readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) { int positionCount = sliceInput.readInt(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlock.java index dc6fb5f00ddb..bf6c6515903c 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlock.java @@ -14,7 +14,6 @@ package io.trino.spi.block; import io.airlift.slice.Slice; -import io.airlift.slice.SliceOutput; import jakarta.annotation.Nullable; import java.util.ArrayList; @@ -87,12 +86,6 @@ public Slice getSlice(int position, int offset, int length) return getBlock().getSlice(position, offset, length); } - @Override - public void writeSliceTo(int position, int offset, int length, SliceOutput output) - { - getBlock().writeSliceTo(position, offset, length, output); - } - @Override public T getObject(int position, Class clazz) { @@ -100,56 +93,7 @@ public T getObject(int position, Class clazz) } @Override - public boolean bytesEqual(int position, int offset, Slice otherSlice, int otherOffset, int length) - { - return getBlock().bytesEqual(position, offset, otherSlice, otherOffset, length); - } - - @Override - public int bytesCompare(int position, int offset, int length, Slice otherSlice, int otherOffset, int otherLength) - { - return getBlock().bytesCompare( - position, - offset, - length, - otherSlice, - otherOffset, - otherLength); - } - - @Override - public boolean equals(int position, int offset, Block otherBlock, int otherPosition, int otherOffset, int length) - { - return getBlock().equals( - position, - offset, - otherBlock, - otherPosition, - otherOffset, - length); - } - - @Override - public long hash(int position, int offset, int length) - { - return getBlock().hash(position, offset, length); - } - - @Override - public int compareTo(int leftPosition, int leftOffset, int leftLength, Block rightBlock, int rightPosition, int rightOffset, int rightLength) - { - return getBlock().compareTo( - leftPosition, - leftOffset, - leftLength, - rightBlock, - rightPosition, - rightOffset, - rightLength); - } - - @Override - public Block getSingleValueBlock(int position) + public ValueBlock getSingleValueBlock(int position) { return getBlock().getSingleValueBlock(position); } @@ -291,6 +235,18 @@ public Block getLoadedBlock() return lazyData.getFullyLoadedBlock(); } + @Override + public ValueBlock getUnderlyingValueBlock() + { + return getBlock().getUnderlyingValueBlock(); + } + + @Override + public int getUnderlyingValuePosition(int position) + { + return getBlock().getUnderlyingValuePosition(position); + } + public static void listenForLoads(Block block, Consumer listener) { requireNonNull(block, "block is null"); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlockEncoding.java index a472cab833f3..99a3df02b65a 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/LazyBlockEncoding.java @@ -23,8 +23,6 @@ public class LazyBlockEncoding { public static final String NAME = "LAZY"; - public LazyBlockEncoding() {} - @Override public String getName() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlock.java index abba78907863..e0b39437b01a 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlock.java @@ -30,7 +30,7 @@ import static java.lang.Math.toIntExact; public class LongArrayBlock - implements Block + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(LongArrayBlock.class); public static final int SIZE_IN_BYTES_PER_POSITION = Long.BYTES + Byte.BYTES; @@ -127,10 +127,15 @@ public int getPositionCount() @Override public long getLong(int position, int offset) { - checkReadablePosition(this, position); if (offset != 0) { throw new IllegalArgumentException("offset must be zero"); } + return getLong(position); + } + + public long getLong(int position) + { + checkReadablePosition(this, position); return values[position + arrayOffset]; } @@ -194,7 +199,7 @@ public boolean isNull(int position) } @Override - public Block getSingleValueBlock(int position) + public LongArrayBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); return new LongArrayBlock( @@ -205,7 +210,7 @@ public Block getSingleValueBlock(int position) } @Override - public Block copyPositions(int[] positions, int offset, int length) + public LongArrayBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -226,7 +231,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int positionOffset, int length) + public LongArrayBlock getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -234,7 +239,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public Block copyRegion(int positionOffset, int length) + public LongArrayBlock copyRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -255,7 +260,7 @@ public String getEncodingName() } @Override - public Block copyWithAppendedNull() + public LongArrayBlock copyWithAppendedNull() { boolean[] newValueIsNull = copyIsNullAndAppendNull(valueIsNull, arrayOffset, positionCount); long[] newValues = ensureCapacity(values, arrayOffset + positionCount + 1); @@ -263,6 +268,12 @@ public Block copyWithAppendedNull() return new LongArrayBlock(arrayOffset, positionCount + 1, newValueIsNull, newValues); } + @Override + public LongArrayBlock getUnderlyingValueBlock() + { + return this; + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockBuilder.java index eaa16a21057b..09a530971ac1 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockBuilder.java @@ -89,6 +89,12 @@ public Block build() if (!hasNonNullValue) { return RunLengthEncodedBlock.create(NULL_VALUE_BLOCK, positionCount); } + return buildValueBlock(); + } + + @Override + public LongArrayBlock buildValueBlock() + { return new LongArrayBlock(0, positionCount, hasNullValue ? valueIsNull : null, values); } @@ -148,9 +154,4 @@ public String toString() sb.append('}'); return sb.toString(); } - - long[] getRawValues() - { - return values; - } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockEncoding.java index 0d6ce7d14679..5167fca68087 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/LongArrayBlockEncoding.java @@ -35,28 +35,21 @@ public String getName() @Override public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceOutput, Block block) { - int positionCount = block.getPositionCount(); + LongArrayBlock longArrayBlock = (LongArrayBlock) block; + int positionCount = longArrayBlock.getPositionCount(); sliceOutput.appendInt(positionCount); - encodeNullsAsBits(sliceOutput, block); + encodeNullsAsBits(sliceOutput, longArrayBlock); - if (!block.mayHaveNull()) { - if (block instanceof LongArrayBlock valueBlock) { - sliceOutput.writeLongs(valueBlock.getRawValues(), valueBlock.getRawValuesOffset(), valueBlock.getPositionCount()); - } - else if (block instanceof LongArrayBlockBuilder blockBuilder) { - sliceOutput.writeLongs(blockBuilder.getRawValues(), 0, blockBuilder.getPositionCount()); - } - else { - throw new IllegalArgumentException("Unexpected block type " + block.getClass().getSimpleName()); - } + if (!longArrayBlock.mayHaveNull()) { + sliceOutput.writeLongs(longArrayBlock.getRawValues(), longArrayBlock.getRawValuesOffset(), longArrayBlock.getPositionCount()); } else { long[] valuesWithoutNull = new long[positionCount]; int nonNullPositionCount = 0; for (int i = 0; i < positionCount; i++) { - valuesWithoutNull[nonNullPositionCount] = block.getLong(i, 0); - if (!block.isNull(i)) { + valuesWithoutNull[nonNullPositionCount] = longArrayBlock.getLong(i); + if (!longArrayBlock.isNull(i)) { nonNullPositionCount++; } } @@ -67,7 +60,7 @@ else if (block instanceof LongArrayBlockBuilder blockBuilder) { } @Override - public Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) + public LongArrayBlock readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) { int positionCount = sliceInput.readInt(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/MapBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/MapBlock.java index da07f1521019..63d53285597d 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/MapBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/MapBlock.java @@ -41,7 +41,7 @@ import static java.util.Objects.requireNonNull; public class MapBlock - implements Block + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(MapBlock.class); @@ -308,7 +308,7 @@ protected void ensureHashTableLoaded() } @Override - public Block copyWithAppendedNull() + public MapBlock copyWithAppendedNull() { boolean[] newMapIsNull = copyIsNullAndAppendNull(mapIsNull, startOffset, getPositionCount()); int[] newOffsets = copyOffsetsAndAppendNull(offsets, startOffset, getPositionCount()); @@ -347,7 +347,7 @@ public String getEncodingName() } @Override - public Block copyPositions(int[] positions, int offset, int length) + public MapBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -407,7 +407,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int position, int length) + public MapBlock getRegion(int position, int length) { int positionCount = getPositionCount(); checkValidRegion(positionCount, position, length); @@ -500,7 +500,7 @@ public final long getPositionsSizeInBytes(boolean[] positions, int selectedMapPo } @Override - public Block copyRegion(int position, int length) + public MapBlock copyRegion(int position, int length) { int positionCount = getPositionCount(); checkValidRegion(positionCount, position, length); @@ -541,21 +541,25 @@ public T getObject(int position, Class clazz) if (clazz != SqlMap.class) { throw new IllegalArgumentException("clazz must be SqlMap.class"); } - checkReadablePosition(this, position); + return clazz.cast(getMap(position)); + } + public SqlMap getMap(int position) + { + checkReadablePosition(this, position); int startEntryOffset = getOffset(position); int endEntryOffset = getOffset(position + 1); - return clazz.cast(new SqlMap( + return new SqlMap( mapType, keyBlock, valueBlock, new SqlMap.HashTableSupplier(this), startEntryOffset, - (endEntryOffset - startEntryOffset))); + (endEntryOffset - startEntryOffset)); } @Override - public Block getSingleValueBlock(int position) + public MapBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); @@ -611,6 +615,12 @@ public boolean isNull(int position) return mapIsNull != null && mapIsNull[position + startOffset]; } + @Override + public MapBlock getUnderlyingValueBlock() + { + return this; + } + // only visible for testing public boolean isHashTablesPresent() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/MapBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/MapBlockBuilder.java index ba0f9e819e72..477ae48f4ce1 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/MapBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/MapBlockBuilder.java @@ -166,6 +166,12 @@ private void entryAdded(boolean isNull) @Override public Block build() + { + return buildValueBlock(); + } + + @Override + public MapBlock buildValueBlock() { if (currentEntryOpened) { throw new IllegalStateException("Current entry must be closed before the block can be built"); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/RowBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/RowBlock.java index 90efc72c22cb..29240dc835f3 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/RowBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/RowBlock.java @@ -36,7 +36,7 @@ import static java.util.Objects.requireNonNull; public class RowBlock - implements Block + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(RowBlock.class); private final int numFields; @@ -55,7 +55,7 @@ public class RowBlock /** * Create a row block directly from columnar nulls and field blocks. */ - public static Block fromFieldBlocks(int positionCount, Optional rowIsNullOptional, Block[] fieldBlocks) + public static RowBlock fromFieldBlocks(int positionCount, Optional rowIsNullOptional, Block[] fieldBlocks) { boolean[] rowIsNull = rowIsNullOptional.orElse(null); int[] fieldBlockOffsets = null; @@ -256,7 +256,7 @@ public Block getLoadedBlock() } @Override - public Block copyWithAppendedNull() + public RowBlock copyWithAppendedNull() { boolean[] newRowIsNull = copyIsNullAndAppendNull(rowIsNull, startOffset, getPositionCount()); @@ -305,7 +305,7 @@ public String getEncodingName() } @Override - public Block copyPositions(int[] positions, int offset, int length) + public RowBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -354,7 +354,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int position, int length) + public RowBlock getRegion(int position, int length) { int positionCount = getPositionCount(); checkValidRegion(positionCount, position, length); @@ -492,7 +492,7 @@ private long getSpecificPositionsSizeInBytes(boolean[] positions, int selectedRo } @Override - public Block copyRegion(int position, int length) + public RowBlock copyRegion(int position, int length) { int positionCount = getPositionCount(); checkValidRegion(positionCount, position, length); @@ -533,13 +533,17 @@ public T getObject(int position, Class clazz) if (clazz != SqlRow.class) { throw new IllegalArgumentException("clazz must be SqlRow.class"); } - checkReadablePosition(this, position); + return clazz.cast(getRow(position)); + } - return clazz.cast(new SqlRow(getFieldBlockOffset(position), fieldBlocks)); + public SqlRow getRow(int position) + { + checkReadablePosition(this, position); + return new SqlRow(getFieldBlockOffset(position), fieldBlocks); } @Override - public Block getSingleValueBlock(int position) + public RowBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); @@ -583,4 +587,10 @@ public boolean isNull(int position) } return rowIsNull[position + startOffset]; } + + @Override + public RowBlock getUnderlyingValueBlock() + { + return this; + } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/RowBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/RowBlockBuilder.java index 77e2d70d5825..cbbbb65e869b 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/RowBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/RowBlockBuilder.java @@ -164,6 +164,16 @@ public Block build() if (!hasNonNullRow) { return nullRle(positionCount); } + return buildValueBlock(); + } + + @Override + public RowBlock buildValueBlock() + { + if (currentEntryOpened) { + throw new IllegalStateException("Current entry must be closed before the block can be built"); + } + Block[] fieldBlocks = new Block[fieldBlockBuilders.length]; for (int i = 0; i < fieldBlockBuilders.length; i++) { fieldBlocks[i] = fieldBlockBuilders[i].build(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/RunLengthEncodedBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/RunLengthEncodedBlock.java index 996d35398c87..8b4620f44fe0 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/RunLengthEncodedBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/RunLengthEncodedBlock.java @@ -14,7 +14,6 @@ package io.trino.spi.block; import io.airlift.slice.Slice; -import io.airlift.slice.SliceOutput; import io.trino.spi.predicate.Utils; import io.trino.spi.type.Type; import jakarta.annotation.Nullable; @@ -59,13 +58,25 @@ public static Block create(Block value, int positionCount) if (positionCount == 1) { return value; } - return new RunLengthEncodedBlock(value, positionCount); + + if (value instanceof ValueBlock valueBlock) { + return new RunLengthEncodedBlock(valueBlock, positionCount); + } + + // unwrap the value + ValueBlock valueBlock = value.getUnderlyingValueBlock(); + int valuePosition = value.getUnderlyingValuePosition(0); + if (valueBlock.getPositionCount() == 1 && valuePosition == 0) { + return new RunLengthEncodedBlock(valueBlock, positionCount); + } + + return new RunLengthEncodedBlock(valueBlock.getRegion(valuePosition, 1), positionCount); } - private final Block value; + private final ValueBlock value; private final int positionCount; - private RunLengthEncodedBlock(Block value, int positionCount) + private RunLengthEncodedBlock(ValueBlock value, int positionCount) { requireNonNull(value, "value is null"); if (positionCount < 0) { @@ -75,24 +86,7 @@ private RunLengthEncodedBlock(Block value, int positionCount) throw new IllegalArgumentException("positionCount must be at least 2"); } - // do not nest an RLE or Dictionary in an RLE - if (value instanceof RunLengthEncodedBlock block) { - this.value = block.getValue(); - } - else if (value instanceof DictionaryBlock block) { - Block dictionary = block.getDictionary(); - int id = block.getId(0); - if (dictionary.getPositionCount() == 1 && id == 0) { - this.value = dictionary; - } - else { - this.value = dictionary.getRegion(id, 1); - } - } - else { - this.value = value; - } - + this.value = value; this.positionCount = positionCount; } @@ -102,7 +96,7 @@ public final List getChildren() return singletonList(value); } - public Block getValue() + public ValueBlock getValue() { return value; } @@ -247,13 +241,6 @@ public Slice getSlice(int position, int offset, int length) return value.getSlice(0, offset, length); } - @Override - public void writeSliceTo(int position, int offset, int length, SliceOutput output) - { - checkReadablePosition(this, position); - value.writeSliceTo(0, offset, length, output); - } - @Override public T getObject(int position, Class clazz) { @@ -262,42 +249,7 @@ public T getObject(int position, Class clazz) } @Override - public boolean bytesEqual(int position, int offset, Slice otherSlice, int otherOffset, int length) - { - checkReadablePosition(this, position); - return value.bytesEqual(0, offset, otherSlice, otherOffset, length); - } - - @Override - public int bytesCompare(int position, int offset, int length, Slice otherSlice, int otherOffset, int otherLength) - { - checkReadablePosition(this, position); - return value.bytesCompare(0, offset, length, otherSlice, otherOffset, otherLength); - } - - @Override - public boolean equals(int position, int offset, Block otherBlock, int otherPosition, int otherOffset, int length) - { - checkReadablePosition(this, position); - return value.equals(0, offset, otherBlock, otherPosition, otherOffset, length); - } - - @Override - public long hash(int position, int offset, int length) - { - checkReadablePosition(this, position); - return value.hash(0, offset, length); - } - - @Override - public int compareTo(int leftPosition, int leftOffset, int leftLength, Block rightBlock, int rightPosition, int rightOffset, int rightLength) - { - checkReadablePosition(this, leftPosition); - return value.compareTo(0, leftOffset, leftLength, rightBlock, rightPosition, rightOffset, rightLength); - } - - @Override - public Block getSingleValueBlock(int position) + public ValueBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); return value; @@ -323,7 +275,7 @@ public Block copyWithAppendedNull() return create(value, positionCount + 1); } - Block dictionary = value.copyWithAppendedNull(); + ValueBlock dictionary = value.copyWithAppendedNull(); int[] ids = new int[positionCount + 1]; ids[positionCount] = 1; return DictionaryBlock.create(ids.length, dictionary, ids); @@ -340,19 +292,14 @@ public String toString() } @Override - public boolean isLoaded() + public ValueBlock getUnderlyingValueBlock() { - return value.isLoaded(); + return value; } @Override - public Block getLoadedBlock() + public int getUnderlyingValuePosition(int position) { - Block loadedValueBlock = value.getLoadedBlock(); - - if (loadedValueBlock == value) { - return this; - } - return create(loadedValueBlock, positionCount); + return 0; } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlock.java index a34ac450556a..3ff986ad69b4 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlock.java @@ -29,7 +29,7 @@ import static io.trino.spi.block.BlockUtil.ensureCapacity; public class ShortArrayBlock - implements Block + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(ShortArrayBlock.class); public static final int SIZE_IN_BYTES_PER_POSITION = Short.BYTES + Byte.BYTES; @@ -126,10 +126,15 @@ public int getPositionCount() @Override public short getShort(int position, int offset) { - checkReadablePosition(this, position); if (offset != 0) { throw new IllegalArgumentException("offset must be zero"); } + return getShort(position); + } + + public short getShort(int position) + { + checkReadablePosition(this, position); return values[position + arrayOffset]; } @@ -147,7 +152,7 @@ public boolean isNull(int position) } @Override - public Block getSingleValueBlock(int position) + public ShortArrayBlock getSingleValueBlock(int position) { checkReadablePosition(this, position); return new ShortArrayBlock( @@ -158,7 +163,7 @@ public Block getSingleValueBlock(int position) } @Override - public Block copyPositions(int[] positions, int offset, int length) + public ShortArrayBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); @@ -179,7 +184,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int positionOffset, int length) + public ShortArrayBlock getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -187,7 +192,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public Block copyRegion(int positionOffset, int length) + public ShortArrayBlock copyRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -208,13 +213,19 @@ public String getEncodingName() } @Override - public Block copyWithAppendedNull() + public ShortArrayBlock copyWithAppendedNull() { boolean[] newValueIsNull = copyIsNullAndAppendNull(valueIsNull, arrayOffset, positionCount); short[] newValues = ensureCapacity(values, arrayOffset + positionCount + 1); return new ShortArrayBlock(arrayOffset, positionCount + 1, newValueIsNull, newValues); } + @Override + public ShortArrayBlock getUnderlyingValueBlock() + { + return this; + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockBuilder.java index aa3db4bcf4b1..ee44b44b6dc2 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockBuilder.java @@ -89,6 +89,12 @@ public Block build() if (!hasNonNullValue) { return RunLengthEncodedBlock.create(NULL_VALUE_BLOCK, positionCount); } + return buildValueBlock(); + } + + @Override + public ShortArrayBlock buildValueBlock() + { return new ShortArrayBlock(0, positionCount, hasNullValue ? valueIsNull : null, values); } @@ -148,9 +154,4 @@ public String toString() sb.append('}'); return sb.toString(); } - - short[] getRawValues() - { - return values; - } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockEncoding.java index 0aa79f278376..15813a428f74 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ShortArrayBlockEncoding.java @@ -35,28 +35,21 @@ public String getName() @Override public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceOutput, Block block) { - int positionCount = block.getPositionCount(); + ShortArrayBlock shortArrayBlock = (ShortArrayBlock) block; + int positionCount = shortArrayBlock.getPositionCount(); sliceOutput.appendInt(positionCount); - encodeNullsAsBits(sliceOutput, block); + encodeNullsAsBits(sliceOutput, shortArrayBlock); - if (!block.mayHaveNull()) { - if (block instanceof ShortArrayBlock valueBlock) { - sliceOutput.writeShorts(valueBlock.getRawValues(), valueBlock.getRawValuesOffset(), valueBlock.getPositionCount()); - } - else if (block instanceof ShortArrayBlockBuilder blockBuilder) { - sliceOutput.writeShorts(blockBuilder.getRawValues(), 0, blockBuilder.getPositionCount()); - } - else { - throw new IllegalArgumentException("Unexpected block type " + block.getClass().getSimpleName()); - } + if (!shortArrayBlock.mayHaveNull()) { + sliceOutput.writeShorts(shortArrayBlock.getRawValues(), shortArrayBlock.getRawValuesOffset(), shortArrayBlock.getPositionCount()); } else { short[] valuesWithoutNull = new short[positionCount]; int nonNullPositionCount = 0; for (int i = 0; i < positionCount; i++) { - valuesWithoutNull[nonNullPositionCount] = block.getShort(i, 0); - if (!block.isNull(i)) { + valuesWithoutNull[nonNullPositionCount] = shortArrayBlock.getShort(i); + if (!shortArrayBlock.isNull(i)) { nonNullPositionCount++; } } @@ -67,7 +60,7 @@ else if (block instanceof ShortArrayBlockBuilder blockBuilder) { } @Override - public Block readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) + public ShortArrayBlock readBlock(BlockEncodingSerde blockEncodingSerde, SliceInput sliceInput) { int positionCount = sliceInput.readInt(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/SqlMap.java b/core/trino-spi/src/main/java/io/trino/spi/block/SqlMap.java index 3a9335a05cbd..29f4bc1988c4 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/SqlMap.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/SqlMap.java @@ -98,6 +98,26 @@ public Block getRawValueBlock() return rawValueBlock; } + public int getUnderlyingKeyPosition(int position) + { + return rawKeyBlock.getUnderlyingValuePosition(offset + position); + } + + public ValueBlock getUnderlyingKeyBlock() + { + return rawKeyBlock.getUnderlyingValueBlock(); + } + + public int getUnderlyingValuePosition(int position) + { + return rawValueBlock.getUnderlyingValuePosition(offset + position); + } + + public ValueBlock getUnderlyingValueBlock() + { + return rawValueBlock.getUnderlyingValueBlock(); + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/SqlRow.java b/core/trino-spi/src/main/java/io/trino/spi/block/SqlRow.java index 2acc019fe596..fcb7e2950e86 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/SqlRow.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/SqlRow.java @@ -82,6 +82,16 @@ public void retainedBytesForEachPart(ObjLongConsumer consumer) consumer.accept(this, INSTANCE_SIZE); } + public int getUnderlyingFieldPosition(int fieldIndex) + { + return fieldBlocks[fieldIndex].getUnderlyingValuePosition(rawIndex); + } + + public ValueBlock getUnderlyingFieldBlock(int fieldIndex) + { + return fieldBlocks[fieldIndex].getUnderlyingValueBlock(); + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/ValueBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/ValueBlock.java new file mode 100644 index 000000000000..f55f5567b20c --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/block/ValueBlock.java @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.block; + +public interface ValueBlock + extends Block +{ + @Override + ValueBlock copyPositions(int[] positions, int offset, int length); + + @Override + ValueBlock getRegion(int positionOffset, int length); + + @Override + ValueBlock copyRegion(int position, int length); + + @Override + ValueBlock copyWithAppendedNull(); + + @Override + default ValueBlock getUnderlyingValueBlock() + { + return this; + } + + @Override + default int getUnderlyingValuePosition(int position) + { + return position; + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlock.java b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlock.java index 9c2d53940554..3a828298091e 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlock.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlock.java @@ -16,7 +16,6 @@ import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; -import io.airlift.slice.XxHash64; import jakarta.annotation.Nullable; import java.util.Optional; @@ -36,7 +35,7 @@ import static io.trino.spi.block.BlockUtil.copyOffsetsAndAppendNull; public class VariableWidthBlock - implements Block + implements ValueBlock { private static final int INSTANCE_SIZE = instanceSize(VariableWidthBlock.class); @@ -214,54 +213,12 @@ public Slice getSlice(int position, int offset, int length) return slice.slice(getPositionOffset(position) + offset, length); } - @Override - public void writeSliceTo(int position, int offset, int length, SliceOutput output) - { - checkReadablePosition(this, position); - output.writeBytes(slice, getPositionOffset(position) + offset, length); - } - - @Override - public boolean equals(int position, int offset, Block otherBlock, int otherPosition, int otherOffset, int length) - { - checkReadablePosition(this, position); - Slice rawSlice = slice; - if (getSliceLength(position) < length) { - return false; - } - return otherBlock.bytesEqual(otherPosition, otherOffset, rawSlice, getPositionOffset(position) + offset, length); - } - - @Override - public boolean bytesEqual(int position, int offset, Slice otherSlice, int otherOffset, int length) - { - checkReadablePosition(this, position); - return slice.equals(getPositionOffset(position) + offset, length, otherSlice, otherOffset, length); - } - - @Override - public long hash(int position, int offset, int length) - { - checkReadablePosition(this, position); - return XxHash64.hash(slice, getPositionOffset(position) + offset, length); - } - - @Override - public int compareTo(int position, int offset, int length, Block otherBlock, int otherPosition, int otherOffset, int otherLength) - { - checkReadablePosition(this, position); - Slice rawSlice = slice; - if (getSliceLength(position) < length) { - throw new IllegalArgumentException("Length longer than value length"); - } - return -otherBlock.bytesCompare(otherPosition, otherOffset, otherLength, rawSlice, getPositionOffset(position) + offset, length); - } - - @Override - public int bytesCompare(int position, int offset, int length, Slice otherSlice, int otherOffset, int otherLength) + public Slice getSlice(int position) { checkReadablePosition(this, position); - return slice.compareTo(getPositionOffset(position) + offset, length, otherSlice, otherOffset, otherLength); + int offset = offsets[position + arrayOffset]; + int length = offsets[position + 1 + arrayOffset] - offset; + return slice.slice(offset, length); } @Override @@ -278,7 +235,7 @@ public boolean isNull(int position) } @Override - public Block getSingleValueBlock(int position) + public VariableWidthBlock getSingleValueBlock(int position) { if (isNull(position)) { return new VariableWidthBlock(0, 1, EMPTY_SLICE, new int[] {0, 0}, new boolean[] {true}); @@ -293,7 +250,7 @@ public Block getSingleValueBlock(int position) } @Override - public Block copyPositions(int[] positions, int offset, int length) + public VariableWidthBlock copyPositions(int[] positions, int offset, int length) { checkArrayRange(positions, offset, length); if (length == 0) { @@ -337,7 +294,7 @@ public Block copyPositions(int[] positions, int offset, int length) } @Override - public Block getRegion(int positionOffset, int length) + public VariableWidthBlock getRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); @@ -345,7 +302,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public Block copyRegion(int positionOffset, int length) + public VariableWidthBlock copyRegion(int positionOffset, int length) { checkValidRegion(getPositionCount(), positionOffset, length); positionOffset += arrayOffset; @@ -367,7 +324,7 @@ public String getEncodingName() } @Override - public Block copyWithAppendedNull() + public VariableWidthBlock copyWithAppendedNull() { boolean[] newValueIsNull = copyIsNullAndAppendNull(valueIsNull, arrayOffset, positionCount); int[] newOffsets = copyOffsetsAndAppendNull(offsets, arrayOffset, positionCount); @@ -375,6 +332,12 @@ public Block copyWithAppendedNull() return new VariableWidthBlock(arrayOffset, positionCount + 1, slice, newOffsets, newValueIsNull); } + @Override + public VariableWidthBlock getUnderlyingValueBlock() + { + return this; + } + @Override public String toString() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockBuilder.java b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockBuilder.java index 201f0029a4e7..6ec063828cf1 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockBuilder.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockBuilder.java @@ -190,6 +190,12 @@ public Block build() if (!hasNonNullValue) { return RunLengthEncodedBlock.create(NULL_VALUE_BLOCK, positions); } + return buildValueBlock(); + } + + @Override + public VariableWidthBlock buildValueBlock() + { return new VariableWidthBlock(0, positions, sliceOutput.slice(), offsets, hasNullValue ? valueIsNull : null); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockEncoding.java b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockEncoding.java index 6218a859aec0..6e8af40a5b44 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockEncoding.java +++ b/core/trino-spi/src/main/java/io/trino/spi/block/VariableWidthBlockEncoding.java @@ -38,7 +38,6 @@ public String getName() @Override public void writeBlock(BlockEncodingSerde blockEncodingSerde, SliceOutput sliceOutput, Block block) { - // The down casts here are safe because it is the block itself the provides this encoding implementation. VariableWidthBlock variableWidthBlock = (VariableWidthBlock) block; int positionCount = variableWidthBlock.getPositionCount(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/InvocationConvention.java b/core/trino-spi/src/main/java/io/trino/spi/function/InvocationConvention.java index 935488f83f41..a494839c5253 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/InvocationConvention.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/InvocationConvention.java @@ -115,6 +115,13 @@ public enum InvocationArgumentConvention * results are undefined. */ BLOCK_POSITION_NOT_NULL(false, 2), + /** + * Argument is passed a ValueBlock followed by the integer position in the block. + * The actual block parameter may be any subtype of ValueBlock, and the scalar function + * adapter will convert the parameter to ValueBlock. If the actual block position + * passed to the function argument is null, the results are undefined. + */ + VALUE_BLOCK_POSITION_NOT_NULL(false, 2), /** * Argument is always an object type. An SQL null will be passed a Java null. */ @@ -125,10 +132,16 @@ public enum InvocationArgumentConvention */ NULL_FLAG(true, 2), /** - * Argument is passed a Block followed by the integer position in the block. The + * Argument is passed a Block followed by the integer position in the block. The * sql value may be null. */ BLOCK_POSITION(true, 2), + /** + * Argument is passed a ValueBlock followed by the integer position in the block. + * The actual block parameter may be any subtype of ValueBlock, and the scalar function + * adapter will convert the parameter to ValueBlock. The sql value may be null. + */ + VALUE_BLOCK_POSITION(true, 2), /** * Argument is passed as a flat slice. The sql value may not be null. */ diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/OperatorMethodHandle.java b/core/trino-spi/src/main/java/io/trino/spi/function/OperatorMethodHandle.java index 1304c318894d..0575a08fe4f5 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/OperatorMethodHandle.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/OperatorMethodHandle.java @@ -15,6 +15,8 @@ import java.lang.invoke.MethodHandle; +import static java.util.Objects.requireNonNull; + public class OperatorMethodHandle { private final InvocationConvention callingConvention; @@ -22,8 +24,8 @@ public class OperatorMethodHandle public OperatorMethodHandle(InvocationConvention callingConvention, MethodHandle methodHandle) { - this.callingConvention = callingConvention; - this.methodHandle = methodHandle; + this.callingConvention = requireNonNull(callingConvention, "callingConvention is null"); + this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); } public InvocationConvention getCallingConvention() diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java index e1de94a16597..27b3e9385181 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java @@ -19,12 +19,14 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; import io.trino.spi.function.InvocationConvention.InvocationReturnConvention; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; import java.util.List; import java.util.Objects; @@ -38,6 +40,8 @@ import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.DEFAULT_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @@ -53,18 +57,43 @@ import static java.lang.invoke.MethodHandles.insertArguments; import static java.lang.invoke.MethodHandles.lookup; import static java.lang.invoke.MethodHandles.permuteArguments; -import static java.lang.invoke.MethodHandles.publicLookup; import static java.lang.invoke.MethodHandles.throwException; import static java.lang.invoke.MethodType.methodType; import static java.util.Objects.requireNonNull; public final class ScalarFunctionAdapter { - private static final MethodHandle IS_NULL_METHOD = lookupIsNullMethod(); - private static final MethodHandle APPEND_NULL_METHOD = lookupAppendNullMethod(); + private static final MethodHandle OBJECT_IS_NULL_METHOD; + private static final MethodHandle APPEND_NULL_METHOD; + private static final MethodHandle BLOCK_IS_NULL_METHOD; + private static final MethodHandle IN_OUT_IS_NULL_METHOD; + private static final MethodHandle GET_UNDERLYING_VALUE_BLOCK_METHOD; + private static final MethodHandle GET_UNDERLYING_VALUE_POSITION_METHOD; + private static final MethodHandle NEW_NEVER_NULL_IS_NULL_EXCEPTION; // This is needed to convert flat arguments to stack types private static final TypeOperators READ_VALUE_TYPE_OPERATORS = new TypeOperators(); + static { + try { + MethodHandles.Lookup lookup = lookup(); + OBJECT_IS_NULL_METHOD = lookup.findStatic(Objects.class, "isNull", methodType(boolean.class, Object.class)); + APPEND_NULL_METHOD = lookup.findVirtual(BlockBuilder.class, "appendNull", methodType(BlockBuilder.class)) + .asType(methodType(void.class, BlockBuilder.class)); + BLOCK_IS_NULL_METHOD = lookup.findVirtual(Block.class, "isNull", methodType(boolean.class, int.class)); + IN_OUT_IS_NULL_METHOD = lookup.findVirtual(InOut.class, "isNull", methodType(boolean.class)); + + GET_UNDERLYING_VALUE_BLOCK_METHOD = lookup().findVirtual(Block.class, "getUnderlyingValueBlock", methodType(ValueBlock.class)); + GET_UNDERLYING_VALUE_POSITION_METHOD = lookup().findVirtual(Block.class, "getUnderlyingValuePosition", methodType(int.class, int.class)); + + NEW_NEVER_NULL_IS_NULL_EXCEPTION = lookup.findConstructor(TrinoException.class, methodType(void.class, ErrorCodeSupplier.class, String.class)) + .bindTo(StandardErrorCode.INVALID_FUNCTION_ARGUMENT) + .bindTo("A never null argument is null"); + } + catch (ReflectiveOperationException e) { + throw new ExceptionInInitializerError(e); + } + } + private ScalarFunctionAdapter() {} /** @@ -136,17 +165,26 @@ private static boolean canAdaptParameter( return switch (actualArgumentConvention) { case NEVER_NULL -> switch (expectedArgumentConvention) { - case BLOCK_POSITION_NOT_NULL, FLAT -> true; + case BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL, FLAT -> true; case BOXED_NULLABLE, NULL_FLAG -> returnConvention != FAIL_ON_NULL; - case BLOCK_POSITION, IN_OUT -> true; // todo only support these if the return convention is nullable + case BLOCK_POSITION, VALUE_BLOCK_POSITION, IN_OUT -> true; // todo only support these if the return convention is nullable case FUNCTION -> throw new IllegalStateException("Unexpected value: " + expectedArgumentConvention); // this is not needed as the case where actual and expected are the same is covered above, // but this means we will get a compile time error if a new convention is added in the future //noinspection DataFlowIssue case NEVER_NULL -> true; }; - case BLOCK_POSITION_NOT_NULL -> expectedArgumentConvention == BLOCK_POSITION && (returnConvention.isNullable() || returnConvention == DEFAULT_ON_NULL); - case BLOCK_POSITION -> expectedArgumentConvention == BLOCK_POSITION_NOT_NULL; + case BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL -> switch (expectedArgumentConvention) { + case BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL -> true; + case BLOCK_POSITION, VALUE_BLOCK_POSITION -> returnConvention.isNullable() || returnConvention == DEFAULT_ON_NULL; + case NEVER_NULL, NULL_FLAG, BOXED_NULLABLE, FLAT, IN_OUT -> false; + case FUNCTION -> throw new IllegalStateException("Unexpected value: " + expectedArgumentConvention); + }; + case BLOCK_POSITION, VALUE_BLOCK_POSITION -> switch (expectedArgumentConvention) { + case BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL, BLOCK_POSITION, VALUE_BLOCK_POSITION -> true; + case NEVER_NULL, NULL_FLAG, BOXED_NULLABLE, FLAT, IN_OUT -> false; + case FUNCTION -> throw new IllegalStateException("Unexpected value: " + expectedArgumentConvention); + }; case BOXED_NULLABLE, NULL_FLAG -> true; case FLAT, IN_OUT -> false; case FUNCTION -> throw new IllegalArgumentException("Unsupported argument convention: " + actualArgumentConvention); @@ -263,6 +301,11 @@ private static MethodHandle adaptParameter( InvocationArgumentConvention expectedArgumentConvention, InvocationReturnConvention returnConvention) { + // For value block, cast specialized parameter to ValueBlock + if (actualArgumentConvention == VALUE_BLOCK_POSITION || actualArgumentConvention == VALUE_BLOCK_POSITION_NOT_NULL && methodHandle.type().parameterType(parameterIndex) != ValueBlock.class) { + methodHandle = methodHandle.asType(methodHandle.type().changeParameterType(parameterIndex, ValueBlock.class)); + } + if (actualArgumentConvention == expectedArgumentConvention) { return methodHandle; } @@ -324,7 +367,7 @@ private static MethodHandle adaptParameter( methodHandle = filterArguments( methodHandle, parameterIndex + 1, - explicitCastArguments(IS_NULL_METHOD, methodType(boolean.class, wrap(parameterType)))); + explicitCastArguments(OBJECT_IS_NULL_METHOD, methodType(boolean.class, wrap(parameterType)))); // 1. Duplicate the argument, so we have two copies of the value // Long, Long => Long @@ -359,88 +402,45 @@ private static MethodHandle adaptParameter( } if (expectedArgumentConvention == BLOCK_POSITION_NOT_NULL) { - if (actualArgumentConvention == BLOCK_POSITION) { - return methodHandle; + if (actualArgumentConvention == VALUE_BLOCK_POSITION_NOT_NULL || actualArgumentConvention == VALUE_BLOCK_POSITION) { + return adaptValueBlockArgumentToBlock(methodHandle, parameterIndex); } - MethodHandle getBlockValue = getBlockValue(argumentType, methodHandle.type().parameterType(parameterIndex)); - if (actualArgumentConvention == NEVER_NULL) { - return collectArguments(methodHandle, parameterIndex, getBlockValue); - } - if (actualArgumentConvention == BOXED_NULLABLE) { - MethodType targetType = getBlockValue.type().changeReturnType(wrap(getBlockValue.type().returnType())); - return collectArguments(methodHandle, parameterIndex, explicitCastArguments(getBlockValue, targetType)); - } - if (actualArgumentConvention == NULL_FLAG) { - // actual method takes value and null flag, so change method handles to not have the flag and always pass false to the actual method - return collectArguments(insertArguments(methodHandle, parameterIndex + 1, false), parameterIndex, getBlockValue); + return adaptParameterToBlockPositionNotNull(methodHandle, parameterIndex, argumentType, actualArgumentConvention, expectedArgumentConvention, returnConvention); + } + + if (expectedArgumentConvention == VALUE_BLOCK_POSITION_NOT_NULL) { + if (actualArgumentConvention == VALUE_BLOCK_POSITION) { + return methodHandle; } + + methodHandle = adaptParameterToBlockPositionNotNull(methodHandle, parameterIndex, argumentType, actualArgumentConvention, expectedArgumentConvention, returnConvention); + methodHandle = methodHandle.asType(methodHandle.type().changeParameterType(parameterIndex, ValueBlock.class)); + return methodHandle; } // caller passes block and position which may contain a null if (expectedArgumentConvention == BLOCK_POSITION) { - MethodHandle getBlockValue = getBlockValue(argumentType, methodHandle.type().parameterType(parameterIndex)); - - if (actualArgumentConvention == NEVER_NULL) { - if (returnConvention != FAIL_ON_NULL) { - // if caller sets the null flag, return null, otherwise invoke target - methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue); - - return guardWithTest( - isBlockPositionNull(methodHandle.type(), parameterIndex), - getNullShortCircuitResult(methodHandle, returnConvention), - methodHandle); - } - - MethodHandle adapter = guardWithTest( - isBlockPositionNull(getBlockValue.type(), 0), - throwTrinoNullArgumentException(getBlockValue.type()), - getBlockValue); - - return collectArguments(methodHandle, parameterIndex, adapter); + // convert ValueBlock argument to Block + if (actualArgumentConvention == VALUE_BLOCK_POSITION || actualArgumentConvention == VALUE_BLOCK_POSITION_NOT_NULL) { + adaptValueBlockArgumentToBlock(methodHandle, parameterIndex); + methodHandle = adaptValueBlockArgumentToBlock(methodHandle, parameterIndex); } - if (actualArgumentConvention == BOXED_NULLABLE) { - getBlockValue = explicitCastArguments(getBlockValue, getBlockValue.type().changeReturnType(wrap(getBlockValue.type().returnType()))); - getBlockValue = guardWithTest( - isBlockPositionNull(getBlockValue.type(), 0), - empty(getBlockValue.type()), - getBlockValue); - methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue); + if (actualArgumentConvention == VALUE_BLOCK_POSITION) { return methodHandle; } - if (actualArgumentConvention == NULL_FLAG) { - // long, boolean => long, Block, int - MethodHandle isNull = isBlockPositionNull(getBlockValue.type(), 0); - methodHandle = collectArguments(methodHandle, parameterIndex + 1, isNull); - - // convert get block value to be null safe - getBlockValue = guardWithTest( - isBlockPositionNull(getBlockValue.type(), 0), - empty(getBlockValue.type()), - getBlockValue); - - // long, Block, int => Block, int, Block, int - methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue); + return adaptParameterToBlockPosition(methodHandle, parameterIndex, argumentType, actualArgumentConvention, expectedArgumentConvention, returnConvention); + } - int[] reorder = IntStream.range(0, methodHandle.type().parameterCount()) - .map(i -> i <= parameterIndex + 1 ? i : i - 2) - .toArray(); - MethodType newType = methodHandle.type().dropParameterTypes(parameterIndex + 2, parameterIndex + 4); - methodHandle = permuteArguments(methodHandle, newType, reorder); - return methodHandle; - } - - if (actualArgumentConvention == BLOCK_POSITION_NOT_NULL) { - if (returnConvention != FAIL_ON_NULL) { - MethodHandle nullReturnValue = getNullShortCircuitResult(methodHandle, returnConvention); - return guardWithTest( - isBlockPositionNull(methodHandle.type(), parameterIndex), - nullReturnValue, - methodHandle); - } + // caller passes value block and position which may contain a null + if (expectedArgumentConvention == VALUE_BLOCK_POSITION) { + if (actualArgumentConvention != BLOCK_POSITION) { + methodHandle = adaptParameterToBlockPosition(methodHandle, parameterIndex, argumentType, actualArgumentConvention, expectedArgumentConvention, returnConvention); } + methodHandle = methodHandle.asType(methodHandle.type().changeParameterType(parameterIndex, ValueBlock.class)); + return methodHandle; } // caller will pass boolean true in the next argument for SQL null @@ -523,7 +523,118 @@ private static MethodHandle adaptParameter( } } - throw new IllegalArgumentException("Cannot convert argument %s to %s with return convention %s".formatted(actualArgumentConvention, expectedArgumentConvention, returnConvention)); + throw unsupportedArgumentAdaptation(actualArgumentConvention, expectedArgumentConvention, returnConvention); + } + + private static MethodHandle adaptParameterToBlockPosition(MethodHandle methodHandle, int parameterIndex, Type argumentType, InvocationArgumentConvention actualArgumentConvention, InvocationArgumentConvention expectedArgumentConvention, InvocationReturnConvention returnConvention) + { + MethodHandle getBlockValue = getBlockValue(argumentType, methodHandle.type().parameterType(parameterIndex)); + if (actualArgumentConvention == NEVER_NULL) { + if (returnConvention != FAIL_ON_NULL) { + // if caller sets the null flag, return null, otherwise invoke target + methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue); + + return guardWithTest( + isBlockPositionNull(methodHandle.type(), parameterIndex), + getNullShortCircuitResult(methodHandle, returnConvention), + methodHandle); + } + + MethodHandle adapter = guardWithTest( + isBlockPositionNull(getBlockValue.type(), 0), + throwTrinoNullArgumentException(getBlockValue.type()), + getBlockValue); + + return collectArguments(methodHandle, parameterIndex, adapter); + } + + if (actualArgumentConvention == BOXED_NULLABLE) { + getBlockValue = explicitCastArguments(getBlockValue, getBlockValue.type().changeReturnType(wrap(getBlockValue.type().returnType()))); + getBlockValue = guardWithTest( + isBlockPositionNull(getBlockValue.type(), 0), + empty(getBlockValue.type()), + getBlockValue); + methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue); + return methodHandle; + } + + if (actualArgumentConvention == NULL_FLAG) { + // long, boolean => long, Block, int + MethodHandle isNull = isBlockPositionNull(getBlockValue.type(), 0); + methodHandle = collectArguments(methodHandle, parameterIndex + 1, isNull); + + // convert get block value to be null safe + getBlockValue = guardWithTest( + isBlockPositionNull(getBlockValue.type(), 0), + empty(getBlockValue.type()), + getBlockValue); + + // long, Block, int => Block, int, Block, int + methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue); + + int[] reorder = IntStream.range(0, methodHandle.type().parameterCount()) + .map(i -> i <= parameterIndex + 1 ? i : i - 2) + .toArray(); + MethodType newType = methodHandle.type().dropParameterTypes(parameterIndex + 2, parameterIndex + 4); + methodHandle = permuteArguments(methodHandle, newType, reorder); + return methodHandle; + } + + if (actualArgumentConvention == BLOCK_POSITION_NOT_NULL || actualArgumentConvention == VALUE_BLOCK_POSITION_NOT_NULL) { + if (returnConvention != FAIL_ON_NULL) { + MethodHandle nullReturnValue = getNullShortCircuitResult(methodHandle, returnConvention); + return guardWithTest( + isBlockPositionNull(methodHandle.type(), parameterIndex), + nullReturnValue, + methodHandle); + } + } + throw unsupportedArgumentAdaptation(actualArgumentConvention, expectedArgumentConvention, returnConvention); + } + + private static MethodHandle adaptParameterToBlockPositionNotNull(MethodHandle methodHandle, int parameterIndex, Type argumentType, InvocationArgumentConvention actualArgumentConvention, InvocationArgumentConvention expectedArgumentConvention, InvocationReturnConvention returnConvention) + { + if (actualArgumentConvention == BLOCK_POSITION || actualArgumentConvention == BLOCK_POSITION_NOT_NULL) { + return methodHandle; + } + + MethodHandle getBlockValue = getBlockValue(argumentType, methodHandle.type().parameterType(parameterIndex)); + if (actualArgumentConvention == NEVER_NULL) { + return collectArguments(methodHandle, parameterIndex, getBlockValue); + } + if (actualArgumentConvention == BOXED_NULLABLE) { + MethodType targetType = getBlockValue.type().changeReturnType(wrap(getBlockValue.type().returnType())); + return collectArguments(methodHandle, parameterIndex, explicitCastArguments(getBlockValue, targetType)); + } + if (actualArgumentConvention == NULL_FLAG) { + // actual method takes value and null flag, so change method handles to not have the flag and always pass false to the actual method + return collectArguments(insertArguments(methodHandle, parameterIndex + 1, false), parameterIndex, getBlockValue); + } + throw unsupportedArgumentAdaptation(actualArgumentConvention, expectedArgumentConvention, returnConvention); + } + + private static IllegalArgumentException unsupportedArgumentAdaptation(InvocationArgumentConvention actualArgumentConvention, InvocationArgumentConvention expectedArgumentConvention, InvocationReturnConvention returnConvention) + { + return new IllegalArgumentException("Cannot convert argument %s to %s with return convention %s".formatted(actualArgumentConvention, expectedArgumentConvention, returnConvention)); + } + + private static MethodHandle adaptValueBlockArgumentToBlock(MethodHandle methodHandle, int parameterIndex) + { + // someValueBlock, position => valueBlock, position + methodHandle = explicitCastArguments(methodHandle, methodHandle.type().changeParameterType(parameterIndex, ValueBlock.class)); + // valueBlock, position => block, position + methodHandle = collectArguments(methodHandle, parameterIndex, GET_UNDERLYING_VALUE_BLOCK_METHOD); + // block, position => block, block, position + methodHandle = collectArguments(methodHandle, parameterIndex + 1, GET_UNDERLYING_VALUE_POSITION_METHOD); + + // block, block, position => block, position + methodHandle = permuteArguments( + methodHandle, + methodHandle.type().dropParameterTypes(parameterIndex, parameterIndex + 1), + IntStream.range(0, methodHandle.type().parameterCount()) + .map(i -> i <= parameterIndex ? i : i - 1) + .toArray()); + return methodHandle; } private static MethodHandle getBlockValue(Type argumentType, Class expectedType) @@ -647,7 +758,7 @@ private static MethodHandle isTrueNullFlag(MethodType methodType, int index) private static MethodHandle isNullArgument(MethodType methodType, int index) { // Start with Objects.isNull(Object):boolean - MethodHandle isNull = IS_NULL_METHOD; + MethodHandle isNull = OBJECT_IS_NULL_METHOD; // Cast in incoming type: isNull(T):boolean isNull = explicitCastArguments(isNull, methodType(boolean.class, methodType.parameterType(index))); // Add extra argument to match the expected method type @@ -657,40 +768,15 @@ private static MethodHandle isNullArgument(MethodType methodType, int index) private static MethodHandle isBlockPositionNull(MethodType methodType, int index) { - // Start with Objects.isNull(Object):boolean - MethodHandle isNull; - try { - isNull = lookup().findVirtual(Block.class, "isNull", methodType(boolean.class, int.class)); - } - catch (ReflectiveOperationException e) { - throw new AssertionError(e); - } - // Add extra argument to match the expected method type - isNull = permuteArguments(isNull, methodType.changeReturnType(boolean.class), index, index + 1); - return isNull; + // Add extra argument to Block.isNull(int):boolean match the expected method type + MethodHandle blockIsNull = BLOCK_IS_NULL_METHOD.asType(BLOCK_IS_NULL_METHOD.type().changeParameterType(0, methodType.parameterType(index))); + return permuteArguments(blockIsNull, methodType.changeReturnType(boolean.class), index, index + 1); } private static MethodHandle isInOutNull(MethodType methodType, int index) { - MethodHandle isNull; - try { - isNull = lookup().findVirtual(InOut.class, "isNull", methodType(boolean.class)); - } - catch (ReflectiveOperationException e) { - throw new AssertionError(e); - } - isNull = permuteArguments(isNull, methodType.changeReturnType(boolean.class), index); - return isNull; - } - - private static MethodHandle lookupIsNullMethod() - { - try { - return lookup().findStatic(Objects.class, "isNull", methodType(boolean.class, Object.class)); - } - catch (ReflectiveOperationException e) { - throw new AssertionError(e); - } + // Add extra argument to InOut.isNull(int):boolean match the expected method type + return permuteArguments(IN_OUT_IS_NULL_METHOD, methodType.changeReturnType(boolean.class), index); } private static MethodHandle getNullShortCircuitResult(MethodHandle methodHandle, InvocationReturnConvention returnConvention) @@ -701,35 +787,12 @@ private static MethodHandle getNullShortCircuitResult(MethodHandle methodHandle, return empty(methodHandle.type()); } - private static MethodHandle lookupAppendNullMethod() - { - try { - return lookup().findVirtual(BlockBuilder.class, "appendNull", methodType(BlockBuilder.class)) - .asType(methodType(void.class, BlockBuilder.class)); - } - catch (ReflectiveOperationException e) { - throw new AssertionError(e); - } - } - private static MethodHandle throwTrinoNullArgumentException(MethodType type) { - MethodHandle throwException = collectArguments(throwException(type.returnType(), TrinoException.class), 0, trinoNullArgumentException()); + MethodHandle throwException = collectArguments(throwException(type.returnType(), TrinoException.class), 0, NEW_NEVER_NULL_IS_NULL_EXCEPTION); return permuteArguments(throwException, type); } - private static MethodHandle trinoNullArgumentException() - { - try { - return publicLookup().findConstructor(TrinoException.class, methodType(void.class, ErrorCodeSupplier.class, String.class)) - .bindTo(StandardErrorCode.INVALID_FUNCTION_ARGUMENT) - .bindTo("A never null argument is null"); - } - catch (ReflectiveOperationException e) { - throw new AssertionError(e); - } - } - private static boolean isWrapperType(Class type) { return type != unwrap(type); diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/AbstractIntType.java b/core/trino-spi/src/main/java/io/trino/spi/type/AbstractIntType.java index 360a8fd67cea..d9adc24d9650 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/AbstractIntType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/AbstractIntType.java @@ -13,14 +13,16 @@ */ package io.trino.spi.type; -import io.airlift.slice.Slice; import io.airlift.slice.XxHash64; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.IntArrayBlock; import io.trino.spi.block.IntArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -51,7 +53,7 @@ public abstract class AbstractIntType protected AbstractIntType(TypeSignature signature) { - super(signature, long.class); + super(signature, long.class, IntArrayBlock.class); } @Override @@ -86,13 +88,7 @@ public final long getLong(Block block, int position) public final int getInt(Block block, int position) { - return block.getInt(position, 0); - } - - @Override - public final Slice getSlice(Block block, int position) - { - return block.getSlice(position, 0, getFixedSize()); + return readInt((IntArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -107,7 +103,7 @@ public BlockBuilder writeInt(BlockBuilder blockBuilder, int value) return ((IntArrayBlockBuilder) blockBuilder).writeInt(value); } - protected void checkValueValid(long value) + protected static void checkValueValid(long value) { if (value > Integer.MAX_VALUE) { throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Value %d exceeds MAX_INT", value)); @@ -124,7 +120,7 @@ public final void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - writeInt(blockBuilder, block.getInt(position, 0)); + writeInt(blockBuilder, getInt(block, position)); } } @@ -161,6 +157,17 @@ public final BlockBuilder createFixedSizeBlockBuilder(int positionCount) return new IntArrayBlockBuilder(null, positionCount); } + @ScalarOperator(READ_VALUE) + private static long read(@BlockPosition IntArrayBlock block, @BlockIndex int position) + { + return readInt(block, position); + } + + private static int readInt(IntArrayBlock block, int position) + { + return block.getInt(position); + } + @ScalarOperator(READ_VALUE) private static long readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/AbstractLongType.java b/core/trino-spi/src/main/java/io/trino/spi/type/AbstractLongType.java index 2caeb91e708b..030c2ce4fe92 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/AbstractLongType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/AbstractLongType.java @@ -13,13 +13,15 @@ */ package io.trino.spi.type; -import io.airlift.slice.Slice; import io.airlift.slice.XxHash64; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.LongArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -49,7 +51,7 @@ public abstract class AbstractLongType public AbstractLongType(TypeSignature signature) { - super(signature, long.class); + super(signature, long.class, LongArrayBlock.class); } @Override @@ -79,13 +81,7 @@ public TypeOperatorDeclaration getTypeOperatorDeclaration(TypeOperators typeOper @Override public final long getLong(Block block, int position) { - return block.getLong(position, 0); - } - - @Override - public final Slice getSlice(Block block, int position) - { - return block.getSlice(position, 0, getFixedSize()); + return read((LongArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -144,6 +140,12 @@ public static long hash(long value) return rotateLeft(value * 0xC2B2AE3D27D4EB4FL, 31) * 0x9E3779B185EBCA87L; } + @ScalarOperator(READ_VALUE) + private static long read(@BlockPosition LongArrayBlock block, @BlockIndex int position) + { + return block.getLong(position); + } + @ScalarOperator(READ_VALUE) private static long readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/AbstractType.java b/core/trino-spi/src/main/java/io/trino/spi/type/AbstractType.java index 2719f3a5ea0e..a7c0b9fe15e1 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/AbstractType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/AbstractType.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import java.util.List; @@ -24,11 +25,13 @@ public abstract class AbstractType { private final TypeSignature signature; private final Class javaType; + private final Class valueBlockType; - protected AbstractType(TypeSignature signature, Class javaType) + protected AbstractType(TypeSignature signature, Class javaType, Class valueBlockType) { this.signature = signature; this.javaType = javaType; + this.valueBlockType = valueBlockType; } @Override @@ -49,6 +52,12 @@ public final Class getJavaType() return javaType; } + @Override + public Class getValueBlockType() + { + return valueBlockType; + } + @Override public List getTypeParameters() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/AbstractVariableWidthType.java b/core/trino-spi/src/main/java/io/trino/spi/type/AbstractVariableWidthType.java index 5d5d839792c5..be1cf7dd70ae 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/AbstractVariableWidthType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/AbstractVariableWidthType.java @@ -19,6 +19,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; import io.trino.spi.block.PageBuilderStatus; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.function.BlockIndex; import io.trino.spi.function.BlockPosition; @@ -54,7 +55,7 @@ public abstract class AbstractVariableWidthType protected AbstractVariableWidthType(TypeSignature signature, Class javaType) { - super(signature, javaType); + super(signature, javaType, VariableWidthBlock.class); } @Override @@ -72,7 +73,7 @@ public VariableWidthBlockBuilder createBlockBuilder(BlockBuilderStatus blockBuil int expectedBytes = (int) min((long) expectedEntries * expectedBytesPerEntry, maxBlockSizeInBytes); return new VariableWidthBlockBuilder( blockBuilderStatus, - expectedBytesPerEntry == 0 ? expectedEntries : Math.min(expectedEntries, maxBlockSizeInBytes / expectedBytesPerEntry), + expectedBytesPerEntry == 0 ? expectedEntries : min(expectedEntries, maxBlockSizeInBytes / expectedBytesPerEntry), expectedBytes); } @@ -89,7 +90,12 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - ((VariableWidthBlockBuilder) blockBuilder).buildEntry(valueBuilder -> block.writeSliceTo(position, 0, block.getSliceLength(position), valueBuilder)); + VariableWidthBlock variableWidthBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + position = block.getUnderlyingValuePosition(position); + Slice slice = variableWidthBlock.getRawSlice(); + int offset = variableWidthBlock.getRawSliceOffset(position); + int length = variableWidthBlock.getSliceLength(position); + ((VariableWidthBlockBuilder) blockBuilder).writeEntry(slice, offset, length); } } @@ -190,21 +196,24 @@ private static void writeFlatFromStack( @ScalarOperator(READ_VALUE) private static void writeFlatFromBlock( - @BlockPosition Block block, + @BlockPosition VariableWidthBlock block, @BlockIndex int position, byte[] fixedSizeSlice, int fixedSizeOffset, byte[] variableSizeSlice, int variableSizeOffset) { + Slice rawSlice = block.getRawSlice(); + int rawSliceOffset = block.getRawSliceOffset(position); int length = block.getSliceLength(position); + INT_HANDLE.set(fixedSizeSlice, fixedSizeOffset, length); if (length <= 12) { - block.writeSliceTo(position, 0, length, wrappedBuffer(fixedSizeSlice, fixedSizeOffset + Integer.BYTES, length).getOutput()); + rawSlice.getBytes(rawSliceOffset, fixedSizeSlice, fixedSizeOffset + Integer.BYTES, length); } else { INT_HANDLE.set(fixedSizeSlice, fixedSizeOffset + Integer.BYTES + Long.BYTES, variableSizeOffset); - block.writeSliceTo(position, 0, length, wrappedBuffer(variableSizeSlice, variableSizeOffset, length).getOutput()); + rawSlice.getBytes(rawSliceOffset, variableSizeSlice, variableSizeOffset, length); } } } @@ -218,31 +227,33 @@ private static boolean equalOperator(Slice left, Slice right) } @ScalarOperator(EQUAL) - private static boolean equalOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean equalOperator(@BlockPosition VariableWidthBlock leftBlock, @BlockIndex int leftPosition, @BlockPosition VariableWidthBlock rightBlock, @BlockIndex int rightPosition) { + Slice leftRawSlice = leftBlock.getRawSlice(); + int leftRawSliceOffset = leftBlock.getRawSliceOffset(leftPosition); int leftLength = leftBlock.getSliceLength(leftPosition); + + Slice rightRawSlice = rightBlock.getRawSlice(); + int rightRawSliceOffset = rightBlock.getRawSliceOffset(rightPosition); int rightLength = rightBlock.getSliceLength(rightPosition); - if (leftLength != rightLength) { - return false; - } - return leftBlock.equals(leftPosition, 0, rightBlock, rightPosition, 0, leftLength); + + return leftRawSlice.equals(leftRawSliceOffset, leftLength, rightRawSlice, rightRawSliceOffset, rightLength); } @ScalarOperator(EQUAL) - private static boolean equalOperator(Slice left, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean equalOperator(Slice left, @BlockPosition VariableWidthBlock rightBlock, @BlockIndex int rightPosition) { return equalOperator(rightBlock, rightPosition, left); } @ScalarOperator(EQUAL) - private static boolean equalOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, Slice right) + private static boolean equalOperator(@BlockPosition VariableWidthBlock leftBlock, @BlockIndex int leftPosition, Slice right) { + Slice leftRawSlice = leftBlock.getRawSlice(); + int leftRawSliceOffset = leftBlock.getRawSliceOffset(leftPosition); int leftLength = leftBlock.getSliceLength(leftPosition); - int rightLength = right.length(); - if (leftLength != rightLength) { - return false; - } - return leftBlock.bytesEqual(leftPosition, 0, right, 0, leftLength); + + return leftRawSlice.equals(leftRawSliceOffset, leftLength, right, 0, right.length()); } @ScalarOperator(EQUAL) @@ -283,7 +294,7 @@ private static boolean equalOperator( @ScalarOperator(EQUAL) private static boolean equalOperator( - @BlockPosition Block leftBlock, + @BlockPosition VariableWidthBlock leftBlock, @BlockIndex int leftPosition, @FlatFixed byte[] rightFixedSizeSlice, @FlatFixedOffset int rightFixedSizeOffset, @@ -302,19 +313,24 @@ private static boolean equalOperator( @FlatFixed byte[] leftFixedSizeSlice, @FlatFixedOffset int leftFixedSizeOffset, @FlatVariableWidth byte[] leftVariableSizeSlice, - @BlockPosition Block rightBlock, + @BlockPosition VariableWidthBlock rightBlock, @BlockIndex int rightPosition) { int leftLength = (int) INT_HANDLE.get(leftFixedSizeSlice, leftFixedSizeOffset); - if (rightBlock.isNull(rightPosition) || leftLength != rightBlock.getSliceLength(rightPosition)) { + + Slice rightRawSlice = rightBlock.getRawSlice(); + int rightRawSliceOffset = rightBlock.getRawSliceOffset(rightPosition); + int rightLength = rightBlock.getSliceLength(rightPosition); + + if (leftLength != rightLength) { return false; } if (leftLength <= 12) { - return rightBlock.bytesEqual(rightPosition, 0, wrappedBuffer(leftFixedSizeSlice, leftFixedSizeOffset + Integer.BYTES, leftLength), 0, leftLength); + return rightRawSlice.equals(rightRawSliceOffset, rightLength, wrappedBuffer(leftFixedSizeSlice, leftFixedSizeOffset + Integer.BYTES, leftLength), 0, leftLength); } else { int leftVariableSizeOffset = (int) INT_HANDLE.get(leftFixedSizeSlice, leftFixedSizeOffset + Integer.BYTES + Long.BYTES); - return rightBlock.bytesEqual(rightPosition, 0, wrappedBuffer(leftVariableSizeSlice, leftVariableSizeOffset, leftLength), 0, leftLength); + return rightRawSlice.equals(rightRawSliceOffset, rightLength, wrappedBuffer(leftVariableSizeSlice, leftVariableSizeOffset, leftLength), 0, leftLength); } } @@ -325,9 +341,9 @@ private static long xxHash64Operator(Slice value) } @ScalarOperator(XX_HASH_64) - private static long xxHash64Operator(@BlockPosition Block block, @BlockIndex int position) + private static long xxHash64Operator(@BlockPosition VariableWidthBlock block, @BlockIndex int position) { - return block.hash(position, 0, block.getSliceLength(position)); + return XxHash64.hash(block.getRawSlice(), block.getRawSliceOffset(position), block.getSliceLength(position)); } @ScalarOperator(XX_HASH_64) @@ -354,25 +370,37 @@ private static long comparisonOperator(Slice left, Slice right) } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static long comparisonOperator(@BlockPosition VariableWidthBlock leftBlock, @BlockIndex int leftPosition, @BlockPosition VariableWidthBlock rightBlock, @BlockIndex int rightPosition) { + Slice leftRawSlice = leftBlock.getRawSlice(); + int leftRawSliceOffset = leftBlock.getRawSliceOffset(leftPosition); int leftLength = leftBlock.getSliceLength(leftPosition); + + Slice rightRawSlice = rightBlock.getRawSlice(); + int rightRawSliceOffset = rightBlock.getRawSliceOffset(rightPosition); int rightLength = rightBlock.getSliceLength(rightPosition); - return leftBlock.compareTo(leftPosition, 0, leftLength, rightBlock, rightPosition, 0, rightLength); + + return leftRawSlice.compareTo(leftRawSliceOffset, leftLength, rightRawSlice, rightRawSliceOffset, rightLength); } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, Slice right) + private static long comparisonOperator(@BlockPosition VariableWidthBlock leftBlock, @BlockIndex int leftPosition, Slice right) { + Slice leftRawSlice = leftBlock.getRawSlice(); + int leftRawSliceOffset = leftBlock.getRawSliceOffset(leftPosition); int leftLength = leftBlock.getSliceLength(leftPosition); - return leftBlock.bytesCompare(leftPosition, 0, leftLength, right, 0, right.length()); + + return leftRawSlice.compareTo(leftRawSliceOffset, leftLength, right, 0, right.length()); } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(Slice left, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static long comparisonOperator(Slice left, @BlockPosition VariableWidthBlock rightBlock, @BlockIndex int rightPosition) { + Slice rightRawSlice = rightBlock.getRawSlice(); + int rightRawSliceOffset = rightBlock.getRawSliceOffset(rightPosition); int rightLength = rightBlock.getSliceLength(rightPosition); - return -rightBlock.bytesCompare(rightPosition, 0, rightLength, left, 0, left.length()); + + return left.compareTo(0, left.length(), rightRawSlice, rightRawSliceOffset, rightLength); } } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/ArrayType.java b/core/trino-spi/src/main/java/io/trino/spi/type/ArrayType.java index ba02ad691d89..5b904f16943c 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/ArrayType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/ArrayType.java @@ -18,6 +18,9 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.OperatorMethodHandle; @@ -33,12 +36,11 @@ import java.util.List; import java.util.function.BiFunction; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -97,13 +99,13 @@ public class ArrayType private final Type elementType; - // this field is used in double checked locking + // this field is used in double-checked locking @SuppressWarnings("FieldAccessedSynchronizedAndUnsynchronized") private volatile TypeOperatorDeclaration operatorDeclaration; public ArrayType(Type elementType) { - super(new TypeSignature(ARRAY, TypeSignatureParameter.typeParameter(elementType.getTypeSignature())), Block.class); + super(new TypeSignature(ARRAY, TypeSignatureParameter.typeParameter(elementType.getTypeSignature())), Block.class, ArrayBlock.class); this.elementType = requireNonNull(elementType, "elementType is null"); } @@ -139,7 +141,7 @@ private static List getReadValueOperatorMethodHandles(Type MethodHandle readFlat = insertArguments(READ_FLAT, 0, elementType, elementReadOperator, elementType.getFlatFixedSize()); MethodHandle readFlatToBlock = insertArguments(READ_FLAT_TO_BLOCK, 0, elementReadOperator, elementType.getFlatFixedSize()); - MethodHandle elementWriteOperator = typeOperators.getReadValueOperator(elementType, simpleConvention(FLAT_RETURN, BLOCK_POSITION)); + MethodHandle elementWriteOperator = typeOperators.getReadValueOperator(elementType, simpleConvention(FLAT_RETURN, VALUE_BLOCK_POSITION_NOT_NULL)); MethodHandle writeFlatToBlock = insertArguments(WRITE_FLAT, 0, elementType, elementWriteOperator, elementType.getFlatFixedSize(), elementType.isFlatVariableWidth()); return List.of( new OperatorMethodHandle(READ_FLAT_CONVENTION, readFlat), @@ -152,7 +154,7 @@ private static List getEqualOperatorMethodHandles(TypeOper if (!elementType.isComparable()) { return emptyList(); } - MethodHandle equalOperator = typeOperators.getEqualOperator(elementType, simpleConvention(NULLABLE_RETURN, BLOCK_POSITION_NOT_NULL, BLOCK_POSITION_NOT_NULL)); + MethodHandle equalOperator = typeOperators.getEqualOperator(elementType, simpleConvention(NULLABLE_RETURN, VALUE_BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); return singletonList(new OperatorMethodHandle(EQUAL_CONVENTION, EQUAL.bindTo(equalOperator))); } @@ -161,7 +163,7 @@ private static List getHashCodeOperatorMethodHandles(TypeO if (!elementType.isComparable()) { return emptyList(); } - MethodHandle elementHashCodeOperator = typeOperators.getHashCodeOperator(elementType, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)); + MethodHandle elementHashCodeOperator = typeOperators.getHashCodeOperator(elementType, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); return singletonList(new OperatorMethodHandle(HASH_CODE_CONVENTION, HASH_CODE.bindTo(elementHashCodeOperator))); } @@ -170,7 +172,7 @@ private static List getXxHash64OperatorMethodHandles(TypeO if (!elementType.isComparable()) { return emptyList(); } - MethodHandle elementHashCodeOperator = typeOperators.getXxHash64Operator(elementType, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)); + MethodHandle elementHashCodeOperator = typeOperators.getXxHash64Operator(elementType, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); return singletonList(new OperatorMethodHandle(HASH_CODE_CONVENTION, HASH_CODE.bindTo(elementHashCodeOperator))); } @@ -179,7 +181,7 @@ private static List getDistinctFromOperatorInvokers(TypeOp if (!elementType.isComparable()) { return emptyList(); } - MethodHandle elementDistinctFromOperator = typeOperators.getDistinctFromOperator(elementType, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION)); + MethodHandle elementDistinctFromOperator = typeOperators.getDistinctFromOperator(elementType, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); return singletonList(new OperatorMethodHandle(DISTINCT_FROM_CONVENTION, DISTINCT_FROM.bindTo(elementDistinctFromOperator))); } @@ -188,7 +190,7 @@ private static List getIndeterminateOperatorInvokers(TypeO if (!elementType.isComparable()) { return emptyList(); } - MethodHandle elementIndeterminateOperator = typeOperators.getIndeterminateOperator(elementType, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)); + MethodHandle elementIndeterminateOperator = typeOperators.getIndeterminateOperator(elementType, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); return singletonList(new OperatorMethodHandle(INDETERMINATE_CONVENTION, INDETERMINATE.bindTo(elementIndeterminateOperator))); } @@ -197,7 +199,7 @@ private static List getComparisonOperatorInvokers(BiFuncti if (!elementType.isOrderable()) { return emptyList(); } - MethodHandle elementComparisonOperator = comparisonOperatorFactory.apply(elementType, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL, BLOCK_POSITION_NOT_NULL)); + MethodHandle elementComparisonOperator = comparisonOperatorFactory.apply(elementType, simpleConvention(FAIL_ON_NULL, VALUE_BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL)); return singletonList(new OperatorMethodHandle(COMPARISON_CONVENTION, COMPARISON.bindTo(elementComparisonOperator))); } @@ -228,7 +230,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position if (block instanceof ArrayBlock) { return ((ArrayBlock) block).apply((valuesBlock, start, length) -> arrayBlockToObjectValues(session, valuesBlock, start, length), position); } - Block arrayBlock = block.getObject(position, Block.class); + Block arrayBlock = getObject(block, position); return arrayBlockToObjectValues(session, arrayBlock, 0, arrayBlock.getPositionCount()); } @@ -257,7 +259,7 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) @Override public Block getObject(Block block, int position) { - return block.getObject(position, Block.class); + return read((ArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -386,6 +388,11 @@ public String getDisplayName() return ARRAY + "(" + elementType.getDisplayName() + ")"; } + private static Block read(ArrayBlock block, int position) + { + return block.getArray(position); + } + private static Block readFlat( Type elementType, MethodHandle elementReadFlat, @@ -457,34 +464,59 @@ private static void writeFlat( private static void writeFlatElements(Type elementType, MethodHandle elementWriteFlat, int elementFixedSize, boolean elementVariableWidth, Block array, byte[] slice, int offset) throws Throwable { + array = array.getLoadedBlock(); + int positionCount = array.getPositionCount(); // variable width data starts after fixed width data // there is one extra byte per position for the null flag int writeVariableWidthOffset = offset + positionCount * (1 + elementFixedSize); - for (int index = 0; index < positionCount; index++) { - if (array.isNull(index)) { - slice[offset] = 1; - offset++; + if (array instanceof ValueBlock valuesBlock) { + for (int index = 0; index < positionCount; index++) { + writeVariableWidthOffset = writeFlatElement(elementType, elementWriteFlat, elementVariableWidth, valuesBlock, index, slice, offset, writeVariableWidthOffset); + offset += 1 + elementFixedSize; } - else { - // skip null byte - offset++; + } + else if (array instanceof RunLengthEncodedBlock rleBlock) { + ValueBlock valuesBlock = rleBlock.getValue(); + for (int index = 0; index < positionCount; index++) { + writeVariableWidthOffset = writeFlatElement(elementType, elementWriteFlat, elementVariableWidth, valuesBlock, 0, slice, offset, writeVariableWidthOffset); + offset += 1 + elementFixedSize; + } + } + else if (array instanceof DictionaryBlock dictionaryBlock) { + ValueBlock valuesBlock = dictionaryBlock.getDictionary(); + for (int position = 0; position < positionCount; position++) { + int index = dictionaryBlock.getId(position); + writeVariableWidthOffset = writeFlatElement(elementType, elementWriteFlat, elementVariableWidth, valuesBlock, index, slice, offset, writeVariableWidthOffset); + offset += 1 + elementFixedSize; + } + } + else { + throw new IllegalArgumentException("Unsupported block type: " + array.getClass().getName()); + } + } - int elementVariableSize = 0; - if (elementVariableWidth) { - elementVariableSize = elementType.getFlatVariableWidthSize(array, index); - } - elementWriteFlat.invokeExact( - array, - index, - slice, - offset, - slice, - writeVariableWidthOffset); - writeVariableWidthOffset += elementVariableSize; + private static int writeFlatElement(Type elementType, MethodHandle elementWriteFlat, boolean elementVariableWidth, ValueBlock array, int index, byte[] slice, int offset, int writeVariableWidthOffset) + throws Throwable + { + if (array.isNull(index)) { + slice[offset] = 1; + } + else { + int elementVariableSize = 0; + if (elementVariableWidth) { + elementVariableSize = elementType.getFlatVariableWidthSize(array, index); } - offset += elementFixedSize; + elementWriteFlat.invokeExact( + array, + index, + slice, + offset + 1, // skip null byte + slice, + writeVariableWidthOffset); + writeVariableWidthOffset += elementVariableSize; } + return writeVariableWidthOffset; } private static Boolean equalOperator(MethodHandle equalOperator, Block leftArray, Block rightArray) @@ -494,13 +526,21 @@ private static Boolean equalOperator(MethodHandle equalOperator, Block leftArray return false; } + leftArray = leftArray.getLoadedBlock(); + rightArray = rightArray.getLoadedBlock(); + + ValueBlock leftValues = leftArray.getUnderlyingValueBlock(); + ValueBlock rightValues = rightArray.getUnderlyingValueBlock(); + boolean unknown = false; for (int position = 0; position < leftArray.getPositionCount(); position++) { - if (leftArray.isNull(position) || rightArray.isNull(position)) { + int leftIndex = leftArray.getUnderlyingValuePosition(position); + int rightIndex = rightArray.getUnderlyingValuePosition(position); + if (leftValues.isNull(leftIndex) || rightValues.isNull(rightIndex)) { unknown = true; continue; } - Boolean result = (Boolean) equalOperator.invokeExact(leftArray, position, rightArray, position); + Boolean result = (Boolean) equalOperator.invokeExact(leftValues, leftIndex, rightValues, rightIndex); if (result == null) { unknown = true; } @@ -515,15 +555,43 @@ else if (!result) { return true; } - private static long hashOperator(MethodHandle hashOperator, Block block) + private static long hashOperator(MethodHandle hashOperator, Block array) throws Throwable { - long hash = 0; - for (int position = 0; position < block.getPositionCount(); position++) { - long elementHash = block.isNull(position) ? NULL_HASH_CODE : (long) hashOperator.invokeExact(block, position); - hash = 31 * hash + elementHash; + array = array.getLoadedBlock(); + + if (array instanceof ValueBlock valuesBlock) { + long hash = 0; + for (int index = 0; index < valuesBlock.getPositionCount(); index++) { + long elementHash = valuesBlock.isNull(index) ? NULL_HASH_CODE : (long) hashOperator.invokeExact(valuesBlock, index); + hash = 31 * hash + elementHash; + } + return hash; + } + + if (array instanceof RunLengthEncodedBlock rleBlock) { + ValueBlock valuesBlock = rleBlock.getValue(); + long elementHash = valuesBlock.isNull(0) ? NULL_HASH_CODE : (long) hashOperator.invokeExact(valuesBlock, 0); + + long hash = 0; + for (int position = 0; position < valuesBlock.getPositionCount(); position++) { + hash = 31 * hash + elementHash; + } + return hash; } - return hash; + + if (array instanceof DictionaryBlock dictionaryBlock) { + ValueBlock valuesBlock = dictionaryBlock.getDictionary(); + long hash = 0; + for (int position = 0; position < valuesBlock.getPositionCount(); position++) { + int index = dictionaryBlock.getId(position); + long elementHash = valuesBlock.isNull(position) ? NULL_HASH_CODE : (long) hashOperator.invokeExact(valuesBlock, index); + hash = 31 * hash + elementHash; + } + return hash; + } + + throw new IllegalArgumentException("Unsupported block type: " + array.getClass().getName()); } private static boolean distinctFromOperator(MethodHandle distinctFromOperator, Block leftArray, Block rightArray) @@ -539,8 +607,26 @@ private static boolean distinctFromOperator(MethodHandle distinctFromOperator, B return true; } + leftArray = leftArray.getLoadedBlock(); + rightArray = rightArray.getLoadedBlock(); + + ValueBlock leftValues = leftArray.getUnderlyingValueBlock(); + ValueBlock rightValues = rightArray.getUnderlyingValueBlock(); + for (int position = 0; position < leftArray.getPositionCount(); position++) { - boolean result = (boolean) distinctFromOperator.invokeExact(leftArray, position, rightArray, position); + int leftIndex = leftArray.getUnderlyingValuePosition(position); + int rightIndex = rightArray.getUnderlyingValuePosition(position); + + boolean leftValueIsNull = leftValues.isNull(leftIndex); + boolean rightValueIsNull = rightValues.isNull(rightIndex); + if (leftValueIsNull != rightValueIsNull) { + return true; + } + if (leftValueIsNull) { + continue; + } + + boolean result = (boolean) distinctFromOperator.invokeExact(leftValues, leftIndex, rightValues, rightIndex); if (result) { return true; } @@ -549,33 +635,73 @@ private static boolean distinctFromOperator(MethodHandle distinctFromOperator, B return false; } - private static boolean indeterminateOperator(MethodHandle elementIndeterminateFunction, Block block, boolean isNull) + private static boolean indeterminateOperator(MethodHandle elementIndeterminateFunction, Block array, boolean isNull) throws Throwable { if (isNull) { return true; } - for (int position = 0; position < block.getPositionCount(); position++) { - if (block.isNull(position)) { + array = array.getLoadedBlock(); + + if (array instanceof ValueBlock valuesBlock) { + for (int index = 0; index < valuesBlock.getPositionCount(); index++) { + if (valuesBlock.isNull(index)) { + return true; + } + if ((boolean) elementIndeterminateFunction.invoke(valuesBlock, index)) { + return true; + } + } + return false; + } + + if (array instanceof RunLengthEncodedBlock rleBlock) { + ValueBlock valuesBlock = rleBlock.getValue(); + if (valuesBlock.isNull(0)) { return true; } - if ((boolean) elementIndeterminateFunction.invoke(block, position)) { + if ((boolean) elementIndeterminateFunction.invoke(valuesBlock, 0)) { return true; } + return false; } - return false; + + if (array instanceof DictionaryBlock dictionaryBlock) { + ValueBlock valuesBlock = dictionaryBlock.getDictionary(); + for (int position = 0; position < valuesBlock.getPositionCount(); position++) { + int index = dictionaryBlock.getId(position); + if (valuesBlock.isNull(index)) { + return true; + } + if ((boolean) elementIndeterminateFunction.invoke(valuesBlock, index)) { + return true; + } + } + return false; + } + + throw new IllegalArgumentException("Unsupported block type: " + array.getClass().getName()); } private static long comparisonOperator(MethodHandle comparisonOperator, Block leftArray, Block rightArray) throws Throwable { + leftArray = leftArray.getLoadedBlock(); + rightArray = rightArray.getLoadedBlock(); + + ValueBlock leftValues = leftArray.getUnderlyingValueBlock(); + ValueBlock rightValues = rightArray.getUnderlyingValueBlock(); + int len = Math.min(leftArray.getPositionCount(), rightArray.getPositionCount()); for (int position = 0; position < len; position++) { checkElementNotNull(leftArray.isNull(position), ARRAY_NULL_ELEMENT_MSG); checkElementNotNull(rightArray.isNull(position), ARRAY_NULL_ELEMENT_MSG); - long result = (long) comparisonOperator.invokeExact(leftArray, position, rightArray, position); + int leftIndex = leftArray.getUnderlyingValuePosition(position); + int rightIndex = rightArray.getUnderlyingValuePosition(position); + + long result = (long) comparisonOperator.invokeExact(leftValues, leftIndex, rightValues, rightIndex); if (result != 0) { return result; } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/BigintType.java b/core/trino-spi/src/main/java/io/trino/spi/type/BigintType.java index 8c6bd5dc61ac..7fb7b46fed51 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/BigintType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/BigintType.java @@ -37,7 +37,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return block.getLong(position, 0); + return getLong(block, position); } @Override diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/BooleanType.java b/core/trino-spi/src/main/java/io/trino/spi/type/BooleanType.java index cab45b0f1599..d2195f2c1619 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/BooleanType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/BooleanType.java @@ -21,6 +21,8 @@ import io.trino.spi.block.ByteArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -67,7 +69,7 @@ public static Block createBlockForSingleNonNullValue(boolean value) private BooleanType() { - super(new TypeSignature(StandardTypes.BOOLEAN), boolean.class); + super(new TypeSignature(StandardTypes.BOOLEAN), boolean.class, ByteArrayBlock.class); } @Override @@ -128,7 +130,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return block.getByte(position, 0) != 0; + return getBoolean(block, position); } @Override @@ -138,14 +140,14 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - ((ByteArrayBlockBuilder) blockBuilder).writeByte(block.getByte(position, 0)); + ((ByteArrayBlockBuilder) blockBuilder).writeByte(getBoolean(block, position) ? (byte) 1 : 0); } } @Override public boolean getBoolean(Block block, int position) { - return block.getByte(position, 0) != 0; + return read((ByteArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -172,6 +174,12 @@ public int hashCode() return getClass().hashCode(); } + @ScalarOperator(READ_VALUE) + private static boolean read(@BlockPosition ByteArrayBlock block, @BlockIndex int position) + { + return block.getByte(position) != 0; + } + @ScalarOperator(READ_VALUE) private static boolean readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/CharType.java b/core/trino-spi/src/main/java/io/trino/spi/type/CharType.java index ed48d409e2b0..b9e1967a848a 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/CharType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/CharType.java @@ -19,6 +19,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.ScalarOperator; @@ -125,7 +126,7 @@ public Optional getRange() if (!cachedRangePresent) { if (length > 100) { // The max/min values may be materialized in the plan, so we don't want them to be too large. - // Range comparison against large values are usually nonsensical, too, so no need to support them + // Range comparison against large values is usually nonsensical, too, so no need to support them // beyond a certain size. They specific choice above is arbitrary and can be adjusted if needed. range = Optional.empty(); } @@ -158,7 +159,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - Slice slice = block.getSlice(position, 0, block.getSliceLength(position)); + Slice slice = getSlice(block, position); if (slice.length() > 0) { if (countCodePoints(slice) > length) { throw new IllegalArgumentException(format("Character count exceeds length limit %s: %s", length, sliceRepresentation(slice))); @@ -182,7 +183,9 @@ public VariableWidthBlockBuilder createBlockBuilder(BlockBuilderStatus blockBuil @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } public void writeString(BlockBuilder blockBuilder, String value) diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/DateType.java b/core/trino-spi/src/main/java/io/trino/spi/type/DateType.java index be74bb309a49..5dad5193338f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/DateType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/DateType.java @@ -23,7 +23,7 @@ // // Note: when dealing with a java.sql.Date it is important to remember that the value is stored // as the number of milliseconds from 1970-01-01T00:00:00 in UTC but time must be midnight in -// the local time zone. This mean when converting between a java.sql.Date and this +// the local time zone. This means when converting between a java.sql.Date and this // type, the time zone offset must be added or removed to keep the time at midnight in UTC. // public final class DateType @@ -43,7 +43,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - int days = block.getInt(position, 0); + int days = getInt(block, position); return new SqlDate(days); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/DecimalType.java b/core/trino-spi/src/main/java/io/trino/spi/type/DecimalType.java index efaaba6a2f51..828780d5ba16 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/DecimalType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/DecimalType.java @@ -14,6 +14,7 @@ package io.trino.spi.type; import io.trino.spi.TrinoException; +import io.trino.spi.block.ValueBlock; import java.util.List; @@ -60,9 +61,9 @@ public static DecimalType createDecimalType() private final int precision; private final int scale; - DecimalType(int precision, int scale, Class javaType) + DecimalType(int precision, int scale, Class javaType, Class valueBlockType) { - super(new TypeSignature(StandardTypes.DECIMAL, buildTypeParameters(precision, scale)), javaType); + super(new TypeSignature(StandardTypes.DECIMAL, buildTypeParameters(precision, scale)), javaType, valueBlockType); this.precision = precision; this.scale = scale; } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/DoubleType.java b/core/trino-spi/src/main/java/io/trino/spi/type/DoubleType.java index 25e506abb701..4a9175875e70 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/DoubleType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/DoubleType.java @@ -17,9 +17,12 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.LongArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -56,7 +59,7 @@ public final class DoubleType private DoubleType() { - super(new TypeSignature(StandardTypes.DOUBLE), double.class); + super(new TypeSignature(StandardTypes.DOUBLE), double.class, LongArrayBlock.class); } @Override @@ -89,7 +92,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position if (block.isNull(position)) { return null; } - return longBitsToDouble(block.getLong(position, 0)); + return getDouble(block, position); } @Override @@ -99,14 +102,16 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - ((LongArrayBlockBuilder) blockBuilder).writeLong(block.getLong(position, 0)); + LongArrayBlock valueBlock = (LongArrayBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + ((LongArrayBlockBuilder) blockBuilder).writeLong(valueBlock.getLong(valuePosition)); } } @Override public double getDouble(Block block, int position) { - return longBitsToDouble(block.getLong(position, 0)); + return read((LongArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -169,6 +174,12 @@ public Optional getRange() return Optional.empty(); } + @ScalarOperator(READ_VALUE) + private static double read(@BlockPosition LongArrayBlock block, @BlockIndex int position) + { + return longBitsToDouble(block.getLong(position)); + } + @ScalarOperator(READ_VALUE) private static double readFlat( @FlatFixed byte[] fixedSizeSlice, @@ -189,6 +200,7 @@ private static void writeFlat( DOUBLE_HANDLE.set(fixedSizeSlice, fixedSizeOffset, value); } + @SuppressWarnings("FloatingPointEquality") @ScalarOperator(EQUAL) private static boolean equalOperator(double left, double right) { @@ -213,6 +225,7 @@ public static long xxHash64(double value) return XxHash64.hash(doubleToLongBits(value)); } + @SuppressWarnings("FloatingPointEquality") @ScalarOperator(IS_DISTINCT_FROM) private static boolean distinctFromOperator(double left, @IsNull boolean leftNull, double right, @IsNull boolean rightNull) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/HyperLogLogType.java b/core/trino-spi/src/main/java/io/trino/spi/type/HyperLogLogType.java index a128184e4a75..15be6c253302 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/HyperLogLogType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/HyperLogLogType.java @@ -17,6 +17,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; @@ -37,7 +38,9 @@ public HyperLogLogType() @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } @Override @@ -59,6 +62,6 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return new SqlVarbinary(block.getSlice(position, 0, block.getSliceLength(position)).getBytes()); + return new SqlVarbinary(getSlice(block, position).getBytes()); } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/IntegerType.java b/core/trino-spi/src/main/java/io/trino/spi/type/IntegerType.java index 0726dabb941d..7d9da63bbaba 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/IntegerType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/IntegerType.java @@ -37,7 +37,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return block.getInt(position, 0); + return getInt(block, position); } @Override diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/LongDecimalType.java b/core/trino-spi/src/main/java/io/trino/spi/type/LongDecimalType.java index 42b31d537f9c..1519ecacecfd 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/LongDecimalType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/LongDecimalType.java @@ -17,6 +17,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.Int128ArrayBlock; import io.trino.spi.block.Int128ArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; @@ -49,7 +50,7 @@ final class LongDecimalType LongDecimalType(int precision, int scale) { - super(precision, scale, Int128.class); + super(precision, scale, Int128.class, Int128ArrayBlock.class); checkArgument(Decimals.MAX_SHORT_PRECISION < precision && precision <= Decimals.MAX_PRECISION, "Invalid precision: %s", precision); checkArgument(0 <= scale && scale <= precision, "Invalid scale for precision %s: %s", precision, scale); } @@ -99,7 +100,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position if (block.isNull(position)) { return null; } - Int128 value = (Int128) getObject(block, position); + Int128 value = getObject(block, position); BigInteger unscaledValue = value.toBigInteger(); return new SqlDecimal(unscaledValue, getPrecision(), getScale()); } @@ -111,9 +112,9 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - ((Int128ArrayBlockBuilder) blockBuilder).writeInt128( - block.getLong(position, 0), - block.getLong(position, SIZE_OF_LONG)); + Int128ArrayBlock valueBlock = (Int128ArrayBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + ((Int128ArrayBlockBuilder) blockBuilder).writeInt128(valueBlock.getInt128High(valuePosition), valueBlock.getInt128Low(valuePosition)); } } @@ -125,11 +126,9 @@ public void writeObject(BlockBuilder blockBuilder, Object value) } @Override - public Object getObject(Block block, int position) + public Int128 getObject(Block block, int position) { - return Int128.valueOf( - block.getLong(position, 0), - block.getLong(position, SIZE_OF_LONG)); + return read((Int128ArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -138,6 +137,12 @@ public int getFlatFixedSize() return INT128_BYTES; } + @ScalarOperator(READ_VALUE) + private static Int128 read(@BlockPosition Int128ArrayBlock block, @BlockIndex int position) + { + return block.getInt128(position); + } + @ScalarOperator(READ_VALUE) private static Int128 readFlat( @FlatFixed byte[] fixedSizeSlice, @@ -175,15 +180,15 @@ private static void writeFlat( @ScalarOperator(READ_VALUE) private static void writeBlockToFlat( - @BlockPosition Block block, + @BlockPosition Int128ArrayBlock block, @BlockIndex int position, byte[] fixedSizeSlice, int fixedSizeOffset, byte[] unusedVariableSizeSlice, int unusedVariableSizeOffset) { - LONG_HANDLE.set(fixedSizeSlice, fixedSizeOffset, block.getLong(position, 0)); - LONG_HANDLE.set(fixedSizeSlice, fixedSizeOffset + SIZE_OF_LONG, block.getLong(position, SIZE_OF_LONG)); + LONG_HANDLE.set(fixedSizeSlice, fixedSizeOffset, block.getInt128High(position)); + LONG_HANDLE.set(fixedSizeSlice, fixedSizeOffset + SIZE_OF_LONG, block.getInt128Low(position)); } @ScalarOperator(EQUAL) @@ -193,10 +198,10 @@ private static boolean equalOperator(Int128 left, Int128 right) } @ScalarOperator(EQUAL) - private static boolean equalOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean equalOperator(@BlockPosition Int128ArrayBlock leftBlock, @BlockIndex int leftPosition, @BlockPosition Int128ArrayBlock rightBlock, @BlockIndex int rightPosition) { - return leftBlock.getLong(leftPosition, 0) == rightBlock.getLong(rightPosition, 0) && - leftBlock.getLong(leftPosition, SIZE_OF_LONG) == rightBlock.getLong(rightPosition, SIZE_OF_LONG); + return leftBlock.getInt128High(leftPosition) == rightBlock.getInt128High(rightPosition) && + leftBlock.getInt128Low(leftPosition) == rightBlock.getInt128Low(rightPosition); } @ScalarOperator(XX_HASH_64) @@ -206,9 +211,9 @@ private static long xxHash64Operator(Int128 value) } @ScalarOperator(XX_HASH_64) - private static long xxHash64Operator(@BlockPosition Block block, @BlockIndex int position) + private static long xxHash64Operator(@BlockPosition Int128ArrayBlock block, @BlockIndex int position) { - return xxHash64(block.getLong(position, 0), block.getLong(position, SIZE_OF_LONG)); + return xxHash64(block.getInt128High(position), block.getInt128Low(position)); } private static long xxHash64(long low, long high) @@ -223,12 +228,12 @@ private static long comparisonOperator(Int128 left, Int128 right) } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static long comparisonOperator(@BlockPosition Int128ArrayBlock leftBlock, @BlockIndex int leftPosition, @BlockPosition Int128ArrayBlock rightBlock, @BlockIndex int rightPosition) { return Int128.compare( - leftBlock.getLong(leftPosition, 0), - leftBlock.getLong(leftPosition, SIZE_OF_LONG), - rightBlock.getLong(rightPosition, 0), - rightBlock.getLong(rightPosition, SIZE_OF_LONG)); + leftBlock.getInt128High(leftPosition), + leftBlock.getInt128Low(leftPosition), + rightBlock.getInt128High(rightPosition), + rightBlock.getInt128Low(rightPosition)); } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/LongTimeWithTimeZoneType.java b/core/trino-spi/src/main/java/io/trino/spi/type/LongTimeWithTimeZoneType.java index 57e6573fdce5..9fc12b124a0f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/LongTimeWithTimeZoneType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/LongTimeWithTimeZoneType.java @@ -17,6 +17,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.Fixed12Block; import io.trino.spi.block.Fixed12BlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; @@ -53,7 +54,7 @@ final class LongTimeWithTimeZoneType public LongTimeWithTimeZoneType(int precision) { - super(precision, LongTimeWithTimeZone.class); + super(precision, LongTimeWithTimeZone.class, Fixed12Block.class); if (precision < MAX_SHORT_PRECISION + 1 || precision > MAX_PRECISION) { throw new IllegalArgumentException(format("Precision must be in the range [%s, %s]", MAX_SHORT_PRECISION + 1, MAX_PRECISION)); @@ -106,14 +107,18 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - write(blockBuilder, getPicos(block, position), getOffsetMinutes(block, position)); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + write(blockBuilder, getPicos(valueBlock, valuePosition), getOffsetMinutes(valueBlock, valuePosition)); } } @Override - public Object getObject(Block block, int position) + public LongTimeWithTimeZone getObject(Block block, int position) { - return new LongTimeWithTimeZone(getPicos(block, position), getOffsetMinutes(block, position)); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return new LongTimeWithTimeZone(getPicos(valueBlock, valuePosition), getOffsetMinutes(valueBlock, valuePosition)); } @Override @@ -135,7 +140,9 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return SqlTimeWithTimeZone.newInstance(getPrecision(), getPicos(block, position), getOffsetMinutes(block, position)); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return SqlTimeWithTimeZone.newInstance(getPrecision(), getPicos(valueBlock, valuePosition), getOffsetMinutes(valueBlock, valuePosition)); } @Override @@ -144,14 +151,14 @@ public int getFlatFixedSize() return Long.BYTES + Integer.BYTES; } - private static long getPicos(Block block, int position) + private static long getPicos(Fixed12Block block, int position) { - return block.getLong(position, 0); + return block.getFixed12First(position); } - private static int getOffsetMinutes(Block block, int position) + private static int getOffsetMinutes(Fixed12Block block, int position) { - return block.getInt(position, SIZE_OF_LONG); + return block.getFixed12Second(position); } @ScalarOperator(READ_VALUE) @@ -191,7 +198,7 @@ private static void writeFlat( @ScalarOperator(READ_VALUE) private static void writeBlockFlat( - @BlockPosition Block block, + @BlockPosition Fixed12Block block, @BlockIndex int position, byte[] fixedSizeSlice, int fixedSizeOffset, @@ -213,7 +220,7 @@ private static boolean equalOperator(LongTimeWithTimeZone left, LongTimeWithTime } @ScalarOperator(EQUAL) - private static boolean equalOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean equalOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return equal( getPicos(leftBlock, leftPosition), @@ -234,7 +241,7 @@ private static long hashCodeOperator(LongTimeWithTimeZone value) } @ScalarOperator(HASH_CODE) - private static long hashCodeOperator(@BlockPosition Block block, @BlockIndex int position) + private static long hashCodeOperator(@BlockPosition Fixed12Block block, @BlockIndex int position) { return hashCodeOperator(getPicos(block, position), getOffsetMinutes(block, position)); } @@ -251,7 +258,7 @@ private static long xxHash64Operator(LongTimeWithTimeZone value) } @ScalarOperator(XX_HASH_64) - private static long xxHash64Operator(@BlockPosition Block block, @BlockIndex int position) + private static long xxHash64Operator(@BlockPosition Fixed12Block block, @BlockIndex int position) { return xxHash64(getPicos(block, position), getOffsetMinutes(block, position)); } @@ -272,7 +279,7 @@ private static long comparisonOperator(LongTimeWithTimeZone left, LongTimeWithTi } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static long comparisonOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return comparison( getPicos(leftBlock, leftPosition), @@ -297,7 +304,7 @@ private static boolean lessThanOperator(LongTimeWithTimeZone left, LongTimeWithT } @ScalarOperator(LESS_THAN) - private static boolean lessThanOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean lessThanOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return lessThan( getPicos(leftBlock, leftPosition), @@ -322,7 +329,7 @@ private static boolean lessThanOrEqualOperator(LongTimeWithTimeZone left, LongTi } @ScalarOperator(LESS_THAN_OR_EQUAL) - private static boolean lessThanOrEqualOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean lessThanOrEqualOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return lessThanOrEqual( getPicos(leftBlock, leftPosition), diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/LongTimestampType.java b/core/trino-spi/src/main/java/io/trino/spi/type/LongTimestampType.java index c7c5d0e6426c..0e13a7b7331e 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/LongTimestampType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/LongTimestampType.java @@ -17,6 +17,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.Fixed12Block; import io.trino.spi.block.Fixed12BlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; @@ -51,7 +52,7 @@ * in the first long and the fractional increment in the remaining integer, as * a number of picoseconds additional to the epoch microsecond. */ -class LongTimestampType +final class LongTimestampType extends TimestampType { private static final TypeOperatorDeclaration TYPE_OPERATOR_DECLARATION = extractOperatorDeclaration(LongTimestampType.class, lookup(), LongTimestamp.class); @@ -61,13 +62,13 @@ class LongTimestampType public LongTimestampType(int precision) { - super(precision, LongTimestamp.class); + super(precision, LongTimestamp.class, Fixed12Block.class); if (precision < MAX_SHORT_PRECISION + 1 || precision > MAX_PRECISION) { throw new IllegalArgumentException(format("Precision must be in the range [%s, %s]", MAX_SHORT_PRECISION + 1, MAX_PRECISION)); } - // ShortTimestampType instances are created eagerly and shared so it's OK to precompute some things. + // ShortTimestampType instances are created eagerly and shared, so it's OK to precompute some things. int picosOfMicroMax = toIntExact(PICOSECONDS_PER_MICROSECOND - rescale(1, 0, 12 - getPrecision())); range = new Range(new LongTimestamp(Long.MIN_VALUE, 0), new LongTimestamp(Long.MAX_VALUE, picosOfMicroMax)); } @@ -118,16 +119,18 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - ((Fixed12BlockBuilder) blockBuilder).writeFixed12( - getEpochMicros(block, position), - getFraction(block, position)); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + ((Fixed12BlockBuilder) blockBuilder).writeFixed12(getEpochMicros(valueBlock, valuePosition), getFraction(valueBlock, valuePosition)); } } @Override public Object getObject(Block block, int position) { - return new LongTimestamp(getEpochMicros(block, position), getFraction(block, position)); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return new LongTimestamp(getEpochMicros(valueBlock, valuePosition), getFraction(valueBlock, valuePosition)); } @Override @@ -149,10 +152,9 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - long epochMicros = getEpochMicros(block, position); - int fraction = getFraction(block, position); - - return SqlTimestamp.newInstance(getPrecision(), epochMicros, fraction); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return SqlTimestamp.newInstance(getPrecision(), getEpochMicros(valueBlock, valuePosition), getFraction(valueBlock, valuePosition)); } @Override @@ -161,14 +163,14 @@ public int getFlatFixedSize() return Long.BYTES + Integer.BYTES; } - private static long getEpochMicros(Block block, int position) + private static long getEpochMicros(Fixed12Block block, int position) { - return block.getLong(position, 0); + return block.getFixed12First(position); } - private static int getFraction(Block block, int position) + private static int getFraction(Fixed12Block block, int position) { - return block.getInt(position, SIZE_OF_LONG); + return block.getFixed12Second(position); } @Override @@ -214,7 +216,7 @@ private static void writeFlat( @ScalarOperator(READ_VALUE) private static void writeBlockFlat( - @BlockPosition Block block, + @BlockPosition Fixed12Block block, @BlockIndex int position, byte[] fixedSizeSlice, int fixedSizeOffset, @@ -236,7 +238,7 @@ private static boolean equalOperator(LongTimestamp left, LongTimestamp right) } @ScalarOperator(EQUAL) - private static boolean equalOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean equalOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return equal( getEpochMicros(leftBlock, leftPosition), @@ -257,7 +259,7 @@ private static long xxHash64Operator(LongTimestamp value) } @ScalarOperator(XX_HASH_64) - private static long xxHash64Operator(@BlockPosition Block block, @BlockIndex int position) + private static long xxHash64Operator(@BlockPosition Fixed12Block block, @BlockIndex int position) { return xxHash64( getEpochMicros(block, position), @@ -276,7 +278,7 @@ private static long comparisonOperator(LongTimestamp left, LongTimestamp right) } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static long comparisonOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return comparison( getEpochMicros(leftBlock, leftPosition), @@ -301,7 +303,7 @@ private static boolean lessThanOperator(LongTimestamp left, LongTimestamp right) } @ScalarOperator(LESS_THAN) - private static boolean lessThanOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean lessThanOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return lessThan( getEpochMicros(leftBlock, leftPosition), @@ -323,7 +325,7 @@ private static boolean lessThanOrEqualOperator(LongTimestamp left, LongTimestamp } @ScalarOperator(LESS_THAN_OR_EQUAL) - private static boolean lessThanOrEqualOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean lessThanOrEqualOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return lessThanOrEqual( getEpochMicros(leftBlock, leftPosition), diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/LongTimestampWithTimeZoneType.java b/core/trino-spi/src/main/java/io/trino/spi/type/LongTimestampWithTimeZoneType.java index 1e131f13b7ce..cd58e49d993c 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/LongTimestampWithTimeZoneType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/LongTimestampWithTimeZoneType.java @@ -17,6 +17,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.Fixed12Block; import io.trino.spi.block.Fixed12BlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; @@ -64,7 +65,7 @@ final class LongTimestampWithTimeZoneType public LongTimestampWithTimeZoneType(int precision) { - super(precision, LongTimestampWithTimeZone.class); + super(precision, LongTimestampWithTimeZone.class, Fixed12Block.class); if (precision < MAX_SHORT_PRECISION + 1 || precision > MAX_PRECISION) { throw new IllegalArgumentException(format("Precision must be in the range [%s, %s]", MAX_SHORT_PRECISION + 1, MAX_PRECISION)); @@ -117,15 +118,19 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - write(blockBuilder, getPackedEpochMillis(block, position), getPicosOfMilli(block, position)); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + write(blockBuilder, getPackedEpochMillis(valueBlock, valuePosition), getPicosOfMilli(valueBlock, valuePosition)); } } @Override public Object getObject(Block block, int position) { - long packedEpochMillis = getPackedEpochMillis(block, position); - int picosOfMilli = getPicosOfMilli(block, position); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + long packedEpochMillis = getPackedEpochMillis(valueBlock, valuePosition); + int picosOfMilli = getPicosOfMilli(valueBlock, valuePosition); return LongTimestampWithTimeZone.fromEpochMillisAndFraction(unpackMillisUtc(packedEpochMillis), picosOfMilli, unpackZoneKey(packedEpochMillis)); } @@ -152,8 +157,10 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - long packedEpochMillis = getPackedEpochMillis(block, position); - int picosOfMilli = getPicosOfMilli(block, position); + Fixed12Block valueBlock = (Fixed12Block) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + long packedEpochMillis = getPackedEpochMillis(valueBlock, valuePosition); + int picosOfMilli = getPicosOfMilli(valueBlock, valuePosition); return SqlTimestampWithTimeZone.newInstance(getPrecision(), unpackMillisUtc(packedEpochMillis), picosOfMilli, unpackZoneKey(packedEpochMillis)); } @@ -200,19 +207,19 @@ public Optional getNextValue(Object value) return Optional.of(LongTimestampWithTimeZone.fromEpochMillisAndFraction(epochMillis, picosOfMilli, UTC_KEY)); } - private static long getPackedEpochMillis(Block block, int position) + private static long getPackedEpochMillis(Fixed12Block block, int position) { - return block.getLong(position, 0); + return block.getFixed12First(position); } - private static long getEpochMillis(Block block, int position) + private static long getEpochMillis(Fixed12Block block, int position) { return unpackMillisUtc(getPackedEpochMillis(block, position)); } - private static int getPicosOfMilli(Block block, int position) + private static int getPicosOfMilli(Fixed12Block block, int position) { - return block.getInt(position, SIZE_OF_LONG); + return block.getFixed12Second(position); } @ScalarOperator(READ_VALUE) @@ -252,7 +259,7 @@ private static void writeFlat( @ScalarOperator(READ_VALUE) private static void writeBlockFlat( - @BlockPosition Block block, + @BlockPosition Fixed12Block block, @BlockIndex int position, byte[] fixedSizeSlice, int fixedSizeOffset, @@ -274,7 +281,7 @@ private static boolean equalOperator(LongTimestampWithTimeZone left, LongTimesta } @ScalarOperator(EQUAL) - private static boolean equalOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean equalOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return equal( getEpochMillis(leftBlock, leftPosition), @@ -296,7 +303,7 @@ private static long xxHash64Operator(LongTimestampWithTimeZone value) } @ScalarOperator(XX_HASH_64) - private static long xxHash64Operator(@BlockPosition Block block, @BlockIndex int position) + private static long xxHash64Operator(@BlockPosition Fixed12Block block, @BlockIndex int position) { return xxHash64( getEpochMillis(block, position), @@ -315,7 +322,7 @@ private static long comparisonOperator(LongTimestampWithTimeZone left, LongTimes } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static long comparisonOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return comparison( getEpochMillis(leftBlock, leftPosition), @@ -340,7 +347,7 @@ private static boolean lessThanOperator(LongTimestampWithTimeZone left, LongTime } @ScalarOperator(LESS_THAN) - private static boolean lessThanOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean lessThanOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return lessThan( getEpochMillis(leftBlock, leftPosition), @@ -362,7 +369,7 @@ private static boolean lessThanOrEqualOperator(LongTimestampWithTimeZone left, L } @ScalarOperator(LESS_THAN_OR_EQUAL) - private static boolean lessThanOrEqualOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean lessThanOrEqualOperator(@BlockPosition Fixed12Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Fixed12Block rightBlock, @BlockIndex int rightPosition) { return lessThanOrEqual( getEpochMillis(leftBlock, leftPosition), diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/MapType.java b/core/trino-spi/src/main/java/io/trino/spi/type/MapType.java index 2e2283a13746..049a0594db92 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/MapType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/MapType.java @@ -111,7 +111,7 @@ public class MapType private final MethodHandle keyBlockNativeEqual; private final MethodHandle keyBlockEqual; - // this field is used in double checked locking + // this field is used in double-checked locking @SuppressWarnings("FieldAccessedSynchronizedAndUnsynchronized") private volatile TypeOperatorDeclaration typeOperatorDeclaration; @@ -122,7 +122,8 @@ public MapType(Type keyType, Type valueType, TypeOperators typeOperators) StandardTypes.MAP, TypeSignatureParameter.typeParameter(keyType.getTypeSignature()), TypeSignatureParameter.typeParameter(valueType.getTypeSignature())), - SqlMap.class); + SqlMap.class, + MapBlock.class); if (!keyType.isComparable()) { throw new IllegalArgumentException(format("key type must be comparable, got %s", keyType)); } @@ -291,7 +292,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - SqlMap sqlMap = block.getObject(position, SqlMap.class); + SqlMap sqlMap = getObject(block, position); int rawOffset = sqlMap.getRawOffset(); Block rawKeyBlock = sqlMap.getRawKeyBlock(); Block rawValueBlock = sqlMap.getRawValueBlock(); @@ -318,7 +319,7 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) @Override public SqlMap getObject(Block block, int position) { - return block.getObject(position, SqlMap.class); + return read((MapBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -463,7 +464,7 @@ public String getDisplayName() return "map(" + keyType.getDisplayName() + ", " + valueType.getDisplayName() + ")"; } - public Block createBlockFromKeyValue(Optional mapIsNull, int[] offsets, Block keyBlock, Block valueBlock) + public MapBlock createBlockFromKeyValue(Optional mapIsNull, int[] offsets, Block keyBlock, Block valueBlock) { return MapBlock.fromKeyValueBlock( mapIsNull, @@ -544,6 +545,11 @@ private static long invokeHashOperator(MethodHandle hashOperator, Block block, i return (long) hashOperator.invokeExact((Block) block, position); } + private static SqlMap read(MapBlock block, int position) + { + return block.getMap(position); + } + private static SqlMap readFlat( MapType mapType, MethodHandle keyReadOperator, @@ -825,7 +831,7 @@ private static boolean indeterminate(MethodHandle valueIndeterminateFunction, Sq Block rawValueBlock = sqlMap.getRawValueBlock(); for (int i = 0; i < sqlMap.getSize(); i++) { - // since maps are not allowed to have indeterminate keys we only check values here + // since maps are not allowed to have indeterminate keys, we only check values here if (rawValueBlock.isNull(rawOffset + i)) { return true; } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/QuantileDigestType.java b/core/trino-spi/src/main/java/io/trino/spi/type/QuantileDigestType.java index 043a54ed99a3..7f3bb4cb785e 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/QuantileDigestType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/QuantileDigestType.java @@ -17,6 +17,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; @@ -39,7 +40,9 @@ public QuantileDigestType(Type valueType) @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } @Override @@ -61,7 +64,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return new SqlVarbinary(block.getSlice(position, 0, block.getSliceLength(position)).getBytes()); + return new SqlVarbinary(getSlice(block, position).getBytes()); } public Type getValueType() diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/RealType.java b/core/trino-spi/src/main/java/io/trino/spi/type/RealType.java index 50385018b295..da26b556cefd 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/RealType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/RealType.java @@ -76,7 +76,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position public float getFloat(Block block, int position) { - return intBitsToFloat(block.getInt(position, 0)); + return intBitsToFloat(getInt(block, position)); } @Override @@ -137,6 +137,7 @@ private static void writeFlat( INT_HANDLE.set(fixedSizeSlice, fixedSizeOffset, (int) value); } + @SuppressWarnings("FloatingPointEquality") @ScalarOperator(EQUAL) private static boolean equalOperator(long left, long right) { @@ -163,6 +164,7 @@ private static long xxHash64Operator(long value) return XxHash64.hash(floatToIntBits(realValue)); } + @SuppressWarnings("FloatingPointEquality") @ScalarOperator(IS_DISTINCT_FROM) private static boolean distinctFromOperator(long left, @IsNull boolean leftNull, long right, @IsNull boolean rightNull) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/RowType.java b/core/trino-spi/src/main/java/io/trino/spi/type/RowType.java index 65ccdfa9cab1..a5115075e4d5 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/RowType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/RowType.java @@ -18,6 +18,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.RowBlock; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.SqlRow; import io.trino.spi.connector.ConnectorSession; @@ -127,7 +128,7 @@ public class RowType private RowType(TypeSignature typeSignature, List originalFields) { - super(typeSignature, SqlRow.class); + super(typeSignature, SqlRow.class, RowBlock.class); this.fields = List.copyOf(originalFields); this.fieldTypes = fields.stream() @@ -268,7 +269,7 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) @Override public SqlRow getObject(Block block, int position) { - return block.getObject(position, SqlRow.class); + return read((RowBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -425,6 +426,11 @@ private List getReadValueOperatorMethodHandles(TypeOperato new OperatorMethodHandle(WRITE_FLAT_CONVENTION, writeFlat)); } + private static SqlRow read(RowBlock block, int position) + { + return block.getRow(position); + } + private static SqlRow megamorphicReadFlat( RowType rowType, List fieldReadFlatMethods, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/ShortDecimalType.java b/core/trino-spi/src/main/java/io/trino/spi/type/ShortDecimalType.java index 0460e018420a..adda87f36d0c 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/ShortDecimalType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/ShortDecimalType.java @@ -17,9 +17,12 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.LongArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -69,8 +72,8 @@ static ShortDecimalType getInstance(int precision, int scale) private ShortDecimalType(int precision, int scale) { - super(precision, scale, long.class); - checkArgument(0 < precision && precision <= Decimals.MAX_SHORT_PRECISION, "Invalid precision: %s", precision); + super(precision, scale, long.class, LongArrayBlock.class); + checkArgument(0 < precision && precision <= MAX_SHORT_PRECISION, "Invalid precision: %s", precision); checkArgument(0 <= scale && scale <= precision, "Invalid scale for precision %s: %s", precision, scale); } @@ -119,8 +122,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position if (block.isNull(position)) { return null; } - long unscaledValue = block.getLong(position, 0); - return new SqlDecimal(BigInteger.valueOf(unscaledValue), getPrecision(), getScale()); + return new SqlDecimal(BigInteger.valueOf(getLong(block, position)), getPrecision(), getScale()); } @Override @@ -130,14 +132,14 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - writeLong(blockBuilder, block.getLong(position, 0)); + writeLong(blockBuilder, getLong(block, position)); } } @Override public long getLong(Block block, int position) { - return block.getLong(position, 0); + return read((LongArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -158,6 +160,12 @@ public Optional> getDiscreteValues(Range range) return Optional.of(LongStream.rangeClosed((long) range.getMin(), (long) range.getMax()).boxed()); } + @ScalarOperator(READ_VALUE) + private static long read(@BlockPosition LongArrayBlock block, @BlockIndex int position) + { + return block.getLong(position); + } + @ScalarOperator(READ_VALUE) private static long readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimeWithTimeZoneType.java b/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimeWithTimeZoneType.java index 2297ae355180..d679a0d37357 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimeWithTimeZoneType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimeWithTimeZoneType.java @@ -13,14 +13,16 @@ */ package io.trino.spi.type; -import io.airlift.slice.Slice; import io.airlift.slice.XxHash64; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.LongArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -56,7 +58,7 @@ final class ShortTimeWithTimeZoneType public ShortTimeWithTimeZoneType(int precision) { - super(precision, long.class); + super(precision, long.class, LongArrayBlock.class); if (precision < 0 || precision > MAX_SHORT_PRECISION) { throw new IllegalArgumentException(format("Precision must be in the range [0, %s]", MAX_SHORT_PRECISION)); @@ -70,42 +72,36 @@ public TypeOperatorDeclaration getTypeOperatorDeclaration(TypeOperators typeOper } @Override - public final int getFixedSize() + public int getFixedSize() { return Long.BYTES; } @Override - public final long getLong(Block block, int position) + public long getLong(Block block, int position) { - return block.getLong(position, 0); + return read((LongArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override - public final Slice getSlice(Block block, int position) - { - return block.getSlice(position, 0, getFixedSize()); - } - - @Override - public final void writeLong(BlockBuilder blockBuilder, long value) + public void writeLong(BlockBuilder blockBuilder, long value) { ((LongArrayBlockBuilder) blockBuilder).writeLong(value); } @Override - public final void appendTo(Block block, int position, BlockBuilder blockBuilder) + public void appendTo(Block block, int position, BlockBuilder blockBuilder) { if (block.isNull(position)) { blockBuilder.appendNull(); } else { - writeLong(blockBuilder, block.getLong(position, 0)); + writeLong(blockBuilder, getLong(block, position)); } } @Override - public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { int maxBlockSizeInBytes; if (blockBuilderStatus == null) { @@ -120,13 +116,13 @@ public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStat } @Override - public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) { return createBlockBuilder(blockBuilderStatus, expectedEntries, Long.BYTES); } @Override - public final BlockBuilder createFixedSizeBlockBuilder(int positionCount) + public BlockBuilder createFixedSizeBlockBuilder(int positionCount) { return new LongArrayBlockBuilder(null, positionCount); } @@ -138,7 +134,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - long value = block.getLong(position, 0); + long value = getLong(block, position); return SqlTimeWithTimeZone.newInstance(getPrecision(), unpackTimeNanos(value) * PICOSECONDS_PER_NANOSECOND, unpackOffsetMinutes(value)); } @@ -148,6 +144,12 @@ public int getFlatFixedSize() return Long.BYTES; } + @ScalarOperator(READ_VALUE) + private static long read(@BlockPosition LongArrayBlock block, @BlockIndex int position) + { + return block.getLong(position); + } + @ScalarOperator(READ_VALUE) private static long readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimestampType.java b/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimestampType.java index f3ea087dc1d4..6a86b1febd78 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimestampType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimestampType.java @@ -17,9 +17,12 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.LongArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -48,7 +51,7 @@ * The value is encoded as microseconds from the 1970-01-01 00:00:00 epoch and is to be interpreted as * local date time without regards to any time zone. */ -class ShortTimestampType +final class ShortTimestampType extends TimestampType { private static final TypeOperatorDeclaration TYPE_OPERATOR_DECLARATION = extractOperatorDeclaration(ShortTimestampType.class, lookup(), long.class); @@ -57,13 +60,13 @@ class ShortTimestampType public ShortTimestampType(int precision) { - super(precision, long.class); + super(precision, long.class, LongArrayBlock.class); if (precision < 0 || precision > MAX_SHORT_PRECISION) { throw new IllegalArgumentException(format("Precision must be in the range [0, %s]", MAX_SHORT_PRECISION)); } - // ShortTimestampType instances are created eagerly and shared so it's OK to precompute some things. + // ShortTimestampType instances are created eagerly and shared, so it's OK to precompute some things. if (getPrecision() == MAX_SHORT_PRECISION) { range = new Range(Long.MIN_VALUE, Long.MAX_VALUE); } @@ -80,25 +83,25 @@ public TypeOperatorDeclaration getTypeOperatorDeclaration(TypeOperators typeOper } @Override - public final int getFixedSize() + public int getFixedSize() { return Long.BYTES; } @Override - public final long getLong(Block block, int position) + public long getLong(Block block, int position) { - return block.getLong(position, 0); + return read((LongArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override - public final void writeLong(BlockBuilder blockBuilder, long value) + public void writeLong(BlockBuilder blockBuilder, long value) { ((LongArrayBlockBuilder) blockBuilder).writeLong(value); } @Override - public final void appendTo(Block block, int position, BlockBuilder blockBuilder) + public void appendTo(Block block, int position, BlockBuilder blockBuilder) { if (block.isNull(position)) { blockBuilder.appendNull(); @@ -109,7 +112,7 @@ public final void appendTo(Block block, int position, BlockBuilder blockBuilder) } @Override - public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { int maxBlockSizeInBytes; if (blockBuilderStatus == null) { @@ -124,13 +127,13 @@ public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStat } @Override - public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) { return createBlockBuilder(blockBuilderStatus, expectedEntries, Long.BYTES); } @Override - public final BlockBuilder createFixedSizeBlockBuilder(int positionCount) + public BlockBuilder createFixedSizeBlockBuilder(int positionCount) { return new LongArrayBlockBuilder(null, positionCount); } @@ -176,6 +179,12 @@ public Optional getNextValue(Object value) return Optional.of((long) value + rescale(1_000_000, getPrecision(), 0)); } + @ScalarOperator(READ_VALUE) + private static long read(@BlockPosition LongArrayBlock block, @BlockIndex int position) + { + return block.getLong(position); + } + @ScalarOperator(READ_VALUE) private static long readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimestampWithTimeZoneType.java b/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimestampWithTimeZoneType.java index 42db05208b69..401ba757e344 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimestampWithTimeZoneType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/ShortTimestampWithTimeZoneType.java @@ -13,14 +13,16 @@ */ package io.trino.spi.type; -import io.airlift.slice.Slice; import io.airlift.slice.XxHash64; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.LongArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -56,7 +58,7 @@ final class ShortTimestampWithTimeZoneType public ShortTimestampWithTimeZoneType(int precision) { - super(precision, long.class); + super(precision, long.class, LongArrayBlock.class); if (precision < 0 || precision > MAX_SHORT_PRECISION) { throw new IllegalArgumentException(format("Precision must be in the range [0, %s]", MAX_SHORT_PRECISION)); @@ -70,42 +72,36 @@ public TypeOperatorDeclaration getTypeOperatorDeclaration(TypeOperators typeOper } @Override - public final int getFixedSize() + public int getFixedSize() { return Long.BYTES; } @Override - public final long getLong(Block block, int position) + public long getLong(Block block, int position) { - return block.getLong(position, 0); + return read((LongArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override - public final Slice getSlice(Block block, int position) - { - return block.getSlice(position, 0, getFixedSize()); - } - - @Override - public final void writeLong(BlockBuilder blockBuilder, long value) + public void writeLong(BlockBuilder blockBuilder, long value) { ((LongArrayBlockBuilder) blockBuilder).writeLong(value); } @Override - public final void appendTo(Block block, int position, BlockBuilder blockBuilder) + public void appendTo(Block block, int position, BlockBuilder blockBuilder) { if (block.isNull(position)) { blockBuilder.appendNull(); } else { - writeLong(blockBuilder, block.getLong(position, 0)); + writeLong(blockBuilder, getLong(block, position)); } } @Override - public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { int maxBlockSizeInBytes; if (blockBuilderStatus == null) { @@ -120,13 +116,13 @@ public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStat } @Override - public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) { return createBlockBuilder(blockBuilderStatus, expectedEntries, Long.BYTES); } @Override - public final BlockBuilder createFixedSizeBlockBuilder(int positionCount) + public BlockBuilder createFixedSizeBlockBuilder(int positionCount) { return new LongArrayBlockBuilder(null, positionCount); } @@ -138,7 +134,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - long value = block.getLong(position, 0); + long value = getLong(block, position); return SqlTimestampWithTimeZone.newInstance(getPrecision(), unpackMillisUtc(value), 0, unpackZoneKey(value)); } @@ -148,6 +144,12 @@ public int getFlatFixedSize() return Long.BYTES; } + @ScalarOperator(READ_VALUE) + private static long read(@BlockPosition LongArrayBlock block, @BlockIndex int position) + { + return block.getLong(position); + } + @ScalarOperator(READ_VALUE) private static long readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/SmallintType.java b/core/trino-spi/src/main/java/io/trino/spi/type/SmallintType.java index fb77472356a6..2114679bf4b5 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/SmallintType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/SmallintType.java @@ -19,8 +19,11 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; import io.trino.spi.block.PageBuilderStatus; +import io.trino.spi.block.ShortArrayBlock; import io.trino.spi.block.ShortArrayBlockBuilder; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -56,7 +59,7 @@ public final class SmallintType private SmallintType() { - super(new TypeSignature(StandardTypes.SMALLINT), long.class); + super(new TypeSignature(StandardTypes.SMALLINT), long.class, ShortArrayBlock.class); } @Override @@ -117,7 +120,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return block.getShort(position, 0); + return getShort(block, position); } @Override @@ -161,7 +164,7 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - ((ShortArrayBlockBuilder) blockBuilder).writeShort(block.getShort(position, 0)); + ((ShortArrayBlockBuilder) blockBuilder).writeShort(getShort(block, position)); } } @@ -173,7 +176,7 @@ public long getLong(Block block, int position) public short getShort(Block block, int position) { - return block.getShort(position, 0); + return readShort((ShortArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -188,7 +191,7 @@ public void writeShort(BlockBuilder blockBuilder, short value) ((ShortArrayBlockBuilder) blockBuilder).writeShort(value); } - private void checkValueValid(long value) + private static void checkValueValid(long value) { if (value > Short.MAX_VALUE) { throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Value %d exceeds MAX_SHORT", value)); @@ -217,6 +220,17 @@ public int hashCode() return getClass().hashCode(); } + @ScalarOperator(READ_VALUE) + private static long read(@BlockPosition ShortArrayBlock block, @BlockIndex int position) + { + return readShort(block, position); + } + + private static short readShort(ShortArrayBlock block, int position) + { + return block.getShort(position); + } + @ScalarOperator(READ_VALUE) private static long readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/TimeType.java b/core/trino-spi/src/main/java/io/trino/spi/type/TimeType.java index 43533598b0e7..1c66925c69c7 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/TimeType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/TimeType.java @@ -98,7 +98,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return SqlTime.newInstance(precision, block.getLong(position, 0)); + return SqlTime.newInstance(precision, getLong(block, position)); } @ScalarOperator(READ_VALUE) diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/TimeWithTimeZoneType.java b/core/trino-spi/src/main/java/io/trino/spi/type/TimeWithTimeZoneType.java index e3d0fbc706c6..ee9e406080bb 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/TimeWithTimeZoneType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/TimeWithTimeZoneType.java @@ -14,6 +14,7 @@ package io.trino.spi.type; import io.trino.spi.TrinoException; +import io.trino.spi.block.ValueBlock; import static io.trino.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static java.lang.String.format; @@ -59,9 +60,9 @@ public static TimeWithTimeZoneType createTimeWithTimeZoneType(int precision) return TYPES[precision]; } - protected TimeWithTimeZoneType(int precision, Class javaType) + protected TimeWithTimeZoneType(int precision, Class javaType, Class valueBlockType) { - super(new TypeSignature(StandardTypes.TIME_WITH_TIME_ZONE, TypeSignatureParameter.numericParameter(precision)), javaType); + super(new TypeSignature(StandardTypes.TIME_WITH_TIME_ZONE, TypeSignatureParameter.numericParameter(precision)), javaType, valueBlockType); this.precision = precision; } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/TimestampType.java b/core/trino-spi/src/main/java/io/trino/spi/type/TimestampType.java index 5d6cd360371d..03749b781ade 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/TimestampType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/TimestampType.java @@ -14,6 +14,7 @@ package io.trino.spi.type; import io.trino.spi.TrinoException; +import io.trino.spi.block.ValueBlock; import static io.trino.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static java.lang.String.format; @@ -24,9 +25,10 @@ * @see ShortTimestampType * @see LongTimestampType */ -public abstract class TimestampType +public abstract sealed class TimestampType extends AbstractType implements FixedWidthType + permits LongTimestampType, ShortTimestampType { public static final int MAX_PRECISION = 12; @@ -57,9 +59,9 @@ public static TimestampType createTimestampType(int precision) return TYPES[precision]; } - TimestampType(int precision, Class javaType) + TimestampType(int precision, Class javaType, Class valueBlockType) { - super(new TypeSignature(StandardTypes.TIMESTAMP, TypeSignatureParameter.numericParameter(precision)), javaType); + super(new TypeSignature(StandardTypes.TIMESTAMP, TypeSignatureParameter.numericParameter(precision)), javaType, valueBlockType); this.precision = precision; } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/TimestampWithTimeZoneType.java b/core/trino-spi/src/main/java/io/trino/spi/type/TimestampWithTimeZoneType.java index d900d47553c8..4f75e8176c5f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/TimestampWithTimeZoneType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/TimestampWithTimeZoneType.java @@ -14,6 +14,7 @@ package io.trino.spi.type; import io.trino.spi.TrinoException; +import io.trino.spi.block.ValueBlock; import static io.trino.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static java.lang.String.format; @@ -56,9 +57,9 @@ public static TimestampWithTimeZoneType createTimestampWithTimeZoneType(int prec return TYPES[precision]; } - TimestampWithTimeZoneType(int precision, Class javaType) + TimestampWithTimeZoneType(int precision, Class javaType, Class valueBlockType) { - super(new TypeSignature(StandardTypes.TIMESTAMP_WITH_TIME_ZONE, TypeSignatureParameter.numericParameter(precision)), javaType); + super(new TypeSignature(StandardTypes.TIMESTAMP_WITH_TIME_ZONE, TypeSignatureParameter.numericParameter(precision)), javaType, valueBlockType); if (precision < 0 || precision > MAX_PRECISION) { throw new IllegalArgumentException(format("Precision must be in the range [0, %s]", MAX_PRECISION)); diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/TinyintType.java b/core/trino-spi/src/main/java/io/trino/spi/type/TinyintType.java index 10ed974602a2..b8b254eef3d4 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/TinyintType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/TinyintType.java @@ -18,9 +18,12 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.ByteArrayBlock; import io.trino.spi.block.ByteArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BlockIndex; +import io.trino.spi.function.BlockPosition; import io.trino.spi.function.FlatFixed; import io.trino.spi.function.FlatFixedOffset; import io.trino.spi.function.FlatVariableWidth; @@ -52,7 +55,7 @@ public final class TinyintType private TinyintType() { - super(new TypeSignature(StandardTypes.TINYINT), long.class); + super(new TypeSignature(StandardTypes.TINYINT), long.class, ByteArrayBlock.class); } @Override @@ -113,7 +116,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return block.getByte(position, 0); + return getByte(block, position); } @Override @@ -157,7 +160,7 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - writeByte(blockBuilder, block.getByte(position, 0)); + writeByte(blockBuilder, getByte(block, position)); } } @@ -169,7 +172,7 @@ public long getLong(Block block, int position) public byte getByte(Block block, int position) { - return block.getByte(position, 0); + return readByte((ByteArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -184,7 +187,7 @@ public void writeByte(BlockBuilder blockBuilder, byte value) ((ByteArrayBlockBuilder) blockBuilder).writeByte(value); } - private void checkValueValid(long value) + private static void checkValueValid(long value) { if (value > Byte.MAX_VALUE) { throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Value %d exceeds MAX_BYTE", value)); @@ -212,6 +215,17 @@ public int hashCode() return getClass().hashCode(); } + @ScalarOperator(READ_VALUE) + private static long read(@BlockPosition ByteArrayBlock block, @BlockIndex int position) + { + return readByte(block, position); + } + + private static byte readByte(ByteArrayBlock block, int position) + { + return block.getByte(position); + } + @ScalarOperator(READ_VALUE) private static long readFlat( @FlatFixed byte[] fixedSizeSlice, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/Type.java b/core/trino-spi/src/main/java/io/trino/spi/type/Type.java index ae4c2347a272..519abbfbb462 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/Type.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/Type.java @@ -18,6 +18,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.ValueBlock; import io.trino.spi.connector.ConnectorSession; import java.util.List; @@ -81,6 +82,11 @@ default TypeOperatorDeclaration getTypeOperatorDeclaration(TypeOperators typeOpe */ Class getJavaType(); + /** + * Gets the ValueBlock type used to store values of this type. + */ + Class getValueBlockType(); + /** * For parameterized types returns the list of parameters. */ diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/TypeOperatorDeclaration.java b/core/trino-spi/src/main/java/io/trino/spi/type/TypeOperatorDeclaration.java index 71ffeedca4a6..bc604ded7dcd 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/TypeOperatorDeclaration.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/TypeOperatorDeclaration.java @@ -15,6 +15,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.BlockIndex; import io.trino.spi.function.BlockPosition; @@ -46,6 +47,8 @@ import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -462,13 +465,18 @@ private void verifyMethodHandleSignature(int expectedArgumentCount, Class ret case BLOCK_POSITION_NOT_NULL: case BLOCK_POSITION: checkArgument(parameterType.equals(Block.class) && methodType.parameterType(parameterIndex + 1).equals(int.class), - "Expected BLOCK_POSITION argument have parameters Block and int"); + "Expected BLOCK_POSITION argument to have parameters Block and int"); + break; + case VALUE_BLOCK_POSITION_NOT_NULL: + case VALUE_BLOCK_POSITION: + checkArgument(Block.class.isAssignableFrom(parameterType) && methodType.parameterType(parameterIndex + 1).equals(int.class), + "Expected VALUE_BLOCK_POSITION argument to have parameters ValueBlock and int"); break; case FLAT: checkArgument(parameterType.equals(byte[].class) && methodType.parameterType(parameterIndex + 1).equals(int.class) && methodType.parameterType(parameterIndex + 2).equals(byte[].class), - "Expected FLAT argument have parameters byte[], int, and byte[]"); + "Expected FLAT argument to have parameters byte[], int, and byte[]"); break; case FUNCTION: throw new IllegalArgumentException("Function argument convention is not supported in type operators"); @@ -506,6 +514,10 @@ private void verifyMethodHandleSignature(int expectedArgumentCount, Class ret default: throw new UnsupportedOperationException("Unknown return convention: " + returnConvention); } + + if (operatorMethodHandle.getCallingConvention().getArgumentConventions().stream().anyMatch(argumentConvention -> argumentConvention == BLOCK_POSITION || argumentConvention == BLOCK_POSITION_NOT_NULL)) { + throw new IllegalArgumentException("BLOCK_POSITION argument convention is not allowed for type operators"); + } } private static InvocationConvention parseInvocationConvention(OperatorType operatorType, Class typeJavaType, Method method, Class expectedReturnType) @@ -576,11 +588,14 @@ private static InvocationArgumentConvention extractNextArgumentConvention( Method method) { if (isAnnotationPresent(parameterAnnotations.get(0), BlockPosition.class)) { - if (parameterTypes.size() > 1 && - isAnnotationPresent(parameterAnnotations.get(1), BlockIndex.class) && - parameterTypes.get(0).equals(Block.class) && - parameterTypes.get(1).equals(int.class)) { - return isAnnotationPresent(parameterAnnotations.get(0), SqlNullable.class) ? BLOCK_POSITION : BLOCK_POSITION_NOT_NULL; + if (parameterTypes.size() > 1 && isAnnotationPresent(parameterAnnotations.get(1), BlockIndex.class)) { + if (!ValueBlock.class.isAssignableFrom(parameterTypes.get(0))) { + throw new IllegalArgumentException("@BlockPosition argument must be a ValueBlock type for %s operator: %s".formatted(operatorType, method)); + } + if (parameterTypes.get(1) != int.class) { + throw new IllegalArgumentException("@BlockIndex argument must be type int for %s operator: %s".formatted(operatorType, method)); + } + return isAnnotationPresent(parameterAnnotations.get(0), SqlNullable.class) ? VALUE_BLOCK_POSITION : VALUE_BLOCK_POSITION_NOT_NULL; } } else if (isAnnotationPresent(parameterAnnotations.get(0), SqlNullable.class)) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/TypeUtils.java b/core/trino-spi/src/main/java/io/trino/spi/type/TypeUtils.java index fd599d361bf2..7e5cd97a39ed 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/TypeUtils.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/TypeUtils.java @@ -18,6 +18,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import jakarta.annotation.Nullable; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; @@ -60,11 +61,11 @@ public static Object readNativeValue(Type type, Block block, int position) return type.getObject(block, position); } - public static Block writeNativeValue(Type type, @Nullable Object value) + public static ValueBlock writeNativeValue(Type type, @Nullable Object value) { BlockBuilder blockBuilder = type.createBlockBuilder(null, 1); writeNativeValue(type, blockBuilder, value); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } /** diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/UuidType.java b/core/trino-spi/src/main/java/io/trino/spi/type/UuidType.java index 2897b79c785f..228ad712bff7 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/UuidType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/UuidType.java @@ -19,6 +19,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.Int128ArrayBlock; import io.trino.spi.block.Int128ArrayBlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorSession; @@ -61,7 +62,7 @@ public class UuidType private UuidType() { - super(new TypeSignature(StandardTypes.UUID), Slice.class); + super(new TypeSignature(StandardTypes.UUID), Slice.class, Int128ArrayBlock.class); } @Override @@ -121,8 +122,10 @@ public Object getObjectValue(ConnectorSession session, Block block, int position if (block.isNull(position)) { return null; } - long high = reverseBytes(block.getLong(position, 0)); - long low = reverseBytes(block.getLong(position, SIZE_OF_LONG)); + Int128ArrayBlock valueBlock = (Int128ArrayBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + long high = reverseBytes(valueBlock.getInt128High(valuePosition)); + long low = reverseBytes(valueBlock.getInt128Low(valuePosition)); return new UUID(high, low).toString(); } @@ -133,9 +136,9 @@ public void appendTo(Block block, int position, BlockBuilder blockBuilder) blockBuilder.appendNull(); } else { - ((Int128ArrayBlockBuilder) blockBuilder).writeInt128( - block.getLong(position, 0), - block.getLong(position, SIZE_OF_LONG)); + Int128ArrayBlock valueBlock = (Int128ArrayBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + ((Int128ArrayBlockBuilder) blockBuilder).writeInt128(valueBlock.getInt128High(valuePosition), valueBlock.getInt128Low(valuePosition)); } } @@ -159,10 +162,7 @@ public void writeSlice(BlockBuilder blockBuilder, Slice value, int offset, int l @Override public final Slice getSlice(Block block, int position) { - Slice value = Slices.allocate(INT128_BYTES); - value.setLong(0, block.getLong(position, 0)); - value.setLong(SIZE_OF_LONG, block.getLong(position, SIZE_OF_LONG)); - return value; + return read((Int128ArrayBlock) block.getUnderlyingValueBlock(), block.getUnderlyingValuePosition(position)); } @Override @@ -189,6 +189,15 @@ public static UUID trinoUuidToJavaUuid(Slice uuid) reverseBytes(uuid.getLong(SIZE_OF_LONG))); } + @ScalarOperator(READ_VALUE) + private static Slice read(@BlockPosition Int128ArrayBlock block, @BlockIndex int position) + { + Slice value = Slices.allocate(INT128_BYTES); + value.setLong(0, block.getInt128High(position)); + value.setLong(SIZE_OF_LONG, block.getInt128Low(position)); + return value; + } + @ScalarOperator(READ_VALUE) private static Slice readFlat( @FlatFixed byte[] fixedSizeSlice, @@ -232,13 +241,13 @@ private static boolean equalOperator(Slice left, Slice right) } @ScalarOperator(EQUAL) - private static boolean equalOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static boolean equalOperator(@BlockPosition Int128ArrayBlock leftBlock, @BlockIndex int leftPosition, @BlockPosition Int128ArrayBlock rightBlock, @BlockIndex int rightPosition) { return equal( - leftBlock.getLong(leftPosition, 0), - leftBlock.getLong(leftPosition, SIZE_OF_LONG), - rightBlock.getLong(rightPosition, 0), - rightBlock.getLong(rightPosition, SIZE_OF_LONG)); + leftBlock.getInt128High(leftPosition), + leftBlock.getInt128Low(leftPosition), + rightBlock.getInt128High(rightPosition), + rightBlock.getInt128Low(rightPosition)); } private static boolean equal(long leftLow, long leftHigh, long rightLow, long rightHigh) @@ -253,9 +262,9 @@ private static long xxHash64Operator(Slice value) } @ScalarOperator(XX_HASH_64) - private static long xxHash64Operator(@BlockPosition Block block, @BlockIndex int position) + private static long xxHash64Operator(@BlockPosition Int128ArrayBlock block, @BlockIndex int position) { - return xxHash64(block.getLong(position, 0), block.getLong(position, SIZE_OF_LONG)); + return xxHash64(block.getInt128High(position), block.getInt128Low(position)); } private static long xxHash64(long low, long high) @@ -274,13 +283,13 @@ private static long comparisonOperator(Slice left, Slice right) } @ScalarOperator(COMPARISON_UNORDERED_LAST) - private static long comparisonOperator(@BlockPosition Block leftBlock, @BlockIndex int leftPosition, @BlockPosition Block rightBlock, @BlockIndex int rightPosition) + private static long comparisonOperator(@BlockPosition Int128ArrayBlock leftBlock, @BlockIndex int leftPosition, @BlockPosition Int128ArrayBlock rightBlock, @BlockIndex int rightPosition) { return compareLittleEndian( - leftBlock.getLong(leftPosition, 0), - leftBlock.getLong(leftPosition, SIZE_OF_LONG), - rightBlock.getLong(rightPosition, 0), - rightBlock.getLong(rightPosition, SIZE_OF_LONG)); + leftBlock.getInt128High(leftPosition), + leftBlock.getInt128Low(leftPosition), + rightBlock.getInt128High(rightPosition), + rightBlock.getInt128Low(rightPosition)); } private static int compareLittleEndian(long leftLow64le, long leftHigh64le, long rightLow64le, long rightHigh64le) diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/VarbinaryType.java b/core/trino-spi/src/main/java/io/trino/spi/type/VarbinaryType.java index e001c983af65..07f192cb83cf 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/VarbinaryType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/VarbinaryType.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; @@ -69,13 +70,15 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return new SqlVarbinary(block.getSlice(position, 0, block.getSliceLength(position)).getBytes()); + return new SqlVarbinary(getSlice(block, position).getBytes()); } @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } @Override diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/VarcharType.java b/core/trino-spi/src/main/java/io/trino/spi/type/VarcharType.java index f69c769b5c0f..02aa8ff06d2e 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/VarcharType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/VarcharType.java @@ -19,6 +19,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; @@ -132,7 +133,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - Slice slice = block.getSlice(position, 0, block.getSliceLength(position)); + Slice slice = getSlice(block, position); if (!isUnbounded() && countCodePoints(slice) > length) { throw new IllegalArgumentException(format("Character count exceeds length limit %s: %s", length, sliceRepresentation(slice))); } @@ -161,7 +162,7 @@ public Optional getRange() if (!cachedRangePresent) { if (length > 100) { // The max/min values may be materialized in the plan, so we don't want them to be too large. - // Range comparison against large values are usually nonsensical, too, so no need to support them + // Range comparison against large values is usually nonsensical, too, so no need to support them // beyond a certain size. They specific choice above is arbitrary and can be adjusted if needed. range = Optional.empty(); } @@ -184,7 +185,9 @@ public Optional getRange() @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } public void writeString(BlockBuilder blockBuilder, String value) diff --git a/core/trino-spi/src/test/java/io/trino/spi/block/TestLazyBlock.java b/core/trino-spi/src/test/java/io/trino/spi/block/TestLazyBlock.java index 88d341cea864..540af7561c19 100644 --- a/core/trino-spi/src/test/java/io/trino/spi/block/TestLazyBlock.java +++ b/core/trino-spi/src/test/java/io/trino/spi/block/TestLazyBlock.java @@ -66,21 +66,21 @@ public void testNestedGetLoadedBlock() List actualNotifications = new ArrayList<>(); Block arrayBlock = new IntArrayBlock(1, Optional.empty(), new int[] {0}); LazyBlock lazyArrayBlock = new LazyBlock(1, () -> arrayBlock); - Block dictionaryBlock = DictionaryBlock.create(2, lazyArrayBlock, new int[] {0, 0}); - LazyBlock lazyBlock = new LazyBlock(2, () -> dictionaryBlock); + Block rowBlock = RowBlock.fromFieldBlocks(2, Optional.empty(), new Block[]{lazyArrayBlock}); + LazyBlock lazyBlock = new LazyBlock(2, () -> rowBlock); LazyBlock.listenForLoads(lazyBlock, actualNotifications::add); Block loadedBlock = lazyBlock.getBlock(); - assertThat(loadedBlock).isInstanceOf(DictionaryBlock.class); - assertThat(((DictionaryBlock) loadedBlock).getDictionary()).isInstanceOf(LazyBlock.class); + assertThat(loadedBlock).isInstanceOf(RowBlock.class); + assertThat(((RowBlock) loadedBlock).getRawFieldBlocks().get(0)).isInstanceOf(LazyBlock.class); assertThat(actualNotifications).isEqualTo(ImmutableList.of(loadedBlock)); Block fullyLoadedBlock = lazyBlock.getLoadedBlock(); - assertThat(fullyLoadedBlock).isInstanceOf(DictionaryBlock.class); - assertThat(((DictionaryBlock) fullyLoadedBlock).getDictionary()).isInstanceOf(IntArrayBlock.class); + assertThat(fullyLoadedBlock).isInstanceOf(RowBlock.class); + assertThat(((RowBlock) fullyLoadedBlock).getRawFieldBlocks().get(0)).isInstanceOf(IntArrayBlock.class); assertThat(actualNotifications).isEqualTo(ImmutableList.of(loadedBlock, arrayBlock)); assertThat(lazyBlock.isLoaded()).isTrue(); - assertThat(dictionaryBlock.isLoaded()).isTrue(); + assertThat(rowBlock.isLoaded()).isTrue(); } private static void assertNotificationsRecursive(int depth, Block lazyBlock, List actualNotifications, List expectedNotifications) diff --git a/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java b/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java index 678867e35749..ac0ba86d87d4 100644 --- a/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java +++ b/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java @@ -19,8 +19,12 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.trino.spi.TrinoException; +import io.trino.spi.block.ArrayBlock; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.Fixed12Block; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; import io.trino.spi.function.InvocationConvention.InvocationReturnConvention; import io.trino.spi.type.ArrayType; @@ -36,6 +40,7 @@ import java.lang.invoke.MethodType; import java.util.ArrayList; import java.util.BitSet; +import java.util.EnumSet; import java.util.List; import java.util.stream.IntStream; @@ -52,6 +57,8 @@ import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; @@ -199,6 +206,58 @@ public void testAdaptFromBlockPositionNotNullObjects() verifyAllAdaptations(actualConvention, "blockPositionObjects", RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); } + @Test + public void testAdaptFromValueBlockPosition() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(ARGUMENT_TYPES.size(), VALUE_BLOCK_POSITION), + FAIL_ON_NULL, + false, + true); + String methodName = "valueBlockPosition"; + verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, ARGUMENT_TYPES); + } + + @Test + public void testAdaptFromValueBlockPositionObjects() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(OBJECTS_ARGUMENT_TYPES.size(), VALUE_BLOCK_POSITION), + FAIL_ON_NULL, + false, + true); + String methodName = "valueBlockPositionObjects"; + verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); + } + + @Test + public void testAdaptFromValueBlockPositionNotNull() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(ARGUMENT_TYPES.size(), VALUE_BLOCK_POSITION_NOT_NULL), + FAIL_ON_NULL, + false, + true); + String methodName = "valueBlockPosition"; + verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, ARGUMENT_TYPES); + } + + @Test + public void testAdaptFromValueBlockPositionObjectsNotNull() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(OBJECTS_ARGUMENT_TYPES.size(), VALUE_BLOCK_POSITION_NOT_NULL), + FAIL_ON_NULL, + false, + true); + String methodName = "valueBlockPositionObjects"; + verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); + } + private static void verifyAllAdaptations( InvocationConvention actualConvention, String methodName, @@ -219,7 +278,7 @@ private static void verifyAllAdaptations( throws Throwable { List> allArgumentConventions = allCombinations( - ImmutableList.of(NEVER_NULL, BLOCK_POSITION_NOT_NULL, BOXED_NULLABLE, NULL_FLAG, BLOCK_POSITION, FLAT, IN_OUT), + ImmutableList.of(NEVER_NULL, BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL, BOXED_NULLABLE, NULL_FLAG, BLOCK_POSITION, VALUE_BLOCK_POSITION, FLAT, IN_OUT), argumentTypes.size()); for (List argumentConventions : allArgumentConventions) { for (InvocationReturnConvention returnConvention : InvocationReturnConvention.values()) { @@ -258,7 +317,8 @@ private static void adaptAndVerify( assertThat(expectedConvention.getReturnConvention() == FAIL_ON_NULL || expectedConvention.getReturnConvention() == FLAT_RETURN).isTrue(); return; } - if (actualConvention.getArgumentConventions().stream().anyMatch(convention -> convention == BLOCK_POSITION || convention == BLOCK_POSITION_NOT_NULL)) { + if (actualConvention.getArgumentConventions().stream() + .anyMatch(convention -> EnumSet.of(BLOCK_POSITION, BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION, VALUE_BLOCK_POSITION_NOT_NULL).contains(convention))) { return; } } @@ -343,7 +403,7 @@ private static boolean canCallConventionWithNullArguments(InvocationConvention c { for (int i = 0; i < convention.getArgumentConventions().size(); i++) { InvocationArgumentConvention argumentConvention = convention.getArgumentConvention(i); - if (nullArguments.get(i) && (argumentConvention == NEVER_NULL || argumentConvention == BLOCK_POSITION_NOT_NULL || argumentConvention == FLAT)) { + if (nullArguments.get(i) && EnumSet.of(NEVER_NULL, BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL, FLAT).contains(argumentConvention)) { return false; } } @@ -382,6 +442,10 @@ private static List> toCallArgumentTypes(InvocationConvention callingCo expectedArguments.add(Block.class); expectedArguments.add(int.class); } + case VALUE_BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION -> { + expectedArguments.add(argumentType.getValueBlockType()); + expectedArguments.add(int.class); + } case FLAT -> { expectedArguments.add(Slice.class); expectedArguments.add(int.class); @@ -423,21 +487,31 @@ private static List toCallArgumentValues(InvocationConvention callingCon callArguments.add(testValue == null ? Defaults.defaultValue(argumentType.getJavaType()) : testValue); callArguments.add(testValue == null); } - case BLOCK_POSITION_NOT_NULL -> { + case BLOCK_POSITION_NOT_NULL, VALUE_BLOCK_POSITION_NOT_NULL -> { verify(testValue != null, "null cannot be passed to a block positions not null argument"); BlockBuilder blockBuilder = argumentType.createBlockBuilder(null, 3); blockBuilder.appendNull(); writeNativeValue(argumentType, blockBuilder, testValue); blockBuilder.appendNull(); - callArguments.add(blockBuilder.build()); + if (argumentConvention == BLOCK_POSITION_NOT_NULL) { + callArguments.add(blockBuilder.build()); + } + else { + callArguments.add(blockBuilder.buildValueBlock()); + } callArguments.add(1); } - case BLOCK_POSITION -> { + case BLOCK_POSITION, VALUE_BLOCK_POSITION -> { BlockBuilder blockBuilder = argumentType.createBlockBuilder(null, 3); blockBuilder.appendNull(); writeNativeValue(argumentType, blockBuilder, testValue); blockBuilder.appendNull(); - callArguments.add(blockBuilder.build()); + if (argumentConvention == BLOCK_POSITION) { + callArguments.add(blockBuilder.build()); + } + else { + callArguments.add(blockBuilder.buildValueBlock()); + } callArguments.add(1); } case FLAT -> { @@ -736,6 +810,80 @@ public boolean blockPositionObjects( return true; } + @SuppressWarnings("unused") + public boolean valueBlockPosition( + LongArrayBlock doubleBlock, int doublePosition, + VariableWidthBlock sliceBlock, int slicePosition, + ArrayBlock blockBlock, int blockPosition) + { + checkState(!invoked, "Already invoked"); + invoked = true; + objectsMethod = false; + + if (doubleBlock.isNull(doublePosition)) { + this.doubleValue = null; + } + else { + this.doubleValue = DOUBLE.getDouble(doubleBlock, doublePosition); + } + + if (sliceBlock.isNull(slicePosition)) { + this.sliceValue = null; + } + else { + this.sliceValue = VARCHAR.getSlice(sliceBlock, slicePosition); + } + + if (blockBlock.isNull(blockPosition)) { + this.blockValue = null; + } + else { + this.blockValue = ARRAY_TYPE.getObject(blockBlock, blockPosition); + } + return true; + } + + @SuppressWarnings("unused") + public boolean valueBlockPositionObjects( + VariableWidthBlock sliceBlock, int slicePosition, + ArrayBlock blockBlock, int blockPosition, + VariableWidthBlock objectCharBlock, int objectCharPosition, + Fixed12Block objectTimestampBlock, int objectTimestampPosition) + { + checkState(!invoked, "Already invoked"); + invoked = true; + objectsMethod = true; + + if (sliceBlock.isNull(slicePosition)) { + this.sliceValue = null; + } + else { + this.sliceValue = VARCHAR.getSlice(sliceBlock, slicePosition); + } + + if (blockBlock.isNull(blockPosition)) { + this.blockValue = null; + } + else { + this.blockValue = ARRAY_TYPE.getObject(blockBlock, blockPosition); + } + + if (objectCharBlock.isNull(objectCharPosition)) { + this.objectCharValue = null; + } + else { + this.objectCharValue = CHAR_TYPE.getObject(objectCharBlock, objectCharPosition); + } + + if (objectTimestampBlock.isNull(objectTimestampPosition)) { + this.objectTimestampValue = null; + } + else { + this.objectTimestampValue = TIMESTAMP_TYPE.getObject(objectTimestampBlock, objectTimestampPosition); + } + return true; + } + public void verify( InvocationConvention actualConvention, BitSet nullArguments, @@ -781,7 +929,7 @@ private static boolean shouldFunctionBeInvoked(InvocationConvention actualConven { for (int i = 0; i < actualConvention.getArgumentConventions().size(); i++) { InvocationArgumentConvention argumentConvention = actualConvention.getArgumentConvention(i); - if ((argumentConvention == NEVER_NULL || argumentConvention == BLOCK_POSITION_NOT_NULL || argumentConvention == FLAT) && nullArguments.get(i)) { + if ((argumentConvention == NEVER_NULL || argumentConvention == BLOCK_POSITION_NOT_NULL || argumentConvention == VALUE_BLOCK_POSITION_NOT_NULL || argumentConvention == FLAT) && nullArguments.get(i)) { return false; } } diff --git a/core/trino-spi/src/test/java/io/trino/spi/type/TestingIdType.java b/core/trino-spi/src/test/java/io/trino/spi/type/TestingIdType.java index 723396c57120..1d60f9c89932 100644 --- a/core/trino-spi/src/test/java/io/trino/spi/type/TestingIdType.java +++ b/core/trino-spi/src/test/java/io/trino/spi/type/TestingIdType.java @@ -40,7 +40,7 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return block.getLong(position, 0); + return getLong(block, position); } @Override diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/BinaryColumnEncodingFactory.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/BinaryColumnEncodingFactory.java index 02e40bf6c8a2..d05069d1ce90 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/BinaryColumnEncodingFactory.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/BinaryColumnEncodingFactory.java @@ -85,19 +85,19 @@ public BinaryColumnEncoding getEncoding(Type type) if (type instanceof TimestampType) { return new TimestampEncoding((TimestampType) type, timeZone); } - if (type instanceof ArrayType) { - return new ListEncoding(type, getEncoding(type.getTypeParameters().get(0))); + if (type instanceof ArrayType arrayType) { + return new ListEncoding(arrayType, getEncoding(arrayType.getElementType())); } - if (type instanceof MapType) { + if (type instanceof MapType mapType) { return new MapEncoding( - type, - getEncoding(type.getTypeParameters().get(0)), - getEncoding(type.getTypeParameters().get(1))); + mapType, + getEncoding(mapType.getKeyType()), + getEncoding(mapType.getValueType())); } - if (type instanceof RowType) { + if (type instanceof RowType rowType) { return new StructEncoding( - type, - type.getTypeParameters().stream() + rowType, + rowType.getTypeParameters().stream() .map(this::getEncoding) .collect(Collectors.toList())); } diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/ListEncoding.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/ListEncoding.java index 4aff509f9d29..5843b3ad853d 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/ListEncoding.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/ListEncoding.java @@ -19,25 +19,27 @@ import io.trino.spi.block.ArrayBlockBuilder; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.type.Type; +import io.trino.spi.type.ArrayType; import static java.lang.Math.toIntExact; public class ListEncoding extends BlockEncoding { + private final ArrayType arrayType; private final BinaryColumnEncoding elementEncoding; - public ListEncoding(Type type, BinaryColumnEncoding elementEncoding) + public ListEncoding(ArrayType arrayType, BinaryColumnEncoding elementEncoding) { - super(type); + super(arrayType); + this.arrayType = arrayType; this.elementEncoding = elementEncoding; } @Override public void encodeValue(Block block, int position, SliceOutput output) { - Block list = block.getObject(position, Block.class); + Block list = arrayType.getObject(block, position); ReadWriteUtils.writeVInt(output, list.getPositionCount()); // write null bits diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/MapEncoding.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/MapEncoding.java index 9238b8585ef4..ace73d1f942e 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/MapEncoding.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/MapEncoding.java @@ -22,19 +22,21 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; -import io.trino.spi.type.Type; +import io.trino.spi.type.MapType; import static java.lang.Math.toIntExact; public class MapEncoding extends BlockEncoding { + private final MapType mapType; private final BinaryColumnEncoding keyReader; private final BinaryColumnEncoding valueReader; - public MapEncoding(Type type, BinaryColumnEncoding keyReader, BinaryColumnEncoding valueReader) + public MapEncoding(MapType mapType, BinaryColumnEncoding keyReader, BinaryColumnEncoding valueReader) { - super(type); + super(mapType); + this.mapType = mapType; this.keyReader = keyReader; this.valueReader = valueReader; } @@ -42,7 +44,7 @@ public MapEncoding(Type type, BinaryColumnEncoding keyReader, BinaryColumnEncodi @Override public void encodeValue(Block block, int position, SliceOutput output) { - SqlMap sqlMap = block.getObject(position, SqlMap.class); + SqlMap sqlMap = mapType.getObject(block, position); int rawOffset = sqlMap.getRawOffset(); Block rawKeyBlock = sqlMap.getRawKeyBlock(); Block rawValueBlock = sqlMap.getRawValueBlock(); diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/StructEncoding.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/StructEncoding.java index 33d26e9d03cb..038deae7b97a 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/StructEncoding.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/binary/StructEncoding.java @@ -20,7 +20,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.SqlRow; -import io.trino.spi.type.Type; +import io.trino.spi.type.RowType; import java.util.List; @@ -28,17 +28,19 @@ public class StructEncoding extends BlockEncoding { private final List structFields; + private final RowType rowType; - public StructEncoding(Type type, List structFields) + public StructEncoding(RowType rowType, List structFields) { - super(type); + super(rowType); + this.rowType = rowType; this.structFields = ImmutableList.copyOf(structFields); } @Override public void encodeValue(Block block, int position, SliceOutput output) { - SqlRow row = block.getObject(position, SqlRow.class); + SqlRow row = rowType.getObject(block, position); int rawIndex = row.getRawIndex(); // write values diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/ListEncoding.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/ListEncoding.java index 2074e042a99b..a50d7ae5e078 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/ListEncoding.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/ListEncoding.java @@ -19,17 +19,19 @@ import io.trino.spi.block.ArrayBlockBuilder; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.type.Type; +import io.trino.spi.type.ArrayType; public class ListEncoding extends BlockEncoding { + private final ArrayType arrayType; private final byte separator; private final TextColumnEncoding elementEncoding; - public ListEncoding(Type type, Slice nullSequence, byte separator, Byte escapeByte, TextColumnEncoding elementEncoding) + public ListEncoding(ArrayType arrayType, Slice nullSequence, byte separator, Byte escapeByte, TextColumnEncoding elementEncoding) { - super(type, nullSequence, escapeByte); + super(arrayType, nullSequence, escapeByte); + this.arrayType = arrayType; this.separator = separator; this.elementEncoding = elementEncoding; } @@ -38,7 +40,7 @@ public ListEncoding(Type type, Slice nullSequence, byte separator, Byte escapeBy public void encodeValueInto(Block block, int position, SliceOutput output) throws FileCorruptionException { - Block list = block.getObject(position, Block.class); + Block list = arrayType.getObject(block, position); for (int elementIndex = 0; elementIndex < list.getPositionCount(); elementIndex++) { if (elementIndex > 0) { output.writeByte(separator); diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/MapEncoding.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/MapEncoding.java index 40840cd205e9..9006bd3325ce 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/MapEncoding.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/MapEncoding.java @@ -24,7 +24,6 @@ import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; import io.trino.spi.type.MapType; -import io.trino.spi.type.Type; public class MapEncoding extends BlockEncoding @@ -39,7 +38,7 @@ public class MapEncoding private BlockBuilder keyBlockBuilder; public MapEncoding( - Type type, + MapType mapType, Slice nullSequence, byte elementSeparator, byte keyValueSeparator, @@ -47,8 +46,8 @@ public MapEncoding( TextColumnEncoding keyEncoding, TextColumnEncoding valueEncoding) { - super(type, nullSequence, escapeByte); - this.mapType = (MapType) type; + super(mapType, nullSequence, escapeByte); + this.mapType = mapType; this.elementSeparator = elementSeparator; this.keyValueSeparator = keyValueSeparator; this.keyEncoding = keyEncoding; @@ -61,7 +60,7 @@ public MapEncoding( public void encodeValueInto(Block block, int position, SliceOutput output) throws FileCorruptionException { - SqlMap sqlMap = block.getObject(position, SqlMap.class); + SqlMap sqlMap = mapType.getObject(block, position); int rawOffset = sqlMap.getRawOffset(); Block rawKeyBlock = sqlMap.getRawKeyBlock(); Block rawValueBlock = sqlMap.getRawValueBlock(); diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/StructEncoding.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/StructEncoding.java index b40f77d5daba..fb78ce553b7a 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/StructEncoding.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/StructEncoding.java @@ -20,26 +20,28 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.SqlRow; -import io.trino.spi.type.Type; +import io.trino.spi.type.RowType; import java.util.List; public class StructEncoding extends BlockEncoding { + private final RowType rowType; private final byte separator; private final boolean lastColumnTakesRest; private final List structFields; public StructEncoding( - Type type, + RowType rowType, Slice nullSequence, byte separator, Byte escapeByte, boolean lastColumnTakesRest, List structFields) { - super(type, nullSequence, escapeByte); + super(rowType, nullSequence, escapeByte); + this.rowType = rowType; this.separator = separator; this.lastColumnTakesRest = lastColumnTakesRest; this.structFields = structFields; @@ -49,7 +51,7 @@ public StructEncoding( public void encodeValueInto(Block block, int position, SliceOutput output) throws FileCorruptionException { - SqlRow row = block.getObject(position, SqlRow.class); + SqlRow row = rowType.getObject(block, position); int rawIndex = row.getRawIndex(); for (int fieldIndex = 0; fieldIndex < structFields.size(); fieldIndex++) { if (fieldIndex > 0) { diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/TextColumnEncodingFactory.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/TextColumnEncodingFactory.java index 60f46091f4c7..24dbb1e33f4f 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/TextColumnEncodingFactory.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/TextColumnEncodingFactory.java @@ -115,20 +115,20 @@ private TextColumnEncoding getEncoding(Type type, int depth) if (type instanceof TimestampType) { return new TimestampEncoding((TimestampType) type, textEncodingOptions.getNullSequence(), textEncodingOptions.getTimestampFormats()); } - if (type instanceof ArrayType) { - TextColumnEncoding elementEncoding = getEncoding(type.getTypeParameters().get(0), depth + 1); + if (type instanceof ArrayType arrayType) { + TextColumnEncoding elementEncoding = getEncoding(arrayType.getElementType(), depth + 1); return new ListEncoding( - type, + arrayType, textEncodingOptions.getNullSequence(), getSeparator(depth + 1), textEncodingOptions.getEscapeByte(), elementEncoding); } - if (type instanceof MapType) { - TextColumnEncoding keyEncoding = getEncoding(type.getTypeParameters().get(0), depth + 2); - TextColumnEncoding valueEncoding = getEncoding(type.getTypeParameters().get(1), depth + 2); + if (type instanceof MapType mapType) { + TextColumnEncoding keyEncoding = getEncoding(mapType.getKeyType(), depth + 2); + TextColumnEncoding valueEncoding = getEncoding(mapType.getValueType(), depth + 2); return new MapEncoding( - type, + mapType, textEncodingOptions.getNullSequence(), getSeparator(depth + 1), getSeparator(depth + 2), @@ -136,12 +136,12 @@ private TextColumnEncoding getEncoding(Type type, int depth) keyEncoding, valueEncoding); } - if (type instanceof RowType) { - List fieldEncodings = type.getTypeParameters().stream() + if (type instanceof RowType rowType) { + List fieldEncodings = rowType.getTypeParameters().stream() .map(fieldType -> getEncoding(fieldType, depth + 1)) .collect(toImmutableList()); return new StructEncoding( - type, + rowType, textEncodingOptions.getNullSequence(), getSeparator(depth + 1), textEncodingOptions.getEscapeByte(), diff --git a/lib/trino-orc/src/main/java/io/trino/orc/ValidationHash.java b/lib/trino-orc/src/main/java/io/trino/orc/ValidationHash.java index 5562871481c2..7b30700b5c37 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/ValidationHash.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/ValidationHash.java @@ -21,7 +21,7 @@ import io.trino.spi.type.ArrayType; import io.trino.spi.type.MapType; import io.trino.spi.type.RowType; -import io.trino.spi.type.StandardTypes; +import io.trino.spi.type.TimestampType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; @@ -93,7 +93,7 @@ public static ValidationHash createValidationHash(Type type) return new ValidationHash(ROW_HASH.bindTo(rowType).bindTo(fieldHashes)); } - if (type.getTypeSignature().getBase().equals(StandardTypes.TIMESTAMP)) { + if (type instanceof TimestampType timestampType && timestampType.isShort()) { return new ValidationHash(TIMESTAMP_HASH); } diff --git a/lib/trino-orc/src/main/java/io/trino/orc/writer/DictionaryBuilder.java b/lib/trino-orc/src/main/java/io/trino/orc/writer/DictionaryBuilder.java index af5cd9195ccf..51a921ad5829 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/writer/DictionaryBuilder.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/writer/DictionaryBuilder.java @@ -14,10 +14,10 @@ package io.trino.orc.writer; import io.airlift.slice.DynamicSliceOutput; +import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; import io.airlift.slice.XxHash64; import io.trino.array.IntBigArray; -import io.trino.spi.block.Block; import io.trino.spi.block.VariableWidthBlock; import java.util.Arrays; @@ -86,7 +86,7 @@ public long getRetainedSizeInBytes() blockPositionByHash.sizeOf(); } - public Block getElementBlock() + public VariableWidthBlock getElementBlock() { boolean[] isNull = new boolean[entryCount]; isNull[NULL_POSITION] = true; @@ -103,7 +103,7 @@ public void clear() Arrays.fill(offsets, 0); } - public int putIfAbsent(Block block, int position) + public int putIfAbsent(VariableWidthBlock block, int position) { requireNonNull(block, "block must not be null"); @@ -131,11 +131,14 @@ public int getEntryCount() /** * Get slot position of the element at {@code position} of {@code block} */ - private long getHashPositionOfElement(Block block, int position) + private long getHashPositionOfElement(VariableWidthBlock block, int position) { checkArgument(!block.isNull(position), "position is null"); + Slice rawSlice = block.getRawSlice(); + int rawSliceOffset = block.getRawSliceOffset(position); int length = block.getSliceLength(position); - long hashPosition = getMaskedHash(block.hash(position, 0, length)); + + long hashPosition = getMaskedHash(XxHash64.hash(rawSlice, rawSliceOffset, length)); while (true) { int entryPosition = blockPositionByHash.get(hashPosition); if (entryPosition == EMPTY_SLOT) { @@ -144,7 +147,7 @@ private long getHashPositionOfElement(Block block, int position) } int entryOffset = offsets[entryPosition]; int entryLength = offsets[entryPosition + 1] - entryOffset; - if (entryLength == length && block.bytesEqual(position, 0, sliceOutput.getUnderlyingSlice(), entryOffset, entryLength)) { + if (rawSlice.equals(rawSliceOffset, length, sliceOutput.getUnderlyingSlice(), entryOffset, entryLength)) { // Already has this element return hashPosition; } @@ -153,14 +156,13 @@ private long getHashPositionOfElement(Block block, int position) } } - private int addNewElement(long hashPosition, Block block, int position) + private int addNewElement(long hashPosition, VariableWidthBlock block, int position) { checkArgument(!block.isNull(position), "position is null"); int newElementPositionInBlock = entryCount; - int length = block.getSliceLength(position); - block.writeSliceTo(position, 0, length, sliceOutput); + sliceOutput.writeBytes(block.getRawSlice(), block.getRawSliceOffset(position), block.getSliceLength(position)); entryCount++; offsets[entryCount] = sliceOutput.size(); diff --git a/lib/trino-orc/src/main/java/io/trino/orc/writer/SliceDictionaryColumnWriter.java b/lib/trino-orc/src/main/java/io/trino/orc/writer/SliceDictionaryColumnWriter.java index 045d2fee05fb..1766a2a45d1a 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/writer/SliceDictionaryColumnWriter.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/writer/SliceDictionaryColumnWriter.java @@ -38,6 +38,7 @@ import io.trino.orc.stream.StreamDataOutput; import io.trino.spi.block.Block; import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.type.Type; import it.unimi.dsi.fastutil.ints.IntArrays; @@ -282,17 +283,19 @@ public void writeBlock(Block block) // record values values.ensureCapacity(rowGroupValueCount + block.getPositionCount()); - for (int position = 0; position < block.getPositionCount(); position++) { - int index = dictionary.putIfAbsent(block, position); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + for (int i = 0; i < block.getPositionCount(); i++) { + int position = block.getUnderlyingValuePosition(i); + int index = dictionary.putIfAbsent(valueBlock, position); values.set(rowGroupValueCount, index); rowGroupValueCount++; totalValueCount++; - if (!block.isNull(position)) { + if (!valueBlock.isNull(position)) { // todo min/max statistics only need to be updated if value was not already in the dictionary, but non-null count does - statisticsBuilder.addValue(type.getSlice(block, position)); + statisticsBuilder.addValue(type.getSlice(valueBlock, position)); - rawBytes += block.getSliceLength(position); + rawBytes += valueBlock.getSliceLength(position); totalNonNullValueCount++; } } @@ -349,7 +352,7 @@ private void bufferOutputData() checkState(closed); checkState(!directEncoded); - Block dictionaryElements = dictionary.getElementBlock(); + VariableWidthBlock dictionaryElements = dictionary.getElementBlock(); // write dictionary in sorted order int[] sortedDictionaryIndexes = getSortedDictionaryNullsLast(dictionaryElements); @@ -404,13 +407,14 @@ private void bufferOutputData() presentStream.close(); } - private static int[] getSortedDictionaryNullsLast(Block elementBlock) + private static int[] getSortedDictionaryNullsLast(VariableWidthBlock elementBlock) { int[] sortedPositions = new int[elementBlock.getPositionCount()]; for (int i = 0; i < sortedPositions.length; i++) { sortedPositions[i] = i; } + Slice rawSlice = elementBlock.getRawSlice(); IntArrays.quickSort(sortedPositions, 0, sortedPositions.length, (int left, int right) -> { boolean nullLeft = elementBlock.isNull(left); boolean nullRight = elementBlock.isNull(right); @@ -423,13 +427,11 @@ private static int[] getSortedDictionaryNullsLast(Block elementBlock) if (nullRight) { return -1; } - return elementBlock.compareTo( - left, - 0, + return rawSlice.compareTo( + elementBlock.getRawSliceOffset(left), elementBlock.getSliceLength(left), - elementBlock, - right, - 0, + rawSlice, + elementBlock.getRawSliceOffset(right), elementBlock.getSliceLength(right)); }); diff --git a/lib/trino-orc/src/test/java/io/trino/orc/TestDictionaryBuilder.java b/lib/trino-orc/src/test/java/io/trino/orc/TestDictionaryBuilder.java index e486503fe60f..9ed19323cb46 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/TestDictionaryBuilder.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/TestDictionaryBuilder.java @@ -47,12 +47,5 @@ public TestHashCollisionBlock(int positionCount, Slice slice, int[] offsets, boo { super(positionCount, slice, offsets, Optional.of(valueIsNull)); } - - @Override - public long hash(int position, int offset, int length) - { - // return 0 to hash to the reserved null position which is zero - return 0; - } } } diff --git a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/model/Field.java b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/model/Field.java index 141083a9e104..74d3ac2cbe01 100644 --- a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/model/Field.java +++ b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/model/Field.java @@ -19,7 +19,6 @@ import io.airlift.slice.Slice; import io.trino.plugin.accumulo.Types; import io.trino.spi.TrinoException; -import io.trino.spi.block.ArrayBlock; import io.trino.spi.block.Block; import io.trino.spi.block.SqlMap; import io.trino.spi.type.Type; @@ -27,8 +26,6 @@ import java.sql.Time; import java.sql.Timestamp; -import java.util.Arrays; -import java.util.Objects; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; @@ -45,7 +42,6 @@ import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_MILLISECOND; import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarbinaryType.VARBINARY; -import static io.trino.spi.type.VarcharType.VARCHAR; import static java.lang.Float.intBitsToFloat; import static java.lang.Math.floorDiv; import static java.lang.Math.toIntExact; @@ -69,57 +65,6 @@ public Field(Object nativeValue, Type type, boolean indexed) this.indexed = indexed; } - public Field(Field field) - { - this.type = field.type; - this.indexed = false; - - if (Types.isArrayType(this.type) || Types.isMapType(this.type)) { - this.value = field.value; - return; - } - - if (type.equals(BIGINT)) { - this.value = field.getLong(); - } - else if (type.equals(BOOLEAN)) { - this.value = field.getBoolean(); - } - else if (type.equals(DATE)) { - this.value = field.getDate(); - } - else if (type.equals(DOUBLE)) { - this.value = field.getDouble(); - } - else if (type.equals(INTEGER)) { - this.value = field.getInt(); - } - else if (type.equals(REAL)) { - this.value = field.getFloat(); - } - else if (type.equals(SMALLINT)) { - this.value = field.getShort(); - } - else if (type.equals(TIME_MILLIS)) { - this.value = new Time(field.getTime().getTime()); - } - else if (type.equals(TIMESTAMP_MILLIS)) { - this.value = new Timestamp(field.getTimestamp().getTime()); - } - else if (type.equals(TINYINT)) { - this.value = field.getByte(); - } - else if (type.equals(VARBINARY)) { - this.value = Arrays.copyOf(field.getVarbinary(), field.getVarbinary().length); - } - else if (type.equals(VARCHAR)) { - this.value = field.getVarchar(); - } - else { - throw new TrinoException(NOT_SUPPORTED, "Unsupported type " + type); - } - } - public Type getType() { return type; @@ -210,59 +155,6 @@ public boolean isNull() return value == null; } - @Override - public int hashCode() - { - return Objects.hash(value, type, indexed); - } - - @Override - public boolean equals(Object obj) - { - boolean retval = true; - if (obj instanceof Field field) { - if (type.equals(field.getType())) { - if (this.isNull() && field.isNull()) { - retval = true; - } - else if (this.isNull() != field.isNull()) { - retval = false; - } - else if (type.equals(VARBINARY)) { - // special case for byte arrays - // aren't they so fancy - retval = Arrays.equals((byte[]) value, (byte[]) field.getObject()); - } - else if (type.equals(DATE) || type.equals(TIME_MILLIS) || type.equals(TIMESTAMP_MILLIS)) { - retval = value.toString().equals(field.getObject().toString()); - } - else { - if (value instanceof Block) { - retval = equals((Block) value, (Block) field.getObject()); - } - else { - retval = value.equals(field.getObject()); - } - } - } - } - return retval; - } - - private static boolean equals(Block block1, Block block2) - { - boolean retval = block1.getPositionCount() == block2.getPositionCount(); - for (int i = 0; i < block1.getPositionCount() && retval; ++i) { - if (block1 instanceof ArrayBlock && block2 instanceof ArrayBlock) { - retval = equals(block1.getObject(i, Block.class), block2.getObject(i, Block.class)); - } - else { - retval = block1.compareTo(i, 0, block1.getSliceLength(i), block2, i, 0, block2.getSliceLength(i)) == 0; - } - } - return retval; - } - @Override public String toString() { diff --git a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/model/Row.java b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/model/Row.java index bf43b5b60484..3b5c674a5fc9 100644 --- a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/model/Row.java +++ b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/model/Row.java @@ -16,10 +16,7 @@ import io.trino.spi.type.Type; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; import static java.util.Objects.requireNonNull; @@ -29,12 +26,6 @@ public class Row public Row() {} - public Row(Row row) - { - requireNonNull(row, "row is null"); - fields.addAll(row.fields.stream().map(Field::new).collect(Collectors.toList())); - } - public Row addField(Field field) { requireNonNull(field, "field is null"); @@ -54,33 +45,11 @@ public Field getField(int i) return fields.get(i); } - /** - * Gets a list of all internal fields. Any changes to this list will affect this row. - * - * @return List of fields - */ - public List getFields() - { - return fields; - } - public int length() { return fields.size(); } - @Override - public int hashCode() - { - return Arrays.hashCode(fields.toArray()); - } - - @Override - public boolean equals(Object obj) - { - return obj instanceof Row && Objects.equals(this.fields, ((Row) obj).getFields()); - } - @Override public String toString() { diff --git a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/serializers/AccumuloRowSerializer.java b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/serializers/AccumuloRowSerializer.java index d1ba43f12262..e83836255815 100644 --- a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/serializers/AccumuloRowSerializer.java +++ b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/serializers/AccumuloRowSerializer.java @@ -21,6 +21,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.SqlMap; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeUtils; @@ -604,12 +605,12 @@ else if (type instanceof MapType mapType) { */ static Object readObject(Type type, Block block, int position) { - if (Types.isArrayType(type)) { - Type elementType = Types.getElementType(type); - return getArrayFromBlock(elementType, block.getObject(position, Block.class)); + if (type instanceof ArrayType arrayType) { + Type elementType = arrayType.getElementType(); + return getArrayFromBlock(elementType, arrayType.getObject(block, position)); } - if (Types.isMapType(type)) { - return getMapFromSqlMap(type, block.getObject(position, SqlMap.class)); + if (type instanceof MapType mapType) { + return getMapFromSqlMap(type, mapType.getObject(block, position)); } if (type.getJavaType() == Slice.class) { Slice slice = (Slice) TypeUtils.readNativeValue(type, block, position); diff --git a/plugin/trino-accumulo/src/test/java/io/trino/plugin/accumulo/model/TestField.java b/plugin/trino-accumulo/src/test/java/io/trino/plugin/accumulo/model/TestField.java index 825b9c5b6d13..9edc7755aaec 100644 --- a/plugin/trino-accumulo/src/test/java/io/trino/plugin/accumulo/model/TestField.java +++ b/plugin/trino-accumulo/src/test/java/io/trino/plugin/accumulo/model/TestField.java @@ -68,9 +68,6 @@ public void testArray() assertEquals(f1.getArray(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -86,9 +83,6 @@ public void testBoolean() assertEquals(f1.getBoolean().booleanValue(), false); assertEquals(f1.getObject(), false); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -100,9 +94,6 @@ public void testDate() assertEquals(f1.getDate(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -114,9 +105,6 @@ public void testDouble() assertEquals(f1.getDouble(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -128,9 +116,6 @@ public void testFloat() assertEquals(f1.getFloat(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -142,9 +127,6 @@ public void testInt() assertEquals(f1.getInt(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -156,9 +138,6 @@ public void testLong() assertEquals(f1.getLong(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -183,9 +162,6 @@ public void testSmallInt() assertEquals(f1.getShort(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -197,9 +173,6 @@ public void testTime() assertEquals(f1.getTime(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -211,9 +184,6 @@ public void testTimestamp() assertEquals(f1.getTimestamp(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -225,9 +195,6 @@ public void testTinyInt() assertEquals(f1.getByte(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -239,9 +206,6 @@ public void testVarbinary() assertEquals(f1.getVarbinary(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } @Test @@ -253,8 +217,5 @@ public void testVarchar() assertEquals(f1.getVarchar(), expected); assertEquals(f1.getObject(), expected); assertEquals(f1.getType(), type); - - Field f2 = new Field(f1); - assertEquals(f2, f1); } } diff --git a/plugin/trino-accumulo/src/test/java/io/trino/plugin/accumulo/model/TestRow.java b/plugin/trino-accumulo/src/test/java/io/trino/plugin/accumulo/model/TestRow.java index 87787049275e..67de0e6a892d 100644 --- a/plugin/trino-accumulo/src/test/java/io/trino/plugin/accumulo/model/TestRow.java +++ b/plugin/trino-accumulo/src/test/java/io/trino/plugin/accumulo/model/TestRow.java @@ -63,9 +63,6 @@ public void testRow() r1.addField(null, VARCHAR); assertEquals(r1.length(), 14); - - Row r2 = new Row(r1); - assertEquals(r2, r1); } @Test diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTypeUtils.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTypeUtils.java index 8dd65dae1d0e..3bdcb3f30627 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTypeUtils.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTypeUtils.java @@ -111,7 +111,7 @@ public static Object readNativeValue(Type type, Block block, int position) return timestampToStringConverter(timestamp); } if (type instanceof ArrayType arrayType) { - Block arrayBlock = block.getObject(position, Block.class); + Block arrayBlock = arrayType.getObject(block, position); ImmutableList.Builder list = ImmutableList.builderWithExpectedSize(arrayBlock.getPositionCount()); for (int i = 0; i < arrayBlock.getPositionCount(); i++) { Object element = readNativeValue(arrayType.getElementType(), arrayBlock, i); @@ -123,7 +123,7 @@ public static Object readNativeValue(Type type, Block block, int position) return list.build(); } if (type instanceof RowType rowType) { - SqlRow sqlRow = block.getObject(position, SqlRow.class); + SqlRow sqlRow = rowType.getObject(block, position); List fieldTypes = rowType.getTypeParameters(); if (fieldTypes.size() != sqlRow.getFieldCount()) { diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java index d34321d5dac0..9acaa296792f 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java @@ -17,6 +17,7 @@ import com.google.common.collect.AbstractIterator; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.math.LongMath; import io.airlift.log.Logger; import io.trino.filesystem.TrinoInputFile; @@ -41,8 +42,16 @@ import io.trino.plugin.hive.parquet.ParquetPageSourceFactory; import io.trino.spi.Page; import io.trino.spi.TrinoException; +import io.trino.spi.block.ArrayBlock; import io.trino.spi.block.Block; +import io.trino.spi.block.ByteArrayBlock; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.block.MapBlock; +import io.trino.spi.block.RowBlock; import io.trino.spi.block.SqlRow; +import io.trino.spi.block.ValueBlock; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.predicate.Domain; @@ -71,7 +80,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.plugin.deltalake.DeltaLakeColumnType.REGULAR; import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_INVALID_SCHEMA; @@ -96,7 +104,6 @@ import static java.lang.Math.floorDiv; import static java.lang.String.format; import static java.math.RoundingMode.UNNECESSARY; -import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.requireNonNull; public class CheckpointEntryIterator @@ -161,7 +168,7 @@ public CheckpointEntryIterator( this.stringList = (ArrayType) typeManager.getType(TypeSignature.arrayType(VARCHAR.getTypeSignature())); this.stringMap = (MapType) typeManager.getType(TypeSignature.mapType(VARCHAR.getTypeSignature(), VARCHAR.getTypeSignature())); this.checkpointRowStatisticsWritingEnabled = checkpointRowStatisticsWritingEnabled; - checkArgument(fields.size() > 0, "fields is empty"); + checkArgument(!fields.isEmpty(), "fields is empty"); Map extractors = ImmutableMap.builder() .put(TRANSACTION, this::buildTxnEntry) .put(ADD, this::buildAddEntry) @@ -279,41 +286,40 @@ private DeltaLakeTransactionLogEntry buildCommitInfoEntry(ConnectorSession sessi int jobFields = 5; int notebookFields = 1; SqlRow commitInfoRow = block.getObject(pagePosition, SqlRow.class); - int commitInfoRawIndex = commitInfoRow.getRawIndex(); log.debug("Block %s has %s fields", block, commitInfoRow.getFieldCount()); if (commitInfoRow.getFieldCount() != commitInfoFields) { throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, format("Expected block %s to have %d children, but found %s", block, commitInfoFields, commitInfoRow.getFieldCount())); } - SqlRow jobRow = commitInfoRow.getRawFieldBlock(9).getObject(commitInfoRawIndex, SqlRow.class); + SqlRow jobRow = getRowField(commitInfoRow, 9); if (jobRow.getFieldCount() != jobFields) { throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, format("Expected block %s to have %d children, but found %s", jobRow, jobFields, jobRow.getFieldCount())); } - SqlRow notebookRow = commitInfoRow.getRawFieldBlock(7).getObject(commitInfoRawIndex, SqlRow.class); + SqlRow notebookRow = getRowField(commitInfoRow, 7); if (notebookRow.getFieldCount() != notebookFields) { throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, format("Expected block %s to have %d children, but found %s", notebookRow, notebookFields, notebookRow.getFieldCount())); } CommitInfoEntry result = new CommitInfoEntry( - getLong(commitInfoRow.getRawFieldBlock(0), commitInfoRawIndex), - getLong(commitInfoRow.getRawFieldBlock(1), commitInfoRawIndex), - getString(commitInfoRow.getRawFieldBlock(2), commitInfoRawIndex), - getString(commitInfoRow.getRawFieldBlock(3), commitInfoRawIndex), - getString(commitInfoRow.getRawFieldBlock(4), commitInfoRawIndex), - getMap(commitInfoRow.getRawFieldBlock(5), commitInfoRawIndex), + getLongField(commitInfoRow, 0), + getLongField(commitInfoRow, 1), + getStringField(commitInfoRow, 2), + getStringField(commitInfoRow, 3), + getStringField(commitInfoRow, 4), + getMapField(commitInfoRow, 5), new CommitInfoEntry.Job( - getString(jobRow.getRawFieldBlock(0), jobRow.getRawIndex()), - getString(jobRow.getRawFieldBlock(1), jobRow.getRawIndex()), - getString(jobRow.getRawFieldBlock(2), jobRow.getRawIndex()), - getString(jobRow.getRawFieldBlock(3), jobRow.getRawIndex()), - getString(jobRow.getRawFieldBlock(4), jobRow.getRawIndex())), + getStringField(jobRow, 0), + getStringField(jobRow, 1), + getStringField(jobRow, 2), + getStringField(jobRow, 3), + getStringField(jobRow, 4)), new CommitInfoEntry.Notebook( - getString(notebookRow.getRawFieldBlock(0), notebookRow.getRawIndex())), - getString(commitInfoRow.getRawFieldBlock(8), commitInfoRawIndex), - getLong(commitInfoRow.getRawFieldBlock(9), commitInfoRawIndex), - getString(commitInfoRow.getRawFieldBlock(10), commitInfoRawIndex), - Optional.of(getByte(commitInfoRow.getRawFieldBlock(11), commitInfoRawIndex) != 0)); + getStringField(notebookRow, 0)), + getStringField(commitInfoRow, 8), + getLongField(commitInfoRow, 9), + getStringField(commitInfoRow, 10), + Optional.of(getBooleanField(commitInfoRow, 11))); log.debug("Result: %s", result); return DeltaLakeTransactionLogEntry.commitInfoEntry(result); } @@ -333,15 +339,14 @@ private DeltaLakeTransactionLogEntry buildProtocolEntry(ConnectorSession session throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, format("Expected block %s to have between %d and %d children, but found %s", block, minProtocolFields, maxProtocolFields, fieldCount)); } - int rawIndex = protocolEntryRow.getRawIndex(); - Block readerFeaturesField = protocolEntryRow.getRawFieldBlock(2); + Optional> readerFeatures = getOptionalSetField(protocolEntryRow, 2); // The last entry should be writer feature when protocol entry size is 3 https://github.com/delta-io/delta/blob/master/PROTOCOL.md#disabled-features - Block writerFeaturesField = fieldCount != 4 ? readerFeaturesField : protocolEntryRow.getRawFieldBlock(3); + Optional> writerFeatures = fieldCount != 4 ? readerFeatures : getOptionalSetField(protocolEntryRow, 3); ProtocolEntry result = new ProtocolEntry( - getInt(protocolEntryRow.getRawFieldBlock(0), rawIndex), - getInt(protocolEntryRow.getRawFieldBlock(1), rawIndex), - readerFeaturesField.isNull(rawIndex) ? Optional.empty() : Optional.of(getList(readerFeaturesField, rawIndex).stream().collect(toImmutableSet())), - writerFeaturesField.isNull(rawIndex) ? Optional.empty() : Optional.of(getList(writerFeaturesField, rawIndex).stream().collect(toImmutableSet()))); + getIntField(protocolEntryRow, 0), + getIntField(protocolEntryRow, 1), + readerFeatures, + writerFeatures); log.debug("Result: %s", result); return DeltaLakeTransactionLogEntry.protocolEntry(result); } @@ -355,28 +360,27 @@ private DeltaLakeTransactionLogEntry buildMetadataEntry(ConnectorSession session int metadataFields = 8; int formatFields = 2; SqlRow metadataEntryRow = block.getObject(pagePosition, SqlRow.class); - int rawIndex = metadataEntryRow.getRawIndex(); log.debug("Block %s has %s fields", block, metadataEntryRow.getFieldCount()); if (metadataEntryRow.getFieldCount() != metadataFields) { throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, format("Expected block %s to have %d children, but found %s", block, metadataFields, metadataEntryRow.getFieldCount())); } - SqlRow formatRow = metadataEntryRow.getRawFieldBlock(3).getObject(rawIndex, SqlRow.class); + SqlRow formatRow = getRowField(metadataEntryRow, 3); if (formatRow.getFieldCount() != formatFields) { throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, format("Expected block %s to have %d children, but found %s", formatRow, formatFields, formatRow.getFieldCount())); } MetadataEntry result = new MetadataEntry( - getString(metadataEntryRow.getRawFieldBlock(0), rawIndex), - getString(metadataEntryRow.getRawFieldBlock(1), rawIndex), - getString(metadataEntryRow.getRawFieldBlock(2), rawIndex), + getStringField(metadataEntryRow, 0), + getStringField(metadataEntryRow, 1), + getStringField(metadataEntryRow, 2), new MetadataEntry.Format( - getString(formatRow.getRawFieldBlock(0), formatRow.getRawIndex()), - getMap(formatRow.getRawFieldBlock(1), formatRow.getRawIndex())), - getString(metadataEntryRow.getRawFieldBlock(4), rawIndex), - getList(metadataEntryRow.getRawFieldBlock(5), rawIndex), - getMap(metadataEntryRow.getRawFieldBlock(6), rawIndex), - getLong(metadataEntryRow.getRawFieldBlock(7), rawIndex)); + getStringField(formatRow, 0), + getMapField(formatRow, 1)), + getStringField(metadataEntryRow, 4), + getListField(metadataEntryRow, 5), + getMapField(metadataEntryRow, 6), + getLongField(metadataEntryRow, 7)); log.debug("Result: %s", result); return DeltaLakeTransactionLogEntry.metadataEntry(result); } @@ -394,11 +398,10 @@ private DeltaLakeTransactionLogEntry buildRemoveEntry(ConnectorSession session, throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, format("Expected block %s to have %d children, but found %s", block, removeFields, removeEntryRow.getFieldCount())); } - int rawIndex = removeEntryRow.getRawIndex(); RemoveFileEntry result = new RemoveFileEntry( - getString(removeEntryRow.getRawFieldBlock(0), rawIndex), - getLong(removeEntryRow.getRawFieldBlock(1), rawIndex), - getByte(removeEntryRow.getRawFieldBlock(2), rawIndex) != 0); + getStringField(removeEntryRow, 0), + getLongField(removeEntryRow, 1), + getBooleanField(removeEntryRow, 2)); log.debug("Result: %s", result); return DeltaLakeTransactionLogEntry.removeFileEntry(result); } @@ -412,92 +415,71 @@ private DeltaLakeTransactionLogEntry buildAddEntry(ConnectorSession session, Blo boolean deletionVectorsEnabled = isDeletionVectorEnabled(metadataEntry, protocolEntry); SqlRow addEntryRow = block.getObject(pagePosition, SqlRow.class); log.debug("Block %s has %s fields", block, addEntryRow.getFieldCount()); - int rawIndex = addEntryRow.getRawIndex(); - String path = getString(addEntryRow.getRawFieldBlock(0), rawIndex); - Map partitionValues = getMap(addEntryRow.getRawFieldBlock(1), rawIndex); - long size = getLong(addEntryRow.getRawFieldBlock(2), rawIndex); - long modificationTime = getLong(addEntryRow.getRawFieldBlock(3), rawIndex); - boolean dataChange = getByte(addEntryRow.getRawFieldBlock(4), rawIndex) != 0; + String path = getStringField(addEntryRow, 0); + Map partitionValues = getMapField(addEntryRow, 1); + long size = getLongField(addEntryRow, 2); + long modificationTime = getLongField(addEntryRow, 3); + boolean dataChange = getBooleanField(addEntryRow, 4); + Optional deletionVector = Optional.empty(); - int position = 5; + int statsFieldIndex; if (deletionVectorsEnabled) { - if (!addEntryRow.getRawFieldBlock(5).isNull(rawIndex)) { - deletionVector = Optional.of(parseDeletionVectorFromParquet(addEntryRow.getRawFieldBlock(5).getObject(rawIndex, Block.class))); - } - position = 6; - } - Map tags = getMap(addEntryRow.getRawFieldBlock(position + 2), rawIndex); - - AddFileEntry result; - if (!addEntryRow.getRawFieldBlock(position + 1).isNull(rawIndex)) { - result = new AddFileEntry( - path, - partitionValues, - size, - modificationTime, - dataChange, - Optional.empty(), - Optional.of(parseStatisticsFromParquet(addEntryRow.getRawFieldBlock(position + 1).getObject(rawIndex, SqlRow.class))), - tags, - deletionVector); - } - else if (!addEntryRow.getRawFieldBlock(position).isNull(rawIndex)) { - result = new AddFileEntry( - path, - partitionValues, - size, - modificationTime, - dataChange, - Optional.of(getString(addEntryRow.getRawFieldBlock(position), rawIndex)), - Optional.empty(), - tags, - deletionVector); + deletionVector = Optional.ofNullable(getRowField(addEntryRow, 5)).map(CheckpointEntryIterator::parseDeletionVectorFromParquet); + statsFieldIndex = 6; } else { - result = new AddFileEntry( - path, - partitionValues, - size, - modificationTime, - dataChange, - Optional.empty(), - Optional.empty(), - tags, - deletionVector); + statsFieldIndex = 5; + } + + Optional parsedStats = Optional.ofNullable(getRowField(addEntryRow, statsFieldIndex + 1)).map(this::parseStatisticsFromParquet); + Optional stats = Optional.empty(); + if (parsedStats.isEmpty()) { + stats = Optional.ofNullable(getStringField(addEntryRow, statsFieldIndex)); } + Map tags = getMapField(addEntryRow, statsFieldIndex + 2); + AddFileEntry result = new AddFileEntry( + path, + partitionValues, + size, + modificationTime, + dataChange, + stats, + parsedStats, + tags, + deletionVector); + log.debug("Result: %s", result); return DeltaLakeTransactionLogEntry.addFileEntry(result); } - private DeletionVectorEntry parseDeletionVectorFromParquet(Block block) + private static DeletionVectorEntry parseDeletionVectorFromParquet(SqlRow row) { - checkArgument(block.getPositionCount() == 5, "Deletion vector entry must have 5 fields"); + checkArgument(row.getFieldCount() == 5, "Deletion vector entry must have 5 fields"); - String storageType = getString(block, 0); - String pathOrInlineDv = getString(block, 1); - OptionalInt offset = block.isNull(2) ? OptionalInt.empty() : OptionalInt.of(getInt(block, 2)); - int sizeInBytes = getInt(block, 3); - long cardinality = getLong(block, 4); + String storageType = getStringField(row, 0); + String pathOrInlineDv = getStringField(row, 1); + OptionalInt offset = getOptionalIntField(row, 2); + int sizeInBytes = getIntField(row, 3); + long cardinality = getLongField(row, 4); return new DeletionVectorEntry(storageType, pathOrInlineDv, offset, sizeInBytes, cardinality); } private DeltaLakeParquetFileStatistics parseStatisticsFromParquet(SqlRow statsRow) { - int rawIndex = statsRow.getRawIndex(); - long numRecords = getLong(statsRow.getRawFieldBlock(0), rawIndex); + long numRecords = getLongField(statsRow, 0); Optional> minValues = Optional.empty(); Optional> maxValues = Optional.empty(); Optional> nullCount; if (!columnsWithMinMaxStats.isEmpty()) { - minValues = Optional.of(readMinMax(statsRow.getRawFieldBlock(1), rawIndex, columnsWithMinMaxStats)); - maxValues = Optional.of(readMinMax(statsRow.getRawFieldBlock(2), rawIndex, columnsWithMinMaxStats)); - nullCount = Optional.of(readNullCount(statsRow.getRawFieldBlock(3), rawIndex, schema)); + minValues = Optional.of(parseMinMax(getRowField(statsRow, 1), columnsWithMinMaxStats)); + maxValues = Optional.of(parseMinMax(getRowField(statsRow, 2), columnsWithMinMaxStats)); + nullCount = Optional.of(parseNullCount(getRowField(statsRow, 3), schema)); } else { - nullCount = Optional.of(readNullCount(statsRow.getRawFieldBlock(1), rawIndex, schema)); + nullCount = Optional.of(parseNullCount(getRowField(statsRow, 1), schema)); } return new DeltaLakeParquetFileStatistics( @@ -507,71 +489,69 @@ private DeltaLakeParquetFileStatistics parseStatisticsFromParquet(SqlRow statsRo nullCount); } - private Map readMinMax(Block block, int blockPosition, List eligibleColumns) + private ImmutableMap parseMinMax(@Nullable SqlRow row, List eligibleColumns) { - if (block.isNull(blockPosition)) { + if (row == null) { // Statistics were not collected return ImmutableMap.of(); } - SqlRow row = block.getObject(blockPosition, SqlRow.class); ImmutableMap.Builder values = ImmutableMap.builder(); - int rawIndex = row.getRawIndex(); for (int i = 0; i < eligibleColumns.size(); i++) { DeltaLakeColumnMetadata metadata = eligibleColumns.get(i); String name = metadata.getPhysicalName(); Type type = metadata.getPhysicalColumnType(); - Block fieldBlock = row.getRawFieldBlock(i); - if (fieldBlock.isNull(rawIndex)) { + ValueBlock fieldBlock = row.getUnderlyingFieldBlock(i); + int fieldIndex = row.getUnderlyingFieldPosition(i); + if (fieldBlock.isNull(fieldIndex)) { continue; } if (type instanceof RowType rowType) { if (checkpointRowStatisticsWritingEnabled) { // RowType column statistics are not used for query planning, but need to be copied when writing out new Checkpoint files. - values.put(name, rowType.getObject(fieldBlock, rawIndex)); + values.put(name, rowType.getObject(fieldBlock, fieldIndex)); } continue; } if (type instanceof TimestampWithTimeZoneType) { - long epochMillis = LongMath.divide((long) readNativeValue(TIMESTAMP_MILLIS, fieldBlock, rawIndex), MICROSECONDS_PER_MILLISECOND, UNNECESSARY); + long epochMillis = LongMath.divide((long) readNativeValue(TIMESTAMP_MILLIS, fieldBlock, fieldIndex), MICROSECONDS_PER_MILLISECOND, UNNECESSARY); if (floorDiv(epochMillis, MILLISECONDS_PER_DAY) >= START_OF_MODERN_ERA_EPOCH_DAY) { values.put(name, packDateTimeWithZone(epochMillis, UTC_KEY)); } continue; } - values.put(name, readNativeValue(type, fieldBlock, rawIndex)); + values.put(name, readNativeValue(type, fieldBlock, fieldIndex)); } return values.buildOrThrow(); } - private Map readNullCount(Block block, int blockPosition, List columns) + private Map parseNullCount(SqlRow row, List columns) { - if (block.isNull(blockPosition)) { + if (row == null) { // Statistics were not collected return ImmutableMap.of(); } - SqlRow row = block.getObject(blockPosition, SqlRow.class); - int rawIndex = row.getRawIndex(); ImmutableMap.Builder values = ImmutableMap.builder(); for (int i = 0; i < columns.size(); i++) { DeltaLakeColumnMetadata metadata = columns.get(i); - Block fieldBlock = row.getRawFieldBlock(i); - if (fieldBlock.isNull(rawIndex)) { + ValueBlock fieldBlock = row.getUnderlyingFieldBlock(i); + int fieldIndex = row.getUnderlyingFieldPosition(i); + if (fieldBlock.isNull(fieldIndex)) { continue; } if (metadata.getType() instanceof RowType) { if (checkpointRowStatisticsWritingEnabled) { // RowType column statistics are not used for query planning, but need to be copied when writing out new Checkpoint files. - values.put(metadata.getPhysicalName(), fieldBlock.getObject(rawIndex, SqlRow.class)); + values.put(metadata.getPhysicalName(), fieldBlock.getObject(fieldIndex, SqlRow.class)); } continue; } - values.put(metadata.getPhysicalName(), getLong(fieldBlock, rawIndex)); + values.put(metadata.getPhysicalName(), getLongField(row, i)); } return values.buildOrThrow(); } @@ -589,52 +569,88 @@ private DeltaLakeTransactionLogEntry buildTxnEntry(ConnectorSession session, Blo throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, format("Expected block %s to have %d children, but found %s", block, txnFields, txnEntryRow.getFieldCount())); } - int rawIndex = txnEntryRow.getRawIndex(); TransactionEntry result = new TransactionEntry( - getString(txnEntryRow.getRawFieldBlock(0), rawIndex), - getLong(txnEntryRow.getRawFieldBlock(1), rawIndex), - getLong(txnEntryRow.getRawFieldBlock(2), rawIndex)); + getStringField(txnEntryRow, 0), + getLongField(txnEntryRow, 1), + getLongField(txnEntryRow, 2)); log.debug("Result: %s", result); return DeltaLakeTransactionLogEntry.transactionEntry(result); } @Nullable - private String getString(Block block, int position) + private static SqlRow getRowField(SqlRow row, int field) { - if (block.isNull(position)) { + RowBlock valueBlock = (RowBlock) row.getUnderlyingFieldBlock(field); + int index = row.getUnderlyingFieldPosition(field); + if (valueBlock.isNull(index)) { return null; } - return block.getSlice(position, 0, block.getSliceLength(position)).toString(UTF_8); + return valueBlock.getRow(index); } - private long getLong(Block block, int position) + @Nullable + private static String getStringField(SqlRow row, int field) { - checkArgument(!block.isNull(position)); - return block.getLong(position, 0); + VariableWidthBlock valueBlock = (VariableWidthBlock) row.getUnderlyingFieldBlock(field); + int index = row.getUnderlyingFieldPosition(field); + if (valueBlock.isNull(index)) { + return null; + } + return valueBlock.getSlice(index).toStringUtf8(); } - private int getInt(Block block, int position) + private static long getLongField(SqlRow row, int field) { - checkArgument(!block.isNull(position)); - return block.getInt(position, 0); + LongArrayBlock valueBlock = (LongArrayBlock) row.getUnderlyingFieldBlock(field); + return valueBlock.getLong(row.getUnderlyingFieldPosition(field)); } - private byte getByte(Block block, int position) + private static int getIntField(SqlRow row, int field) { - checkArgument(!block.isNull(position)); - return block.getByte(position, 0); + IntArrayBlock valueBlock = (IntArrayBlock) row.getUnderlyingFieldBlock(field); + return valueBlock.getInt(row.getUnderlyingFieldPosition(field)); + } + + private static OptionalInt getOptionalIntField(SqlRow row, int field) + { + IntArrayBlock valueBlock = (IntArrayBlock) row.getUnderlyingFieldBlock(field); + int index = row.getUnderlyingFieldPosition(field); + if (valueBlock.isNull(index)) { + return OptionalInt.empty(); + } + return OptionalInt.of(valueBlock.getInt(index)); + } + + private static boolean getBooleanField(SqlRow row, int field) + { + ByteArrayBlock valueBlock = (ByteArrayBlock) row.getUnderlyingFieldBlock(field); + return valueBlock.getByte(row.getUnderlyingFieldPosition(field)) != 0; } @SuppressWarnings("unchecked") - private Map getMap(Block block, int position) + private Map getMapField(SqlRow row, int field) { - return (Map) stringMap.getObjectValue(session, block, position); + MapBlock valueBlock = (MapBlock) row.getUnderlyingFieldBlock(field); + return (Map) stringMap.getObjectValue(session, valueBlock, row.getUnderlyingFieldPosition(field)); } @SuppressWarnings("unchecked") - private List getList(Block block, int position) + private List getListField(SqlRow row, int field) { - return (List) stringList.getObjectValue(session, block, position); + ArrayBlock valueBlock = (ArrayBlock) row.getUnderlyingFieldBlock(field); + return (List) stringList.getObjectValue(session, valueBlock, row.getUnderlyingFieldPosition(field)); + } + + @SuppressWarnings("unchecked") + private Optional> getOptionalSetField(SqlRow row, int field) + { + ArrayBlock valueBlock = (ArrayBlock) row.getUnderlyingFieldBlock(field); + int index = row.getUnderlyingFieldPosition(field); + if (valueBlock.isNull(index)) { + return Optional.empty(); + } + List list = (List) stringList.getObjectValue(session, valueBlock, index); + return Optional.of(ImmutableSet.copyOf(list)); } @Override diff --git a/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/BingTileType.java b/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/BingTileType.java index 659ce9d6585f..8f9d5e8b2453 100644 --- a/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/BingTileType.java +++ b/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/BingTileType.java @@ -24,7 +24,7 @@ public class BingTileType public static final BingTileType BING_TILE = new BingTileType(); public static final String NAME = "BingTile"; - public BingTileType() + private BingTileType() { super(new TypeSignature(NAME)); } @@ -42,6 +42,6 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - return BingTile.decode(block.getLong(position, 0)); + return BingTile.decode(getLong(block, position)); } } diff --git a/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/GeometryType.java b/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/GeometryType.java index 510c28dc2dae..9904fbd0006f 100644 --- a/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/GeometryType.java +++ b/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/GeometryType.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -42,7 +43,9 @@ protected GeometryType(TypeSignature signature) @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } @Override @@ -71,7 +74,6 @@ public Object getObjectValue(ConnectorSession session, Block block, int position if (block.isNull(position)) { return null; } - Slice slice = block.getSlice(position, 0, block.getSliceLength(position)); - return deserialize(slice).asText(); + return deserialize(getSlice(block, position)).asText(); } } diff --git a/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/KdbTreeType.java b/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/KdbTreeType.java index e9daa2ebf90a..f66e3afd1863 100644 --- a/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/KdbTreeType.java +++ b/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/KdbTreeType.java @@ -19,6 +19,7 @@ import io.trino.geospatial.KdbTreeUtils; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.BlockIndex; @@ -54,7 +55,7 @@ private KdbTreeType() { // The KDB tree type should be KdbTree but can not be since KdbTree is in // both the plugin class loader and the system class loader. This was done - // so the plan optimizer can process geo spatial joins. + // so the plan optimizer can process geospatial joins. super(new TypeSignature(NAME), Object.class); } @@ -83,9 +84,10 @@ public Object getObject(Block block, int position) if (block.isNull(position)) { return null; } - Slice bytes = block.getSlice(position, 0, block.getSliceLength(position)); - KdbTree kdbTree = KdbTreeUtils.fromJson(bytes.toStringUtf8()); - return kdbTree; + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + String json = valueBlock.getSlice(valuePosition).toStringUtf8(); + return KdbTreeUtils.fromJson(json); } @Override @@ -149,14 +151,16 @@ private static void writeFlat( @ScalarOperator(READ_VALUE) private static void writeBlockToFlat( - @BlockPosition Block block, + @BlockPosition VariableWidthBlock block, @BlockIndex int position, byte[] fixedSizeSlice, int fixedSizeOffset, byte[] variableSizeSlice, int variableSizeOffset) { - Slice bytes = block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + Slice bytes = valueBlock.getSlice(valuePosition); bytes.getBytes(0, variableSizeSlice, variableSizeOffset, bytes.length()); INT_HANDLE.set(fixedSizeSlice, fixedSizeOffset, bytes.length()); diff --git a/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/SphericalGeographyType.java b/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/SphericalGeographyType.java index 3dfcff0c3edc..af01848db877 100644 --- a/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/SphericalGeographyType.java +++ b/plugin/trino-geospatial/src/main/java/io/trino/plugin/geospatial/SphericalGeographyType.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -37,7 +38,9 @@ private SphericalGeographyType() @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } @Override @@ -58,7 +61,6 @@ public Object getObjectValue(ConnectorSession session, Block block, int position if (block.isNull(position)) { return null; } - Slice slice = block.getSlice(position, 0, block.getSliceLength(position)); - return deserialize(slice).asText(); + return deserialize(getSlice(block, position)).asText(); } } diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestKdbTreeType.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestKdbTreeType.java index f800c5f228c7..a233e1f26922 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestKdbTreeType.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestKdbTreeType.java @@ -16,8 +16,8 @@ import io.trino.geospatial.KdbTree; import io.trino.geospatial.KdbTree.Node; import io.trino.geospatial.Rectangle; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.type.AbstractTestType; import org.junit.jupiter.api.Test; @@ -36,7 +36,7 @@ protected TestKdbTreeType() super(KDB_TREE, KdbTree.class, createTestBlock()); } - private static Block createTestBlock() + private static ValueBlock createTestBlock() { BlockBuilder blockBuilder = KDB_TREE.createBlockBuilder(null, 1); KdbTree kdbTree = new KdbTree( @@ -46,7 +46,7 @@ private static Block createTestBlock() Optional.empty(), Optional.empty())); KDB_TREE.writeObject(blockBuilder, kdbTree); - return blockBuilder.build(); + return blockBuilder.buildValueBlock(); } @Override diff --git a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialPartitioningInternalAggregation.java b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialPartitioningInternalAggregation.java index 8c6b70b0afd5..8207cd03895e 100644 --- a/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialPartitioningInternalAggregation.java +++ b/plugin/trino-geospatial/src/test/java/io/trino/plugin/geospatial/TestSpatialPartitioningInternalAggregation.java @@ -30,6 +30,7 @@ import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.testing.LocalQueryRunner; import org.junit.jupiter.api.Test; @@ -71,7 +72,9 @@ public void test(int partitionCount) List geometries = makeGeometries(); Block geometryBlock = makeGeometryBlock(geometries); - Block partitionCountBlock = BlockAssertions.createRepeatedValuesBlock(partitionCount, geometries.size()); + BlockBuilder blockBuilder = INTEGER.createBlockBuilder(null, 1); + INTEGER.writeInt(blockBuilder, partitionCount); + Block partitionCountBlock = RunLengthEncodedBlock.create(blockBuilder.build(), geometries.size()); Rectangle expectedExtent = new Rectangle(-10, -10, Math.nextUp(10.0), Math.nextUp(10.0)); String expectedValue = getSpatialPartitioning(expectedExtent, geometries, partitionCount); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java index a63435535669..301a43da59c2 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java @@ -228,8 +228,12 @@ public PartitionUpdateAndMergeResults getPartitionUpdateAndMergeResults(Partitio private Page buildDeletePage(Block rowIds, long writeId) { ColumnarRow columnarRow = toColumnarRow(rowIds); - checkArgument(!columnarRow.mayHaveNull(), "The rowIdsRowBlock may not have null rows"); int positionCount = rowIds.getPositionCount(); + if (columnarRow.mayHaveNull()) { + for (int position = 0; position < positionCount; position++) { + checkArgument(!columnarRow.isNull(position), "The rowIdsRowBlock may not have null rows"); + } + } // We've verified that the rowIds block has no null rows, so it's okay to get the field blocks Block[] blockArray = { RunLengthEncodedBlock.create(DELETE_OPERATION_BLOCK, positionCount), diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/CoercionUtils.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/CoercionUtils.java index 165fc86eb2a5..97585f0e3f09 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/CoercionUtils.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/CoercionUtils.java @@ -128,7 +128,7 @@ public static Type createTypeFromCoercer(TypeManager typeManager, HiveType fromH return Optional.of(new IntegerNumberUpscaleCoercer<>(fromType, toType)); } if (fromHiveType.equals(HIVE_INT) && toHiveType.equals(HIVE_LONG)) { - return Optional.of(new IntegerNumberUpscaleCoercer<>(fromType, toType)); + return Optional.of(new IntegerToBigintCoercer()); } if (fromHiveType.equals(HIVE_FLOAT) && toHiveType.equals(HIVE_DOUBLE)) { return Optional.of(new FloatToDoubleCoercer()); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/FloatToDoubleCoercer.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/FloatToDoubleCoercer.java index df8105941069..5a9ce09968e7 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/FloatToDoubleCoercer.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/FloatToDoubleCoercer.java @@ -16,6 +16,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.type.DoubleType; import io.trino.spi.type.RealType; @@ -30,6 +31,16 @@ public FloatToDoubleCoercer() super(REAL, DOUBLE); } + @Override + public Block apply(Block block) + { + // data may have already been coerced by the Avro reader + if (block instanceof LongArrayBlock) { + return block; + } + return super.apply(block); + } + @Override protected void applyCoercedValue(BlockBuilder blockBuilder, Block block, int position) { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/IntegerToBigintCoercer.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/IntegerToBigintCoercer.java new file mode 100644 index 000000000000..3cf4706b1855 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/IntegerToBigintCoercer.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.plugin.hive.coercions; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.IntegerType; + +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.IntegerType.INTEGER; + +public class IntegerToBigintCoercer + extends TypeCoercer +{ + public IntegerToBigintCoercer() + { + super(INTEGER, BIGINT); + } + + @Override + public Block apply(Block block) + { + // data may have already been coerced by the Avro reader + if (block instanceof LongArrayBlock) { + return block; + } + return super.apply(block); + } + + @Override + protected void applyCoercedValue(BlockBuilder blockBuilder, Block block, int position) + { + BIGINT.writeLong(blockBuilder, INTEGER.getInt(block, position)); + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderProjectionsAdapter.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderProjectionsAdapter.java index 00dad7292dc1..19610bf24465 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderProjectionsAdapter.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderProjectionsAdapter.java @@ -45,6 +45,7 @@ import static io.trino.spi.block.RowBlock.fromFieldBlocks; import static io.trino.spi.type.BigintType.BIGINT; import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; @@ -220,11 +221,12 @@ private static Block createLongArrayBlock(List data) private static void verifyBlock(Block actualBlock, Type outputType, Block input, Type inputType, List dereferences) { - Block expectedOutputBlock = createProjectedColumnBlock(input, outputType, inputType, dereferences); + assertThat(inputType).isInstanceOf(RowType.class); + Block expectedOutputBlock = createProjectedColumnBlock(input, outputType, (RowType) inputType, dereferences); assertBlockEquals(outputType, actualBlock, expectedOutputBlock); } - private static Block createProjectedColumnBlock(Block data, Type finalType, Type blockType, List dereferences) + private static Block createProjectedColumnBlock(Block data, Type finalType, RowType blockType, List dereferences) { if (dereferences.size() == 0) { return data; @@ -233,14 +235,14 @@ private static Block createProjectedColumnBlock(Block data, Type finalType, Type BlockBuilder builder = finalType.createBlockBuilder(null, data.getPositionCount()); for (int i = 0; i < data.getPositionCount(); i++) { - Type sourceType = blockType; + RowType sourceType = blockType; SqlRow currentData = null; boolean isNull = data.isNull(i); if (!isNull) { - // Get SingleRowBlock corresponding to element at position i - currentData = data.getObject(i, SqlRow.class); + // Get SqlRow corresponding to element at position i + currentData = sourceType.getObject(data, i); } // Apply all dereferences except for the last one, because the type can be different @@ -253,14 +255,14 @@ private static Block createProjectedColumnBlock(Block data, Type finalType, Type int fieldIndex = dereferences.get(j); Block fieldBlock = currentData.getRawFieldBlock(fieldIndex); - RowType rowType = (RowType) sourceType; + RowType rowType = sourceType; int rawIndex = currentData.getRawIndex(); if (fieldBlock.isNull(rawIndex)) { currentData = null; } else { - sourceType = rowType.getFields().get(fieldIndex).getType(); - currentData = fieldBlock.getObject(rawIndex, SqlRow.class); + sourceType = (RowType) rowType.getFields().get(fieldIndex).getType(); + currentData = sourceType.getObject(fieldBlock, rawIndex); } isNull = isNull || (currentData == null); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/util/TestStatistics.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/util/TestStatistics.java index 64b4646b608e..d47f123f9e1d 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/util/TestStatistics.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/util/TestStatistics.java @@ -24,7 +24,6 @@ import io.trino.plugin.hive.metastore.HiveColumnStatistics; import io.trino.plugin.hive.metastore.IntegerStatistics; import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; import io.trino.spi.statistics.ComputedStatistics; import io.trino.spi.statistics.TableStatisticType; import io.trino.spi.type.Type; @@ -36,7 +35,6 @@ import java.util.Optional; import java.util.OptionalDouble; import java.util.OptionalLong; -import java.util.function.Function; import static io.trino.plugin.hive.HiveBasicStatistics.createEmptyStatistics; import static io.trino.plugin.hive.HiveBasicStatistics.createZeroStatistics; @@ -57,6 +55,7 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.TypeUtils.writeNativeValue; import static io.trino.spi.type.VarcharType.VARCHAR; import static java.lang.Float.floatToIntBits; import static org.assertj.core.api.Assertions.assertThat; @@ -318,20 +317,13 @@ public void testMergeHiveColumnStatisticsMap() @Test public void testFromComputedStatistics() { - Function singleIntegerValueBlock = value -> - { - BlockBuilder blockBuilder = BIGINT.createBlockBuilder(null, 1); - BIGINT.writeLong(blockBuilder, value); - return blockBuilder.build(); - }; - ComputedStatistics statistics = ComputedStatistics.builder(ImmutableList.of(), ImmutableList.of()) - .addTableStatistic(TableStatisticType.ROW_COUNT, singleIntegerValueBlock.apply(5)) - .addColumnStatistic(MIN_VALUE.createColumnStatisticMetadata("a_column"), singleIntegerValueBlock.apply(1)) - .addColumnStatistic(MAX_VALUE.createColumnStatisticMetadata("a_column"), singleIntegerValueBlock.apply(5)) - .addColumnStatistic(NUMBER_OF_DISTINCT_VALUES.createColumnStatisticMetadata("a_column"), singleIntegerValueBlock.apply(5)) - .addColumnStatistic(NUMBER_OF_NON_NULL_VALUES.createColumnStatisticMetadata("a_column"), singleIntegerValueBlock.apply(5)) - .addColumnStatistic(NUMBER_OF_NON_NULL_VALUES.createColumnStatisticMetadata("b_column"), singleIntegerValueBlock.apply(4)) + .addTableStatistic(TableStatisticType.ROW_COUNT, writeNativeValue(BIGINT, 5L)) + .addColumnStatistic(MIN_VALUE.createColumnStatisticMetadata("a_column"), writeNativeValue(INTEGER, 1L)) + .addColumnStatistic(MAX_VALUE.createColumnStatisticMetadata("a_column"), writeNativeValue(INTEGER, 5L)) + .addColumnStatistic(NUMBER_OF_DISTINCT_VALUES.createColumnStatisticMetadata("a_column"), writeNativeValue(BIGINT, 5L)) + .addColumnStatistic(NUMBER_OF_NON_NULL_VALUES.createColumnStatisticMetadata("a_column"), writeNativeValue(BIGINT, 5L)) + .addColumnStatistic(NUMBER_OF_NON_NULL_VALUES.createColumnStatisticMetadata("b_column"), writeNativeValue(BIGINT, 4L)) .build(); Map columnTypes = ImmutableMap.of("a_column", INTEGER, "b_column", VARCHAR); diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroDataConversion.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroDataConversion.java index 367b8f99b9e4..83fc9e7fa5ad 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroDataConversion.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroDataConversion.java @@ -196,7 +196,7 @@ public static Object toIcebergAvroObject(Type type, org.apache.iceberg.types.Typ Type elementType = arrayType.getElementType(); org.apache.iceberg.types.Type elementIcebergType = icebergType.asListType().elementType(); - Block arrayBlock = block.getObject(position, Block.class); + Block arrayBlock = arrayType.getObject(block, position); List list = new ArrayList<>(arrayBlock.getPositionCount()); for (int i = 0; i < arrayBlock.getPositionCount(); i++) { @@ -212,7 +212,7 @@ public static Object toIcebergAvroObject(Type type, org.apache.iceberg.types.Typ org.apache.iceberg.types.Type keyIcebergType = icebergType.asMapType().keyType(); org.apache.iceberg.types.Type valueIcebergType = icebergType.asMapType().valueType(); - SqlMap sqlMap = block.getObject(position, SqlMap.class); + SqlMap sqlMap = mapType.getObject(block, position); int rawOffset = sqlMap.getRawOffset(); Block rawKeyBlock = sqlMap.getRawKeyBlock(); Block rawValueBlock = sqlMap.getRawValueBlock(); diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/IcebergThetaSketchForStats.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/IcebergThetaSketchForStats.java index 1180e9c2b582..fdafcddb6b15 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/IcebergThetaSketchForStats.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/aggregation/IcebergThetaSketchForStats.java @@ -13,8 +13,8 @@ */ package io.trino.plugin.iceberg.aggregation; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; @@ -53,7 +53,7 @@ private IcebergThetaSketchForStats() {} @InputFunction @TypeParameter("T") - public static void input(@TypeParameter("T") Type type, @AggregationState DataSketchState state, @BlockPosition @SqlType("T") Block block, @BlockIndex int index) + public static void input(@TypeParameter("T") Type type, @AggregationState DataSketchState state, @BlockPosition @SqlType("T") ValueBlock block, @BlockIndex int index) { verify(!block.isNull(index), "Input function is not expected to be called on a NULL input"); diff --git a/plugin/trino-ml/src/main/java/io/trino/plugin/ml/type/ModelType.java b/plugin/trino-ml/src/main/java/io/trino/plugin/ml/type/ModelType.java index 63ecbdde38ea..a8d023df4621 100644 --- a/plugin/trino-ml/src/main/java/io/trino/plugin/ml/type/ModelType.java +++ b/plugin/trino-ml/src/main/java/io/trino/plugin/ml/type/ModelType.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -45,7 +46,9 @@ protected ModelType(TypeSignature signature) @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } @Override diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java index f317a72bca24..9639d4546d57 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java @@ -202,7 +202,7 @@ private Object getObjectValue(Type type, Block block, int position) if (type instanceof ArrayType arrayType) { Type elementType = arrayType.getElementType(); - Block arrayBlock = block.getObject(position, Block.class); + Block arrayBlock = arrayType.getObject(block, position); List list = new ArrayList<>(arrayBlock.getPositionCount()); for (int i = 0; i < arrayBlock.getPositionCount(); i++) { @@ -216,7 +216,7 @@ private Object getObjectValue(Type type, Block block, int position) Type keyType = mapType.getKeyType(); Type valueType = mapType.getValueType(); - SqlMap sqlMap = block.getObject(position, SqlMap.class); + SqlMap sqlMap = mapType.getObject(block, position); int size = sqlMap.getSize(); int rawOffset = sqlMap.getRawOffset(); Block rawKeyBlock = sqlMap.getRawKeyBlock(); diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/ObjectIdType.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/ObjectIdType.java index 7124ec9ea47f..d9dcf28a849c 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/ObjectIdType.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/ObjectIdType.java @@ -17,6 +17,7 @@ import io.airlift.slice.Slice; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.AbstractVariableWidthType; @@ -68,13 +69,15 @@ public Object getObjectValue(ConnectorSession session, Block block, int position } // TODO: There's no way to represent string value of a custom type - return new SqlVarbinary(block.getSlice(position, 0, block.getSliceLength(position)).getBytes()); + return new SqlVarbinary(getSlice(block, position).getBytes()); } @Override public Slice getSlice(Block block, int position) { - return block.getSlice(position, 0, block.getSliceLength(position)); + VariableWidthBlock valueBlock = (VariableWidthBlock) block.getUnderlyingValueBlock(); + int valuePosition = block.getUnderlyingValuePosition(position); + return valueBlock.getSlice(valuePosition); } @Override diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 82644e0414dc..c4a87c4df937 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -75,6 +75,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.MapBlock; import io.trino.spi.block.SqlMap; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; @@ -1306,8 +1307,8 @@ private ObjectReadFunction varcharMapReadFunction() varcharMapType.getValueType().writeSlice(valueBlockBuilder, utf8Slice(entry.getValue())); } } - return varcharMapType.createBlockFromKeyValue(Optional.empty(), new int[] {0, map.size()}, keyBlockBuilder.build(), valueBlockBuilder.build()) - .getObject(0, SqlMap.class); + MapBlock mapBlock = varcharMapType.createBlockFromKeyValue(Optional.empty(), new int[]{0, map.size()}, keyBlockBuilder.build(), valueBlockBuilder.build()); + return varcharMapType.getObject(mapBlock, 0); }); } diff --git a/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusRecordCursor.java b/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusRecordCursor.java index cce1f2a47798..40d160635cf8 100644 --- a/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusRecordCursor.java +++ b/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusRecordCursor.java @@ -262,9 +262,9 @@ else if (type instanceof MapType mapType) { private static Object readObject(Type type, Block block, int position) { - if (type instanceof ArrayType) { - Type elementType = ((ArrayType) type).getElementType(); - return getArrayFromBlock(elementType, block.getObject(position, Block.class)); + if (type instanceof ArrayType arrayType) { + Type elementType = arrayType.getElementType(); + return getArrayFromBlock(elementType, arrayType.getObject(block, position)); } if (type instanceof MapType mapType) { return getMapFromSqlMap(type, mapType.getObject(block, position)); diff --git a/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java b/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java index 62ca20d61bae..9f48902a4bf5 100644 --- a/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java +++ b/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java @@ -39,7 +39,7 @@ public BenchmarkAggregationFunction(ResolvedFunction resolvedFunction, Aggregati BoundSignature signature = resolvedFunction.getSignature(); intermediateType = getOnlyElement(aggregationImplementation.getAccumulatorStateDescriptors()).getSerializer().getSerializedType(); finalType = signature.getReturnType(); - accumulatorFactory = generateAccumulatorFactory(signature, aggregationImplementation, resolvedFunction.getFunctionNullability()); + accumulatorFactory = generateAccumulatorFactory(signature, aggregationImplementation, resolvedFunction.getFunctionNullability(), true); } public AggregatorFactory bind(List inputChannels) diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java index 26be645eeba9..5cb4c66e35d3 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java @@ -71,7 +71,6 @@ import static io.trino.testing.QueryAssertions.assertEqualsIgnoreOrder; import static io.trino.tests.QueryTemplate.parameter; import static io.trino.tests.QueryTemplate.queryTemplate; -import static io.trino.type.UnknownType.UNKNOWN; import static java.lang.String.format; import static java.util.Collections.nCopies; import static java.util.stream.Collectors.joining; @@ -1382,7 +1381,7 @@ public void testDescribeInputNoParameters() .addPreparedStatement("my_query", "SELECT * FROM nation") .build(); assertThat(query(session, "DESCRIBE INPUT my_query")) - .hasOutputTypes(List.of(UNKNOWN, UNKNOWN)) + .hasOutputTypes(List.of(BIGINT, VARCHAR)) .returnsEmptyResult(); }