Skip to content

Commit

Permalink
[KafkaIO] Determine partition backlog using endOffsets instead of see…
Browse files Browse the repository at this point in the history
…k2End and position (#32889)

* Determine partition backlog using endOffsets instead of seekToEnd and position

* Remove offset consumer assignments

* Explicitly update partitions and start/end offsets for relevant mock consumers

* Clean up partition and offset updates in tests
  • Loading branch information
sjvanrossum authored Nov 6, 2024
1 parent 36c19a3 commit deeddd1
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.MockConsumer;
Expand Down Expand Up @@ -138,10 +137,6 @@ public synchronized void assign(final Collection<TopicPartition> assigned) {
.collect(Collectors.toList());
super.assign(realPartitions);
assignedPartitions.set(ImmutableList.copyOf(realPartitions));
for (TopicPartition tp : realPartitions) {
updateBeginningOffsets(ImmutableMap.of(tp, 0L));
updateEndOffsets(ImmutableMap.of(tp, (long) kafkaRecords.get(tp).size()));
}
}
// Override offsetsForTimes() in order to look up the offsets by timestamp.
@Override
Expand All @@ -163,9 +158,12 @@ public synchronized Map<TopicPartition, OffsetAndTimestamp> offsetsForTimes(
}
};

for (String topic : getTopics()) {
consumer.updatePartitions(topic, partitionInfoMap.get(topic));
}
partitionInfoMap.forEach(consumer::updatePartitions);
consumer.updateBeginningOffsets(
kafkaRecords.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> 0L)));
consumer.updateEndOffsets(
kafkaRecords.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> (long) e.getValue().size())));

Runnable recordEnqueueTask =
new Runnable() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ public boolean start() throws IOException {
name, spec.getOffsetConsumerConfig(), spec.getConsumerConfig());

offsetConsumer = spec.getConsumerFactoryFn().apply(offsetConsumerConfig);
ConsumerSpEL.evaluateAssign(offsetConsumer, topicPartitions);

