Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.type.BlockTypeOperators;
import io.trino.type.BlockTypeOperators.BlockPositionEqual;
import io.trino.type.BlockTypeOperators.BlockPositionHashCode;
import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
Expand All @@ -40,7 +40,7 @@
import static io.trino.metadata.FunctionKind.AGGREGATE;
import static io.trino.metadata.Signature.comparableTypeParameter;
import static io.trino.metadata.Signature.typeVariable;
import static io.trino.operator.aggregation.TypedSet.createEqualityTypedSet;
import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet;
import static io.trino.spi.type.TypeSignature.arrayType;
import static io.trino.spi.type.TypeSignature.mapType;
import static io.trino.spi.type.TypeSignature.rowType;
Expand All @@ -57,7 +57,7 @@ public class MultimapAggregationFunction
MultimapAggregationFunction.class,
"output",
Type.class,
BlockPositionEqual.class,
BlockPositionIsDistinctFrom.class,
BlockPositionHashCode.class,
Type.class,
MultimapAggregationState.class,
Expand Down Expand Up @@ -103,7 +103,7 @@ public MultimapAggregationFunction(BlockTypeOperators blockTypeOperators)
public AggregationMetadata specialize(BoundSignature boundSignature)
{
Type keyType = boundSignature.getArgumentType(0);
BlockPositionEqual keyEqual = blockTypeOperators.getEqualOperator(keyType);
BlockPositionIsDistinctFrom keyDistinctOperator = blockTypeOperators.getDistinctFromOperator(keyType);
BlockPositionHashCode keyHashCode = blockTypeOperators.getHashCodeOperator(keyType);

Type valueType = boundSignature.getArgumentType(1);
Expand All @@ -114,7 +114,7 @@ public AggregationMetadata specialize(BoundSignature boundSignature)
INPUT_FUNCTION,
Optional.empty(),
Optional.of(COMBINE_FUNCTION),
MethodHandles.insertArguments(OUTPUT_FUNCTION, 0, keyType, keyEqual, keyHashCode, valueType),
MethodHandles.insertArguments(OUTPUT_FUNCTION, 0, keyType, keyDistinctOperator, keyHashCode, valueType),
ImmutableList.of(new AccumulatorStateDescriptor<>(
MultimapAggregationState.class,
stateSerializer,
Expand All @@ -131,7 +131,7 @@ public static void combine(MultimapAggregationState state, MultimapAggregationSt
state.merge(otherState);
}

public static void output(Type keyType, BlockPositionEqual keyEqual, BlockPositionHashCode keyHashCode, Type valueType, MultimapAggregationState state, BlockBuilder out)
public static void output(Type keyType, BlockPositionIsDistinctFrom keyDistinctOperator, BlockPositionHashCode keyHashCode, Type valueType, MultimapAggregationState state, BlockBuilder out)
{
if (state.isEmpty()) {
out.appendNull();
Expand All @@ -141,7 +141,7 @@ public static void output(Type keyType, BlockPositionEqual keyEqual, BlockPositi
ObjectBigArray<BlockBuilder> valueArrayBlockBuilders = new ObjectBigArray<>();
valueArrayBlockBuilders.ensureCapacity(state.getEntryCount());
BlockBuilder distinctKeyBlockBuilder = keyType.createBlockBuilder(null, state.getEntryCount(), expectedValueSize(keyType, 100));
TypedSet keySet = createEqualityTypedSet(keyType, keyEqual, keyHashCode, state.getEntryCount(), NAME);
TypedSet keySet = createDistinctTypedSet(keyType, keyDistinctOperator, keyHashCode, state.getEntryCount(), NAME);

state.forEach((key, value, keyValueIndex) -> {
// Merge values of the same key into an array
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.Type;
import io.trino.type.BlockTypeOperators.BlockPositionEqual;
import io.trino.type.BlockTypeOperators.BlockPositionHashCode;
import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom;

import static io.trino.operator.aggregation.TypedSet.createEqualityTypedSet;
import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static io.trino.spi.function.OperatorType.EQUAL;
import static io.trino.spi.function.OperatorType.HASH_CODE;
import static io.trino.spi.function.OperatorType.IS_DISTINCT_FROM;

@ScalarFunction("array_except")
@Description("Returns an array of elements that are in the first array but not the second, without duplicates.")
Expand All @@ -44,9 +43,9 @@ private ArrayExceptFunction() {}
public static Block except(
@TypeParameter("E") Type type,
@OperatorDependency(
operator = EQUAL,
operator = IS_DISTINCT_FROM,
argumentTypes = {"E", "E"},
convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) BlockPositionEqual elementEqual,
convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = FAIL_ON_NULL)) BlockPositionIsDistinctFrom isDistinctOperator,
@OperatorDependency(
operator = HASH_CODE,
argumentTypes = "E",
Expand All @@ -61,7 +60,7 @@ public static Block except(
return leftArray;
}

TypedSet typedSet = createEqualityTypedSet(type, elementEqual, elementHashCode, leftPositionCount, "array_except");
TypedSet typedSet = createDistinctTypedSet(type, isDistinctOperator, elementHashCode, leftPositionCount, "array_except");
BlockBuilder distinctElementBlockBuilder = type.createBlockBuilder(null, leftPositionCount);
for (int i = 0; i < rightPositionCount; i++) {
typedSet.add(rightArray, i);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,18 @@
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.Type;
import io.trino.type.BlockTypeOperators.BlockPositionEqual;
import io.trino.type.BlockTypeOperators.BlockPositionHashCode;
import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import it.unimi.dsi.fastutil.longs.LongSet;

import java.util.concurrent.atomic.AtomicBoolean;

import static io.trino.operator.aggregation.TypedSet.createEqualityTypedSet;
import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static io.trino.spi.function.OperatorType.EQUAL;
import static io.trino.spi.function.OperatorType.HASH_CODE;
import static io.trino.spi.function.OperatorType.IS_DISTINCT_FROM;
import static io.trino.spi.type.BigintType.BIGINT;

@ScalarFunction("array_union")
Expand All @@ -49,9 +48,9 @@ private ArrayUnionFunction() {}
public static Block union(
@TypeParameter("E") Type type,
@OperatorDependency(
operator = EQUAL,
operator = IS_DISTINCT_FROM,
Comment thread
jirassimok marked this conversation as resolved.
Outdated
argumentTypes = {"E", "E"},
convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) BlockPositionEqual elementEqual,
convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = FAIL_ON_NULL)) BlockPositionIsDistinctFrom isDistinctOperator,
@OperatorDependency(
operator = HASH_CODE,
argumentTypes = "E",
Expand All @@ -62,9 +61,9 @@ public static Block union(
int leftArrayCount = leftArray.getPositionCount();
int rightArrayCount = rightArray.getPositionCount();
BlockBuilder distinctElementBlockBuilder = type.createBlockBuilder(null, leftArrayCount + rightArrayCount);
TypedSet typedSet = createEqualityTypedSet(
TypedSet typedSet = createDistinctTypedSet(
type,
elementEqual,
isDistinctOperator,
elementHashCode,
distinctElementBlockBuilder,
leftArrayCount + rightArrayCount,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@
import io.trino.spi.type.TypeSignature;
import io.trino.sql.gen.VarArgsToArrayAdapterGenerator.MethodHandleAndConstructor;
import io.trino.type.BlockTypeOperators;
import io.trino.type.BlockTypeOperators.BlockPositionEqual;
import io.trino.type.BlockTypeOperators.BlockPositionHashCode;
import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.util.Optional;

import static io.trino.metadata.FunctionKind.SCALAR;
import static io.trino.metadata.Signature.typeVariable;
import static io.trino.operator.aggregation.TypedSet.createEqualityTypedSet;
import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
Expand All @@ -61,7 +61,7 @@ public final class MapConcatFunction
MapConcatFunction.class,
"mapConcat",
MapType.class,
BlockPositionEqual.class,
BlockPositionIsDistinctFrom.class,
BlockPositionHashCode.class,
Object.class,
Block[].class);
Expand Down Expand Up @@ -95,14 +95,14 @@ protected ScalarFunctionImplementation specialize(BoundSignature boundSignature)

MapType mapType = (MapType) boundSignature.getReturnType();
Type keyType = mapType.getKeyType();
BlockPositionEqual keyEqual = blockTypeOperators.getEqualOperator(keyType);
BlockPositionIsDistinctFrom keysDistinctOperator = blockTypeOperators.getDistinctFromOperator(keyType);
BlockPositionHashCode keyHashCode = blockTypeOperators.getHashCodeOperator(keyType);

MethodHandleAndConstructor methodHandleAndConstructor = generateVarArgsToArrayAdapter(
Block.class,
Block.class,
boundSignature.getArity(),
MethodHandles.insertArguments(METHOD_HANDLE, 0, mapType, keyEqual, keyHashCode),
MethodHandles.insertArguments(METHOD_HANDLE, 0, mapType, keysDistinctOperator, keyHashCode),
USER_STATE_FACTORY.bindTo(mapType));

return new ChoicesScalarFunctionImplementation(
Expand All @@ -120,7 +120,7 @@ public static Object createMapState(MapType mapType)
}

@UsedByGeneratedCode
public static Block mapConcat(MapType mapType, BlockPositionEqual keyEqual, BlockPositionHashCode keyHashCode, Object state, Block[] maps)
public static Block mapConcat(MapType mapType, BlockPositionIsDistinctFrom keysDistinctOperator, BlockPositionHashCode keyHashCode, Object state, Block[] maps)
{
int entries = 0;
int lastMapIndex = maps.length - 1;
Expand All @@ -144,7 +144,7 @@ public static Block mapConcat(MapType mapType, BlockPositionEqual keyEqual, Bloc
// TODO: we should move TypedSet into user state as well
Type keyType = mapType.getKeyType();
Type valueType = mapType.getValueType();
TypedSet typedSet = createEqualityTypedSet(keyType, keyEqual, keyHashCode, entries / 2, FUNCTION_NAME);
TypedSet typedSet = createDistinctTypedSet(keyType, keysDistinctOperator, keyHashCode, entries / 2, FUNCTION_NAME);
BlockBuilder mapBlockBuilder = pageBuilder.getBlockBuilder(0);
BlockBuilder blockBuilder = mapBlockBuilder.beginBlockEntry();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,15 @@
import io.trino.spi.type.MapType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.type.BlockTypeOperators.BlockPositionEqual;
import io.trino.type.BlockTypeOperators.BlockPositionHashCode;
import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom;

import static io.trino.operator.aggregation.TypedSet.createEqualityTypedSet;
import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static io.trino.spi.function.OperatorType.EQUAL;
import static io.trino.spi.function.OperatorType.HASH_CODE;
import static io.trino.spi.function.OperatorType.IS_DISTINCT_FROM;
import static java.lang.String.format;

@ScalarFunction("map_from_entries")
Expand All @@ -61,9 +60,9 @@ public MapFromEntriesFunction(@TypeParameter("map(K,V)") Type mapType)
@SqlNullable
public Block mapFromEntries(
@OperatorDependency(
operator = EQUAL,
operator = IS_DISTINCT_FROM,
argumentTypes = {"K", "K"},
convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) BlockPositionEqual keyEqual,
convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = FAIL_ON_NULL)) BlockPositionIsDistinctFrom keysDistinctOperator,
@OperatorDependency(
operator = HASH_CODE,
argumentTypes = "K",
Expand All @@ -84,7 +83,7 @@ public Block mapFromEntries(

BlockBuilder mapBlockBuilder = pageBuilder.getBlockBuilder(0);
BlockBuilder resultBuilder = mapBlockBuilder.beginBlockEntry();
TypedSet uniqueKeys = createEqualityTypedSet(keyType, keyEqual, keyHashCode, entryCount, "map_from_entries");
TypedSet uniqueKeys = createDistinctTypedSet(keyType, keysDistinctOperator, keyHashCode, entryCount, "map_from_entries");

for (int i = 0; i < entryCount; i++) {
if (mapEntries.isNull(i)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.type.BlockTypeOperators;
import io.trino.type.BlockTypeOperators.BlockPositionEqual;
import io.trino.type.BlockTypeOperators.BlockPositionHashCode;
import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
Expand All @@ -41,7 +41,7 @@
import static com.google.common.base.Preconditions.checkState;
import static io.trino.metadata.Signature.castableToTypeParameter;
import static io.trino.metadata.Signature.typeVariable;
import static io.trino.operator.aggregation.TypedSet.createEqualityTypedSet;
import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet;
import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static io.trino.spi.block.MethodHandleUtil.compose;
import static io.trino.spi.block.MethodHandleUtil.nativeValueWriter;
Expand All @@ -66,7 +66,7 @@ public final class MapToMapCast
MethodHandle.class,
MethodHandle.class,
Type.class,
BlockPositionEqual.class,
BlockPositionIsDistinctFrom.class,
BlockPositionHashCode.class,
ConnectorSession.class,
Block.class);
Expand Down Expand Up @@ -116,7 +116,7 @@ public ScalarFunctionImplementation specialize(BoundSignature boundSignature, Fu

MethodHandle keyProcessor = buildProcessor(functionDependencies, fromKeyType, toKeyType, true);
MethodHandle valueProcessor = buildProcessor(functionDependencies, fromValueType, toValueType, false);
BlockPositionEqual keyEqual = blockTypeOperators.getEqualOperator(toKeyType);
BlockPositionIsDistinctFrom keyEqual = blockTypeOperators.getDistinctFromOperator(toKeyType);
BlockPositionHashCode keyHashCode = blockTypeOperators.getHashCodeOperator(toKeyType);
MethodHandle target = MethodHandles.insertArguments(METHOD_HANDLE, 0, keyProcessor, valueProcessor, toMapType, keyEqual, keyHashCode);
return new ChoicesScalarFunctionImplementation(boundSignature, NULLABLE_RETURN, ImmutableList.of(NEVER_NULL), target);
Expand Down Expand Up @@ -238,14 +238,14 @@ public static Block mapCast(
MethodHandle keyProcessFunction,
MethodHandle valueProcessFunction,
Type targetType,
BlockPositionEqual keyEqual,
BlockPositionIsDistinctFrom keyDistinctOperator,
BlockPositionHashCode keyHashCode,
ConnectorSession session,
Block fromMap)
{
checkState(targetType.getTypeParameters().size() == 2, "Expect two type parameters for targetType");
Type toKeyType = targetType.getTypeParameters().get(0);
TypedSet resultKeys = createEqualityTypedSet(toKeyType, keyEqual, keyHashCode, fromMap.getPositionCount() / 2, "map-to-map cast");
TypedSet resultKeys = createDistinctTypedSet(toKeyType, keyDistinctOperator, keyHashCode, fromMap.getPositionCount() / 2, "map-to-map cast");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can somewhat test it with:

select cast(map(array['NaN', 'NaN '], array[1, 2]) as map<double, bigint>);

It will fail now. It used to not.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test.


// Cast the keys into a new block
BlockBuilder keyBlockBuilder = toKeyType.createBlockBuilder(null, fromMap.getPositionCount() / 2);
Expand Down
Loading