diff --git a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/KafkaFilterManager.java b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/KafkaFilterManager.java index 0e814a88b6d8..5c853d9011ac 100644 --- a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/KafkaFilterManager.java +++ b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/KafkaFilterManager.java @@ -47,7 +47,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.plugin.kafka.KafkaErrorCode.KAFKA_SPLIT_ERROR; import static io.trino.plugin.kafka.KafkaInternalFieldManager.InternalFieldId.OFFSET_TIMESTAMP_FIELD; import static io.trino.plugin.kafka.KafkaInternalFieldManager.InternalFieldId.PARTITION_ID_FIELD; @@ -57,6 +56,7 @@ import static java.lang.Math.floorDiv; import static java.lang.String.format; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toMap; public class KafkaFilterManager { @@ -123,13 +123,19 @@ public KafkaFilteringResult getKafkaFilterResult( try (KafkaConsumer kafkaConsumer = consumerFactory.create(session)) { // filter negative value to avoid java.lang.IllegalArgumentException when using KafkaConsumer offsetsForTimes if (offsetTimestampRanged.get().begin() > INVALID_KAFKA_RANGE_INDEX) { - partitionBeginOffsets = overridePartitionBeginOffsets(partitionBeginOffsets, - partition -> findOffsetsForTimestampGreaterOrEqual(kafkaConsumer, partition, offsetTimestampRanged.get().begin())); + long partitionBeginTimestamp = floorDiv(offsetTimestampRanged.get().begin(), MICROSECONDS_PER_MILLISECOND); + Map partitionBeginTimestamps = partitionBeginOffsets.entrySet().stream() + .collect(toMap(Map.Entry::getKey, _ -> partitionBeginTimestamp)); + Map> beginOffsets = findOffsetsForTimestampGreaterOrEqual(kafkaConsumer, partitionBeginTimestamps); + partitionBeginOffsets = overridePartitionBeginOffsets(partitionBeginOffsets, beginOffsets::get); } if (isTimestampUpperBoundPushdownEnabled(session, kafkaTableHandle.topicName())) { if (offsetTimestampRanged.get().end() > INVALID_KAFKA_RANGE_INDEX) { - partitionEndOffsets = overridePartitionEndOffsets(partitionEndOffsets, - partition -> findOffsetsForTimestampGreaterOrEqual(kafkaConsumer, partition, offsetTimestampRanged.get().end())); + long partitionEndTimestamp = floorDiv(offsetTimestampRanged.get().end(), MICROSECONDS_PER_MILLISECOND); + Map partitionEndTimestamps = partitionEndOffsets.entrySet().stream() + .collect(toMap(Map.Entry::getKey, _ -> partitionEndTimestamp)); + Map> endOffsets = findOffsetsForTimestampGreaterOrEqual(kafkaConsumer, partitionEndTimestamps); + partitionEndOffsets = overridePartitionEndOffsets(partitionEndOffsets, endOffsets::get); } } } @@ -172,11 +178,12 @@ private boolean isTimestampUpperBoundPushdownEnabled(ConnectorSession session, S return KafkaSessionProperties.isTimestampUpperBoundPushdownEnabled(session); } - private static Optional findOffsetsForTimestampGreaterOrEqual(KafkaConsumer kafkaConsumer, TopicPartition topicPartition, long timestamp) + private static Map> findOffsetsForTimestampGreaterOrEqual(KafkaConsumer kafkaConsumer, Map timestamps) { - final long transferTimestamp = floorDiv(timestamp, MICROSECONDS_PER_MILLISECOND); - Map topicPartitionOffsets = kafkaConsumer.offsetsForTimes(ImmutableMap.of(topicPartition, transferTimestamp)); - return Optional.ofNullable(getOnlyElement(topicPartitionOffsets.values(), null)).map(OffsetAndTimestamp::offset); + Map topicPartitionOffsetAndTimestamps = kafkaConsumer.offsetsForTimes(timestamps); + return topicPartitionOffsetAndTimestamps.entrySet().stream() + .collect(toMap(Map.Entry::getKey, entry -> Optional.of(entry.getValue()) + .map(OffsetAndTimestamp::offset))); } private static Map overridePartitionBeginOffsets(Map partitionBeginOffsets,