// Fetch offsets once before running periodically.
updateLatestOffsets();
Expand Down Expand Up @@ -711,23 +710,28 @@ private void setupInitialOffset(PartitionState<K, V> pState) {
// Called from setupInitialOffset() at the start and then periodically from offsetFetcher thread.
private void updateLatestOffsets() {
Consumer<byte[], byte[]> offsetConsumer = Preconditions.checkStateNotNull(this.offsetConsumer);
for (PartitionState<K, V> p : partitionStates) {
try {
Instant fetchTime = Instant.now();
ConsumerSpEL.evaluateSeek2End(offsetConsumer, p.topicPartition);
long offset = offsetConsumer.position(p.topicPartition);
p.setLatestOffset(offset, fetchTime);
} catch (Exception e) {
if (closed.get()) { // Ignore the exception if the reader is closed.
break;
}
List<TopicPartition> topicPartitions =
Preconditions.checkStateNotNull(source.getSpec().getTopicPartitions());
Instant fetchTime = Instant.now();
try {
Map<TopicPartition, Long> endOffsets = offsetConsumer.endOffsets(topicPartitions);
for (PartitionState<K, V> p : partitionStates) {
p.setLatestOffset(
Preconditions.checkStateNotNull(
endOffsets.get(p.topicPartition),
"No end offset found for partition %s.",
p.topicPartition),
fetchTime);
}
} catch (Exception e) {
if (!closed.get()) { // Ignore the exception if the reader is closed.
LOG.warn(
"{}: exception while fetching latest offset for partition {}. will be retried.",
"{}: exception while fetching latest offset for partitions {}. will be retried.",
this,
p.topicPartition,
topicPartitions,
e);
// Don't update the latest offset.
}
// Don't update the latest offset.
}

LOG.debug("{}: backlog {}", this, getSplitBacklogBytes());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -253,13 +254,16 @@ private static class KafkaLatestOffsetEstimator
Consumer<byte[], byte[]> offsetConsumer, TopicPartition topicPartition) {
this.offsetConsumer = offsetConsumer;
this.topicPartition = topicPartition;
ConsumerSpEL.evaluateAssign(this.offsetConsumer, ImmutableList.of(this.topicPartition));
memoizedBacklog =
Suppliers.memoizeWithExpiration(
() -> {
synchronized (offsetConsumer) {
ConsumerSpEL.evaluateSeek2End(offsetConsumer, topicPartition);
return offsetConsumer.position(topicPartition);
return Preconditions.checkStateNotNull(
offsetConsumer
.endOffsets(Collections.singleton(topicPartition))
.get(topicPartition),
"No end offset found for partition %s.",
topicPartition);
}
},
1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader;
import org.apache.beam.sdk.io.kafka.KafkaIO.Read.FakeFlinkPipelineOptions;
import org.apache.beam.sdk.io.kafka.KafkaMocks.PositionErrorConsumerFactory;
import org.apache.beam.sdk.io.kafka.KafkaMocks.EndOffsetErrorConsumerFactory;
import org.apache.beam.sdk.io.kafka.KafkaMocks.SendErrorProducerFactory;
import org.apache.beam.sdk.metrics.DistributionResult;
import org.apache.beam.sdk.metrics.Lineage;
Expand Down Expand Up @@ -267,10 +267,6 @@ private static MockConsumer<byte[], byte[]> mkMockConsumer(
public synchronized void assign(final Collection<TopicPartition> assigned) {
super.assign(assigned);
assignedPartitions.set(ImmutableList.copyOf(assigned));
for (TopicPartition tp : assigned) {
updateBeginningOffsets(ImmutableMap.of(tp, 0L));
updateEndOffsets(ImmutableMap.of(tp, (long) records.get(tp).size()));
}
}
// Override offsetsForTimes() in order to look up the offsets by timestamp.
@Override
Expand All @@ -290,9 +286,12 @@ public synchronized Map<TopicPartition, OffsetAndTimestamp> offsetsForTimes(
}
};

for (String topic : topics) {
consumer.updatePartitions(topic, partitionMap.get(topic));
}
partitionMap.forEach(consumer::updatePartitions);
consumer.updateBeginningOffsets(
records.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> 0L)));
consumer.updateEndOffsets(
records.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> (long) e.getValue().size())));

// MockConsumer does not maintain any relationship between partition seek position and the
// records added. e.g. if we add 10 records to a partition and then seek to end of the
Expand Down Expand Up @@ -1525,13 +1524,14 @@ public void testUnboundedReaderLogsCommitFailure() throws Exception {

List<String> topics = ImmutableList.of("topic_a");

PositionErrorConsumerFactory positionErrorConsumerFactory = new PositionErrorConsumerFactory();
EndOffsetErrorConsumerFactory endOffsetErrorConsumerFactory =
new EndOffsetErrorConsumerFactory();

UnboundedSource<KafkaRecord<Integer, Long>, KafkaCheckpointMark> source =
KafkaIO.<Integer, Long>read()
.withBootstrapServers("myServer1:9092,myServer2:9092")
.withTopics(topics)
.withConsumerFactoryFn(positionErrorConsumerFactory)
.withConsumerFactoryFn(endOffsetErrorConsumerFactory)
.withKeyDeserializer(IntegerDeserializer.class)
.withValueDeserializer(LongDeserializer.class)
.makeSource();
Expand All @@ -1540,7 +1540,7 @@ public void testUnboundedReaderLogsCommitFailure() throws Exception {

reader.start();

unboundedReaderExpectedLogs.verifyWarn("exception while fetching latest offset for partition");
unboundedReaderExpectedLogs.verifyWarn("exception while fetching latest offset for partitions");

reader.close();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.beam.sdk.io.kafka;

import java.io.Serializable;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand All @@ -27,8 +28,8 @@
import org.apache.beam.sdk.values.KV;
import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.clients.consumer.MockConsumer;
import org.apache.kafka.clients.consumer.OffsetResetStrategy;
import org.apache.kafka.clients.producer.Callback;
import org.apache.kafka.clients.producer.MockProducer;
import org.apache.kafka.clients.producer.Producer;
Expand Down Expand Up @@ -66,51 +67,33 @@ public Producer<Integer, Long> apply(Map<String, Object> input) {
}
}

public static final class PositionErrorConsumer extends MockConsumer<byte[], byte[]> {

public PositionErrorConsumer() {
super(null);
}

@Override
public synchronized long position(TopicPartition partition) {
throw new KafkaException("fakeException");
}

@Override
public synchronized List<PartitionInfo> partitionsFor(String topic) {
return Collections.singletonList(
new PartitionInfo("topic_a", 1, new Node(1, "myServer1", 9092), null, null));
}
}

public static final class PositionErrorConsumerFactory
public static final class EndOffsetErrorConsumerFactory
implements SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> {
public PositionErrorConsumerFactory() {}
public EndOffsetErrorConsumerFactory() {}

@Override
public MockConsumer<byte[], byte[]> apply(Map<String, Object> input) {
final MockConsumer<byte[], byte[]> consumer;
if (input.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
return new PositionErrorConsumer();
} else {
MockConsumer<byte[], byte[]> consumer =
new MockConsumer<byte[], byte[]>(null) {
consumer =
new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
@Override
public synchronized long position(TopicPartition partition) {
return 1L;
}

@Override
public synchronized ConsumerRecords<byte[], byte[]> poll(long timeout) {
return ConsumerRecords.empty();
public synchronized Map<TopicPartition, Long> endOffsets(
Collection<TopicPartition> partitions) {
throw new KafkaException("fakeException");
}
};
consumer.updatePartitions(
"topic_a",
Collections.singletonList(
new PartitionInfo("topic_a", 1, new Node(1, "myServer1", 9092), null, null)));
return consumer;
} else {
consumer = new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST);
}
consumer.updatePartitions(
"topic_a",
Collections.singletonList(
new PartitionInfo("topic_a", 1, new Node(1, "myServer1", 9092), null, null)));
consumer.updateBeginningOffsets(
Collections.singletonMap(new TopicPartition("topic_a", 1), 0L));
consumer.updateEndOffsets(Collections.singletonMap(new TopicPartition("topic_a", 1), 0L));
return consumer;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ public SimpleMockKafkaConsumer(
OffsetResetStrategy offsetResetStrategy, TopicPartition topicPartition) {
super(offsetResetStrategy);
this.topicPartition = topicPartition;
updateBeginningOffsets(ImmutableMap.of(topicPartition, 0L));
updateEndOffsets(ImmutableMap.of(topicPartition, Long.MAX_VALUE));
}

public void reset() {
Expand All @@ -214,6 +216,8 @@ public void reset() {
this.startOffsetForTime = KV.of(0L, Instant.now());
this.stopOffsetForTime = KV.of(Long.MAX_VALUE, null);
this.numOfRecordsPerPoll = 0L;
updateBeginningOffsets(ImmutableMap.of(topicPartition, 0L));
updateEndOffsets(ImmutableMap.of(topicPartition, Long.MAX_VALUE));
}

public void setRemoved() {
Expand Down

0 comments on commit deeddd1

Please sign in to comment.