diff --git a/CHANGELOG.md b/CHANGELOG.md index d86de830248e7..e6f9e29700b6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,6 +77,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Fixed - Fix bytes parameter on `_cat/recovery` ([#17598](https://github.com/opensearch-project/OpenSearch/pull/17598)) - Fix slow performance of FeatureFlag checks ([#17611](https://github.com/opensearch-project/OpenSearch/pull/17611)) +- Fix shard recovery in pull-based ingestion to avoid skipping messages ([#17868](https://github.com/opensearch-project/OpenSearch/pull/17868))) ### Security diff --git a/server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java b/server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java index 4b4a44e13d1df..42019c5bfcd55 100644 --- a/server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java +++ b/server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java @@ -50,7 +50,7 @@ public class DefaultStreamPoller implements StreamPoller { private ExecutorService processorThread; // start of the batch, inclusive - private IngestionShardPointer batchStartPointer; + private IngestionShardPointer initialBatchStartPointer; private boolean includeBatchStartPointer = false; private ResetState resetState; @@ -105,7 +105,7 @@ public DefaultStreamPoller( this.consumer = Objects.requireNonNull(consumer); this.resetState = resetState; this.resetValue = resetValue; - this.batchStartPointer = startPointer; + this.initialBatchStartPointer = startPointer; this.state = initialState; this.persistedPointers = persistedPointers; if (!this.persistedPointers.isEmpty()) { @@ -170,23 +170,23 @@ protected void startPoll() { if (resetState != ResetState.NONE) { switch (resetState) { case EARLIEST: - batchStartPointer = consumer.earliestPointer(); - logger.info("Resetting offset by seeking to earliest offset {}", batchStartPointer.asString()); + initialBatchStartPointer = consumer.earliestPointer(); + logger.info("Resetting offset by seeking to earliest offset {}", initialBatchStartPointer.asString()); break; case LATEST: - batchStartPointer = consumer.latestPointer(); - logger.info("Resetting offset by seeking to latest offset {}", batchStartPointer.asString()); + initialBatchStartPointer = consumer.latestPointer(); + logger.info("Resetting offset by seeking to latest offset {}", initialBatchStartPointer.asString()); break; case REWIND_BY_OFFSET: - batchStartPointer = consumer.pointerFromOffset(resetValue); - logger.info("Resetting offset by seeking to offset {}", batchStartPointer.asString()); + initialBatchStartPointer = consumer.pointerFromOffset(resetValue); + logger.info("Resetting offset by seeking to offset {}", initialBatchStartPointer.asString()); break; case REWIND_BY_TIMESTAMP: - batchStartPointer = consumer.pointerFromTimestampMillis(Long.parseLong(resetValue)); + initialBatchStartPointer = consumer.pointerFromTimestampMillis(Long.parseLong(resetValue)); logger.info( "Resetting offset by seeking to timestamp {}, corresponding offset {}", resetValue, - batchStartPointer.asString() + initialBatchStartPointer.asString() ); break; } @@ -209,7 +209,8 @@ protected void startPoll() { List> results; if (includeBatchStartPointer) { - results = consumer.readNext(batchStartPointer, true, MAX_POLL_SIZE, POLL_TIMEOUT); + results = consumer.readNext(initialBatchStartPointer, true, MAX_POLL_SIZE, POLL_TIMEOUT); + includeBatchStartPointer = false; } else { results = consumer.readNext(MAX_POLL_SIZE, POLL_TIMEOUT); } @@ -220,38 +221,47 @@ protected void startPoll() { } state = State.PROCESSING; - // process the records - boolean firstInBatch = true; - for (IngestionShardConsumer.ReadResult result : results) { - if (firstInBatch) { - // update the batch start pointer to the next batch - batchStartPointer = result.getPointer(); - firstInBatch = false; - } + processRecords(results); + } catch (Exception e) { + // Pause ingestion when an error is encountered while polling the streaming source. + // Currently we do not have a good way to skip past the failing messages. + // The user will have the option to manually update the offset and resume ingestion. + // todo: support retry? + logger.error("Pausing ingestion. Fatal error occurred in polling the shard {}: {}", consumer.getShardId(), e); + pause(); + } + } + } - // check if the message is already processed - if (isProcessed(result.getPointer())) { - logger.info("Skipping message with pointer {} as it is already processed", result.getPointer().asString()); - continue; - } - totalPolledCount.inc(); - blockingQueue.put(result); - - logger.debug( - "Put message {} with pointer {} to the blocking queue", - String.valueOf(result.getMessage().getPayload()), - result.getPointer().asString() - ); + private void processRecords(List> results) { + for (IngestionShardConsumer.ReadResult result : results) { + try { + // check if the message is already processed + if (isProcessed(result.getPointer())) { + logger.debug("Skipping message with pointer {} as it is already processed", () -> result.getPointer().asString()); + continue; } - // for future reads, we do not need to include the batch start pointer, and read from the last successful pointer. - includeBatchStartPointer = false; - } catch (Throwable e) { - logger.error("Error in polling the shard {}: {}", consumer.getShardId(), e); + totalPolledCount.inc(); + blockingQueue.put(result); + + logger.debug( + "Put message {} with pointer {} to the blocking queue", + String.valueOf(result.getMessage().getPayload()), + result.getPointer().asString() + ); + } catch (Exception e) { + logger.error( + "Error in processing a record. Shard {}, pointer {}: {}", + consumer.getShardId(), + result.getPointer().asString(), + e + ); errorStrategy.handleError(e, IngestionErrorStrategy.ErrorStage.POLLING); if (!errorStrategy.shouldIgnoreError(e, IngestionErrorStrategy.ErrorStage.POLLING)) { // Blocking error encountered. Pause poller to stop processing remaining updates. pause(); + break; } } } @@ -329,9 +339,16 @@ public boolean isClosed() { return closed; } + /** + * Returns the batch start pointer from where the poller can resume in case of shard recovery. The poller and + * processor are decoupled in this implementation, and hence the latest pointer tracked by the processor acts as the + * recovery/start point. In case the processor has not started tracking, then the initial batchStartPointer used by + * the poller acts as the start point. + */ @Override public IngestionShardPointer getBatchStartPointer() { - return batchStartPointer; + IngestionShardPointer currentShardPointer = processorRunnable.getCurrentShardPointer(); + return currentShardPointer == null ? initialBatchStartPointer : currentShardPointer; } @Override diff --git a/server/src/main/java/org/opensearch/indices/pollingingest/MessageProcessorRunnable.java b/server/src/main/java/org/opensearch/indices/pollingingest/MessageProcessorRunnable.java index 28de7224f9d89..c1d098279a7eb 100644 --- a/server/src/main/java/org/opensearch/indices/pollingingest/MessageProcessorRunnable.java +++ b/server/src/main/java/org/opensearch/indices/pollingingest/MessageProcessorRunnable.java @@ -13,6 +13,7 @@ import org.apache.lucene.document.StoredField; import org.apache.lucene.index.Term; import org.opensearch.action.DocWriteRequest; +import org.opensearch.common.Nullable; import org.opensearch.common.lucene.uid.Versions; import org.opensearch.common.metrics.CounterMetric; import org.opensearch.common.util.RequestUtils; @@ -59,6 +60,10 @@ public class MessageProcessorRunnable implements Runnable { private final MessageProcessor messageProcessor; private final CounterMetric stats = new CounterMetric(); + // tracks the most recent pointer that is being processed + @Nullable + private volatile IngestionShardPointer currentShardPointer; + /** * Constructor. * @@ -274,6 +279,7 @@ public void run() { if (readResult != null) { try { stats.inc(); + currentShardPointer = readResult.getPointer(); messageProcessor.process(readResult.getMessage(), readResult.getPointer()); readResult = null; } catch (Exception e) { @@ -308,4 +314,9 @@ public IngestionErrorStrategy getErrorStrategy() { public void setErrorStrategy(IngestionErrorStrategy errorStrategy) { this.errorStrategy = errorStrategy; } + + @Nullable + public IngestionShardPointer getCurrentShardPointer() { + return currentShardPointer; + } } diff --git a/server/src/test/java/org/opensearch/index/engine/IngestionEngineTests.java b/server/src/test/java/org/opensearch/index/engine/IngestionEngineTests.java index d8c5ebb16a36a..a510f92f9dd4c 100644 --- a/server/src/test/java/org/opensearch/index/engine/IngestionEngineTests.java +++ b/server/src/test/java/org/opensearch/index/engine/IngestionEngineTests.java @@ -102,7 +102,7 @@ public void testCreateEngine() throws IOException { // verify the commit data Assert.assertEquals(7, commitData.size()); // the commiit data is the start of the current batch - Assert.assertEquals("0", commitData.get(StreamPoller.BATCH_START)); + Assert.assertEquals("1", commitData.get(StreamPoller.BATCH_START)); // verify the stored offsets var offset = new FakeIngestionSource.FakeIngestionShardPointer(0); diff --git a/server/src/test/java/org/opensearch/indices/pollingingest/DefaultStreamPollerTests.java b/server/src/test/java/org/opensearch/indices/pollingingest/DefaultStreamPollerTests.java index 6d71a3763fbc9..37ff7eeb27f4c 100644 --- a/server/src/test/java/org/opensearch/indices/pollingingest/DefaultStreamPollerTests.java +++ b/server/src/test/java/org/opensearch/indices/pollingingest/DefaultStreamPollerTests.java @@ -267,13 +267,14 @@ public void testDropErrorIngestionStrategy() throws TimeoutException, Interrupte ); IngestionShardConsumer mockConsumer = mock(IngestionShardConsumer.class); when(mockConsumer.getShardId()).thenReturn(0); - when(mockConsumer.readNext(any(), anyBoolean(), anyLong(), anyInt())).thenThrow(new RuntimeException("message1 poll failed")) - .thenReturn(readResultsBatch1) - .thenThrow(new RuntimeException("message3 poll failed")) - .thenReturn(readResultsBatch2) - .thenReturn(Collections.emptyList()); + when(mockConsumer.readNext(any(), anyBoolean(), anyLong(), anyInt())).thenReturn(readResultsBatch1); + when(mockConsumer.readNext(anyLong(), anyInt())).thenReturn(readResultsBatch2).thenReturn(Collections.emptyList()); IngestionErrorStrategy errorStrategy = spy(new DropIngestionErrorStrategy("ingestion_source")); + ArrayBlockingQueue mockQueue = mock(ArrayBlockingQueue.class); + doThrow(new RuntimeException()).doNothing().when(mockQueue).put(any()); + processorRunnable = new MessageProcessorRunnable(mockQueue, processor, errorStrategy); + poller = new DefaultStreamPoller( new FakeIngestionSource.FakeIngestionShardPointer(0), persistedPointers, @@ -288,7 +289,7 @@ public void testDropErrorIngestionStrategy() throws TimeoutException, Interrupte Thread.sleep(sleepTime); verify(errorStrategy, times(1)).handleError(any(), eq(IngestionErrorStrategy.ErrorStage.POLLING)); - verify(processor, times(2)).process(any(), any()); + verify(mockQueue, times(4)).put(any()); } public void testBlockErrorIngestionStrategy() throws TimeoutException, InterruptedException { @@ -314,12 +315,14 @@ public void testBlockErrorIngestionStrategy() throws TimeoutException, Interrupt ); IngestionShardConsumer mockConsumer = mock(IngestionShardConsumer.class); when(mockConsumer.getShardId()).thenReturn(0); - when(mockConsumer.readNext(any(), anyBoolean(), anyLong(), anyInt())).thenThrow(new RuntimeException("message1 poll failed")) - .thenReturn(readResultsBatch1) - .thenReturn(readResultsBatch2) - .thenReturn(Collections.emptyList()); + when(mockConsumer.readNext(any(), anyBoolean(), anyLong(), anyInt())).thenReturn(readResultsBatch1); + when(mockConsumer.readNext(anyLong(), anyInt())).thenReturn(readResultsBatch2).thenReturn(Collections.emptyList()); IngestionErrorStrategy errorStrategy = spy(new BlockIngestionErrorStrategy("ingestion_source")); + ArrayBlockingQueue mockQueue = mock(ArrayBlockingQueue.class); + doThrow(new RuntimeException()).doNothing().when(mockQueue).put(any()); + processorRunnable = new MessageProcessorRunnable(mockQueue, processor, errorStrategy); + poller = new DefaultStreamPoller( new FakeIngestionSource.FakeIngestionShardPointer(0), persistedPointers, @@ -334,7 +337,6 @@ public void testBlockErrorIngestionStrategy() throws TimeoutException, Interrupt Thread.sleep(sleepTime); verify(errorStrategy, times(1)).handleError(any(), eq(IngestionErrorStrategy.ErrorStage.POLLING)); - verify(processor, never()).process(any(), any()); assertEquals(DefaultStreamPoller.State.PAUSED, poller.getState()); assertTrue(poller.isPaused()); } @@ -374,4 +376,54 @@ public void testUpdateErrorStrategy() { assertTrue(poller.getErrorStrategy() instanceof BlockIngestionErrorStrategy); assertTrue(processorRunnable.getErrorStrategy() instanceof BlockIngestionErrorStrategy); } + + public void testPersistedBatchStartPointer() throws TimeoutException, InterruptedException { + messages.add("{\"_id\":\"3\",\"_source\":{\"name\":\"bob\", \"age\": 24}}".getBytes(StandardCharsets.UTF_8)); + messages.add("{\"_id\":\"4\",\"_source\":{\"name\":\"alice\", \"age\": 21}}".getBytes(StandardCharsets.UTF_8)); + List< + IngestionShardConsumer.ReadResult< + FakeIngestionSource.FakeIngestionShardPointer, + FakeIngestionSource.FakeIngestionMessage>> readResultsBatch1 = fakeConsumer.readNext( + fakeConsumer.earliestPointer(), + true, + 2, + 100 + ); + List< + IngestionShardConsumer.ReadResult< + FakeIngestionSource.FakeIngestionShardPointer, + FakeIngestionSource.FakeIngestionMessage>> readResultsBatch2 = fakeConsumer.readNext( + new FakeIngestionSource.FakeIngestionShardPointer(2), + true, + 2, + 100 + ); + + // This test publishes 4 messages, so use blocking queue of size 3. This ensures the poller is blocked when adding the 4th message + // for validation. + IngestionErrorStrategy errorStrategy = spy(new BlockIngestionErrorStrategy("ingestion_source")); + doThrow(new RuntimeException()).when(processor).process(any(), any()); + processorRunnable = new MessageProcessorRunnable(new ArrayBlockingQueue<>(3), processor, errorStrategy); + + IngestionShardConsumer mockConsumer = mock(IngestionShardConsumer.class); + when(mockConsumer.getShardId()).thenReturn(0); + when(mockConsumer.readNext(any(), anyBoolean(), anyLong(), anyInt())).thenReturn(readResultsBatch1); + + when(mockConsumer.readNext(anyLong(), anyInt())).thenReturn(readResultsBatch2).thenReturn(Collections.emptyList()); + + poller = new DefaultStreamPoller( + new FakeIngestionSource.FakeIngestionShardPointer(0), + persistedPointers, + mockConsumer, + processorRunnable, + StreamPoller.ResetState.NONE, + "", + errorStrategy, + StreamPoller.State.NONE + ); + poller.start(); + Thread.sleep(sleepTime); + + assertEquals(new FakeIngestionSource.FakeIngestionShardPointer(0), poller.getBatchStartPointer()); + } }