Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Fixed
- Fix bytes parameter on `_cat/recovery` ([#17598](https://github.com/opensearch-project/OpenSearch/pull/17598))
- Fix slow performance of FeatureFlag checks ([#17611](https://github.com/opensearch-project/OpenSearch/pull/17611))
- Fix shard recovery in pull-based ingestion to avoid skipping messages ([#17868](https://github.com/opensearch-project/OpenSearch/pull/17868)))

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
private ExecutorService processorThread;

// start of the batch, inclusive
private IngestionShardPointer batchStartPointer;
private IngestionShardPointer initialBatchStartPointer;
private boolean includeBatchStartPointer = false;

private ResetState resetState;
Expand Down Expand Up @@ -105,7 +105,7 @@
this.consumer = Objects.requireNonNull(consumer);
this.resetState = resetState;
this.resetValue = resetValue;
this.batchStartPointer = startPointer;
this.initialBatchStartPointer = startPointer;
this.state = initialState;
this.persistedPointers = persistedPointers;
if (!this.persistedPointers.isEmpty()) {
Expand Down Expand Up @@ -170,23 +170,23 @@
if (resetState != ResetState.NONE) {
switch (resetState) {
case EARLIEST:
batchStartPointer = consumer.earliestPointer();
logger.info("Resetting offset by seeking to earliest offset {}", batchStartPointer.asString());
initialBatchStartPointer = consumer.earliestPointer();
logger.info("Resetting offset by seeking to earliest offset {}", initialBatchStartPointer.asString());
break;
case LATEST:
batchStartPointer = consumer.latestPointer();
logger.info("Resetting offset by seeking to latest offset {}", batchStartPointer.asString());
initialBatchStartPointer = consumer.latestPointer();
logger.info("Resetting offset by seeking to latest offset {}", initialBatchStartPointer.asString());
break;
case REWIND_BY_OFFSET:
batchStartPointer = consumer.pointerFromOffset(resetValue);
logger.info("Resetting offset by seeking to offset {}", batchStartPointer.asString());
initialBatchStartPointer = consumer.pointerFromOffset(resetValue);
logger.info("Resetting offset by seeking to offset {}", initialBatchStartPointer.asString());
break;
case REWIND_BY_TIMESTAMP:
batchStartPointer = consumer.pointerFromTimestampMillis(Long.parseLong(resetValue));
initialBatchStartPointer = consumer.pointerFromTimestampMillis(Long.parseLong(resetValue));

Check warning on line 185 in server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java#L185

Added line #L185 was not covered by tests
logger.info(
"Resetting offset by seeking to timestamp {}, corresponding offset {}",
resetValue,
batchStartPointer.asString()
initialBatchStartPointer.asString()

Check warning on line 189 in server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java#L189

Added line #L189 was not covered by tests
);
break;
}
Expand All @@ -209,7 +209,8 @@
List<IngestionShardConsumer.ReadResult<? extends IngestionShardPointer, ? extends Message>> results;

if (includeBatchStartPointer) {
results = consumer.readNext(batchStartPointer, true, MAX_POLL_SIZE, POLL_TIMEOUT);
results = consumer.readNext(initialBatchStartPointer, true, MAX_POLL_SIZE, POLL_TIMEOUT);
includeBatchStartPointer = false;
} else {
results = consumer.readNext(MAX_POLL_SIZE, POLL_TIMEOUT);
}
Expand All @@ -220,38 +221,47 @@
}

state = State.PROCESSING;
// process the records
boolean firstInBatch = true;
for (IngestionShardConsumer.ReadResult<? extends IngestionShardPointer, ? extends Message> result : results) {
if (firstInBatch) {
// update the batch start pointer to the next batch
batchStartPointer = result.getPointer();
firstInBatch = false;
}
processRecords(results);
} catch (Exception e) {

Check warning on line 225 in server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java#L225

Added line #L225 was not covered by tests
// Pause ingestion when an error is encountered while polling the streaming source.
// Currently we do not have a good way to skip past the failing messages.
// The user will have the option to manually update the offset and resume ingestion.
// todo: support retry?
logger.error("Pausing ingestion. Fatal error occurred in polling the shard {}: {}", consumer.getShardId(), e);
pause();

Check warning on line 231 in server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java#L230-L231

Added lines #L230 - L231 were not covered by tests
}
}
}

// check if the message is already processed
if (isProcessed(result.getPointer())) {
logger.info("Skipping message with pointer {} as it is already processed", result.getPointer().asString());
continue;
}
totalPolledCount.inc();
blockingQueue.put(result);

logger.debug(
"Put message {} with pointer {} to the blocking queue",
String.valueOf(result.getMessage().getPayload()),
result.getPointer().asString()
);
private void processRecords(List<IngestionShardConsumer.ReadResult<? extends IngestionShardPointer, ? extends Message>> results) {
for (IngestionShardConsumer.ReadResult<? extends IngestionShardPointer, ? extends Message> result : results) {
try {
// check if the message is already processed
if (isProcessed(result.getPointer())) {
logger.debug("Skipping message with pointer {} as it is already processed", () -> result.getPointer().asString());
continue;
}
// for future reads, we do not need to include the batch start pointer, and read from the last successful pointer.
includeBatchStartPointer = false;
} catch (Throwable e) {
logger.error("Error in polling the shard {}: {}", consumer.getShardId(), e);
totalPolledCount.inc();
blockingQueue.put(result);

logger.debug(
"Put message {} with pointer {} to the blocking queue",
String.valueOf(result.getMessage().getPayload()),
result.getPointer().asString()
);
} catch (Exception e) {
logger.error(
"Error in processing a record. Shard {}, pointer {}: {}",
consumer.getShardId(),
result.getPointer().asString(),
e
);
errorStrategy.handleError(e, IngestionErrorStrategy.ErrorStage.POLLING);

if (!errorStrategy.shouldIgnoreError(e, IngestionErrorStrategy.ErrorStage.POLLING)) {
// Blocking error encountered. Pause poller to stop processing remaining updates.
pause();
break;
}
}
}
Expand Down Expand Up @@ -329,9 +339,16 @@
return closed;
}

/**
* Returns the batch start pointer from where the poller can resume in case of shard recovery. The poller and
* processor are decoupled in this implementation, and hence the latest pointer tracked by the processor acts as the
* recovery/start point. In case the processor has not started tracking, then the initial batchStartPointer used by
* the poller acts as the start point.
*/
@Override
public IngestionShardPointer getBatchStartPointer() {
return batchStartPointer;
IngestionShardPointer currentShardPointer = processorRunnable.getCurrentShardPointer();
return currentShardPointer == null ? initialBatchStartPointer : currentShardPointer;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.Term;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.common.Nullable;
import org.opensearch.common.lucene.uid.Versions;
import org.opensearch.common.metrics.CounterMetric;
import org.opensearch.common.util.RequestUtils;
Expand Down Expand Up @@ -59,6 +60,10 @@ public class MessageProcessorRunnable implements Runnable {
private final MessageProcessor messageProcessor;
private final CounterMetric stats = new CounterMetric();

// tracks the most recent pointer that is being processed
@Nullable
private volatile IngestionShardPointer currentShardPointer;

/**
* Constructor.
*
Expand Down Expand Up @@ -274,6 +279,7 @@ public void run() {
if (readResult != null) {
try {
stats.inc();
currentShardPointer = readResult.getPointer();
messageProcessor.process(readResult.getMessage(), readResult.getPointer());
readResult = null;
} catch (Exception e) {
Expand Down Expand Up @@ -308,4 +314,9 @@ public IngestionErrorStrategy getErrorStrategy() {
public void setErrorStrategy(IngestionErrorStrategy errorStrategy) {
this.errorStrategy = errorStrategy;
}

@Nullable
public IngestionShardPointer getCurrentShardPointer() {
return currentShardPointer;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public void testCreateEngine() throws IOException {
// verify the commit data
Assert.assertEquals(7, commitData.size());
// the commiit data is the start of the current batch
Assert.assertEquals("0", commitData.get(StreamPoller.BATCH_START));
Assert.assertEquals("1", commitData.get(StreamPoller.BATCH_START));

// verify the stored offsets
var offset = new FakeIngestionSource.FakeIngestionShardPointer(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,14 @@ public void testDropErrorIngestionStrategy() throws TimeoutException, Interrupte
);
IngestionShardConsumer mockConsumer = mock(IngestionShardConsumer.class);
when(mockConsumer.getShardId()).thenReturn(0);
when(mockConsumer.readNext(any(), anyBoolean(), anyLong(), anyInt())).thenThrow(new RuntimeException("message1 poll failed"))
.thenReturn(readResultsBatch1)
.thenThrow(new RuntimeException("message3 poll failed"))
.thenReturn(readResultsBatch2)
.thenReturn(Collections.emptyList());
when(mockConsumer.readNext(any(), anyBoolean(), anyLong(), anyInt())).thenReturn(readResultsBatch1);
when(mockConsumer.readNext(anyLong(), anyInt())).thenReturn(readResultsBatch2).thenReturn(Collections.emptyList());

IngestionErrorStrategy errorStrategy = spy(new DropIngestionErrorStrategy("ingestion_source"));
ArrayBlockingQueue mockQueue = mock(ArrayBlockingQueue.class);
doThrow(new RuntimeException()).doNothing().when(mockQueue).put(any());
processorRunnable = new MessageProcessorRunnable(mockQueue, processor, errorStrategy);

poller = new DefaultStreamPoller(
new FakeIngestionSource.FakeIngestionShardPointer(0),
persistedPointers,
Expand All @@ -288,7 +289,7 @@ public void testDropErrorIngestionStrategy() throws TimeoutException, Interrupte
Thread.sleep(sleepTime);

verify(errorStrategy, times(1)).handleError(any(), eq(IngestionErrorStrategy.ErrorStage.POLLING));
verify(processor, times(2)).process(any(), any());
verify(mockQueue, times(4)).put(any());
}

public void testBlockErrorIngestionStrategy() throws TimeoutException, InterruptedException {
Expand All @@ -314,12 +315,14 @@ public void testBlockErrorIngestionStrategy() throws TimeoutException, Interrupt
);
IngestionShardConsumer mockConsumer = mock(IngestionShardConsumer.class);
when(mockConsumer.getShardId()).thenReturn(0);
when(mockConsumer.readNext(any(), anyBoolean(), anyLong(), anyInt())).thenThrow(new RuntimeException("message1 poll failed"))
.thenReturn(readResultsBatch1)
.thenReturn(readResultsBatch2)
.thenReturn(Collections.emptyList());
when(mockConsumer.readNext(any(), anyBoolean(), anyLong(), anyInt())).thenReturn(readResultsBatch1);
when(mockConsumer.readNext(anyLong(), anyInt())).thenReturn(readResultsBatch2).thenReturn(Collections.emptyList());

IngestionErrorStrategy errorStrategy = spy(new BlockIngestionErrorStrategy("ingestion_source"));
ArrayBlockingQueue mockQueue = mock(ArrayBlockingQueue.class);
doThrow(new RuntimeException()).doNothing().when(mockQueue).put(any());
processorRunnable = new MessageProcessorRunnable(mockQueue, processor, errorStrategy);

poller = new DefaultStreamPoller(
new FakeIngestionSource.FakeIngestionShardPointer(0),
persistedPointers,
Expand All @@ -334,7 +337,6 @@ public void testBlockErrorIngestionStrategy() throws TimeoutException, Interrupt
Thread.sleep(sleepTime);

verify(errorStrategy, times(1)).handleError(any(), eq(IngestionErrorStrategy.ErrorStage.POLLING));
verify(processor, never()).process(any(), any());
assertEquals(DefaultStreamPoller.State.PAUSED, poller.getState());
assertTrue(poller.isPaused());
}
Expand Down Expand Up @@ -374,4 +376,54 @@ public void testUpdateErrorStrategy() {
assertTrue(poller.getErrorStrategy() instanceof BlockIngestionErrorStrategy);
assertTrue(processorRunnable.getErrorStrategy() instanceof BlockIngestionErrorStrategy);
}

public void testPersistedBatchStartPointer() throws TimeoutException, InterruptedException {
messages.add("{\"_id\":\"3\",\"_source\":{\"name\":\"bob\", \"age\": 24}}".getBytes(StandardCharsets.UTF_8));
messages.add("{\"_id\":\"4\",\"_source\":{\"name\":\"alice\", \"age\": 21}}".getBytes(StandardCharsets.UTF_8));
List<
IngestionShardConsumer.ReadResult<
FakeIngestionSource.FakeIngestionShardPointer,
FakeIngestionSource.FakeIngestionMessage>> readResultsBatch1 = fakeConsumer.readNext(
fakeConsumer.earliestPointer(),
true,
2,
100
);
List<
IngestionShardConsumer.ReadResult<
FakeIngestionSource.FakeIngestionShardPointer,
FakeIngestionSource.FakeIngestionMessage>> readResultsBatch2 = fakeConsumer.readNext(
new FakeIngestionSource.FakeIngestionShardPointer(2),
true,
2,
100
);

// This test publishes 4 messages, so use blocking queue of size 3. This ensures the poller is blocked when adding the 4th message
// for validation.
IngestionErrorStrategy errorStrategy = spy(new BlockIngestionErrorStrategy("ingestion_source"));
doThrow(new RuntimeException()).when(processor).process(any(), any());
processorRunnable = new MessageProcessorRunnable(new ArrayBlockingQueue<>(3), processor, errorStrategy);

IngestionShardConsumer mockConsumer = mock(IngestionShardConsumer.class);
when(mockConsumer.getShardId()).thenReturn(0);
when(mockConsumer.readNext(any(), anyBoolean(), anyLong(), anyInt())).thenReturn(readResultsBatch1);

when(mockConsumer.readNext(anyLong(), anyInt())).thenReturn(readResultsBatch2).thenReturn(Collections.emptyList());

poller = new DefaultStreamPoller(
new FakeIngestionSource.FakeIngestionShardPointer(0),
persistedPointers,
mockConsumer,
processorRunnable,
StreamPoller.ResetState.NONE,
"",
errorStrategy,
StreamPoller.State.NONE
);
poller.start();
Thread.sleep(sleepTime);

assertEquals(new FakeIngestionSource.FakeIngestionShardPointer(0), poller.getBatchStartPointer());
}
}
Loading