diff --git a/core/src/main/java/kafka/server/share/SharePartition.java b/core/src/main/java/kafka/server/share/SharePartition.java index 900ce8ca68461..3e311e3fcd992 100644 --- a/core/src/main/java/kafka/server/share/SharePartition.java +++ b/core/src/main/java/kafka/server/share/SharePartition.java @@ -19,6 +19,7 @@ import kafka.server.ReplicaManager; import kafka.server.share.SharePartitionManager.SharePartitionListener; +import org.apache.kafka.clients.consumer.AcknowledgeType; import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.TopicIdPartition; import org.apache.kafka.common.Uuid; @@ -138,6 +139,16 @@ enum SharePartitionState { FENCED } + /** + * To provide static mapping between acknowledgement type bytes to RecordState. + */ + private static final Map ACK_TYPE_TO_RECORD_STATE = Map.of( + (byte) 0, RecordState.ARCHIVED, // Represents gap + AcknowledgeType.ACCEPT.id, RecordState.ACKNOWLEDGED, + AcknowledgeType.RELEASE.id, RecordState.AVAILABLE, + AcknowledgeType.REJECT.id, RecordState.ARCHIVED + ); + /** * The group id of the share partition belongs to. */ @@ -916,9 +927,9 @@ public CompletableFuture acknowledge( for (ShareAcknowledgementBatch batch : acknowledgementBatches) { // Client can either send a single entry in acknowledgeTypes which represents the state // of the complete batch or can send individual offsets state. - Map recordStateMap; + Map ackTypeMap; try { - recordStateMap = fetchRecordStateMapForAcknowledgementBatch(batch); + ackTypeMap = fetchAckTypeMapForBatch(batch); } catch (IllegalArgumentException e) { log.debug("Invalid acknowledge type: {} for share partition: {}-{}", batch.acknowledgeTypes(), groupId, topicIdPartition); @@ -946,7 +957,7 @@ public CompletableFuture acknowledge( Optional ackThrowable = acknowledgeBatchRecords( memberId, batch, - recordStateMap, + ackTypeMap, subMap, persisterBatches ); @@ -1856,26 +1867,21 @@ private boolean checkForStartOffsetWithinBatch(long batchFirstOffset, long batch return batchFirstOffset < localStartOffset && batchLastOffset >= localStartOffset; } - private Map fetchRecordStateMapForAcknowledgementBatch( - ShareAcknowledgementBatch batch) { + // Visibility for test + static Map fetchAckTypeMapForBatch(ShareAcknowledgementBatch batch) { // Client can either send a single entry in acknowledgeTypes which represents the state // of the complete batch or can send individual offsets state. Construct a map with record state // for each offset in the batch, if single acknowledge type is sent, the map will have only one entry. - Map recordStateMap = new HashMap<>(); + Map ackTypeMap = new HashMap<>(); for (int index = 0; index < batch.acknowledgeTypes().size(); index++) { - recordStateMap.put(batch.firstOffset() + index, - fetchRecordState(batch.acknowledgeTypes().get(index))); + byte ackType = batch.acknowledgeTypes().get(index); + // Validate + if (ackType != 0) { + AcknowledgeType.forId(ackType); + } + ackTypeMap.put(batch.firstOffset() + index, ackType); } - return recordStateMap; - } - - private static RecordState fetchRecordState(byte acknowledgeType) { - return switch (acknowledgeType) { - case 1 /* ACCEPT */ -> RecordState.ACKNOWLEDGED; - case 2 /* RELEASE */ -> RecordState.AVAILABLE; - case 3, 0 /* REJECT / GAP */ -> RecordState.ARCHIVED; - default -> throw new IllegalArgumentException("Invalid acknowledge type: " + acknowledgeType); - }; + return ackTypeMap; } private NavigableMap fetchSubMapForAcknowledgementBatch( @@ -1930,7 +1936,7 @@ private NavigableMap fetchSubMapForAcknowledgementBatch( private Optional acknowledgeBatchRecords( String memberId, ShareAcknowledgementBatch batch, - Map recordStateMap, + Map ackTypeMap, NavigableMap subMap, List persisterBatches ) { @@ -1994,11 +2000,11 @@ private Optional acknowledgeBatchRecords( } throwable = acknowledgePerOffsetBatchRecords(memberId, batch, inFlightBatch, - recordStateMap, persisterBatches); + ackTypeMap, persisterBatches); } else { // The in-flight batch is a full match hence change the state of the complete batch. throwable = acknowledgeCompleteBatch(batch, inFlightBatch, - recordStateMap.get(batch.firstOffset()), persisterBatches); + ackTypeMap.get(batch.firstOffset()), persisterBatches, memberId); } if (throwable.isPresent()) { @@ -2034,14 +2040,11 @@ private Optional acknowledgePerOffsetBatchRecords( String memberId, ShareAcknowledgementBatch batch, InFlightBatch inFlightBatch, - Map recordStateMap, + Map ackTypeMap, List persisterBatches ) { lock.writeLock().lock(); try { - // Fetch the first record state from the map to be used as default record state in case the - // offset record state is not provided by client. - RecordState recordStateDefault = recordStateMap.get(batch.firstOffset()); for (Map.Entry offsetState : inFlightBatch.offsetState().entrySet()) { // 1. For the first batch which might have offsets prior to the request base @@ -2081,31 +2084,50 @@ private Optional acknowledgePerOffsetBatchRecords( new InvalidRecordStateException("Member is not the owner of offset")); } - // Determine the record state for the offset. If the per offset record state is not provided - // by the client, then use the batch record state. - RecordState recordState = - recordStateMap.size() > 1 ? recordStateMap.get(offsetState.getKey()) : - recordStateDefault; - InFlightState updateResult = offsetState.getValue().startStateTransition( - recordState, - DeliveryCountOps.NO_OP, - this.maxDeliveryCount, - EMPTY_MEMBER_ID - ); - if (updateResult == null) { - log.debug("Unable to acknowledge records for the offset: {} in batch: {}" - + " for the share partition: {}-{}", offsetState.getKey(), - inFlightBatch, groupId, topicIdPartition); - return Optional.of(new InvalidRecordStateException( - "Unable to acknowledge records for the batch")); - } - // Successfully updated the state of the offset and created a persister state batch for write to persister. - persisterBatches.add(new PersisterBatch(updateResult, new PersisterStateBatch(offsetState.getKey(), - offsetState.getKey(), updateResult.state().id(), (short) updateResult.deliveryCount()))); - if (isStateTerminal(updateResult.state())) { - deliveryCompleteCount.incrementAndGet(); + // In case of 0 size ackTypeMap, we have already validated the batch.acknowledgeTypes. + byte ackType = ackTypeMap.size() > 1 ? ackTypeMap.get(offsetState.getKey()) : batch.acknowledgeTypes().get(0); + + if (ackType == AcknowledgeType.RENEW.id) { + // If RENEW, renew the acquisition lock timer for this offset and continue without changing state. + // We do not care about recordState map here. + // Only valid for ACQUIRED offsets; the check above ensures this. + long key = offsetState.getKey(); + InFlightState state = offsetState.getValue(); + log.debug("Renewing acq lock for {}-{} with offset {} in batch {} for member {}.", + groupId, topicIdPartition, key, inFlightBatch, memberId); + state.cancelAndClearAcquisitionLockTimeoutTask(); + AcquisitionLockTimerTask renewalTask = scheduleAcquisitionLockTimeout(memberId, key, key); + state.updateAcquisitionLockTimeoutTask(renewalTask); + } else { + // Determine the record state for the offset. If the per offset record state is not provided + // by the client, then use the batch record state. This will always be present as it is a static + // mapping between bytes and record state type. All ack types have been added except for RENEW which + // has been handled above. + RecordState recordState = ACK_TYPE_TO_RECORD_STATE.get(ackType); + Objects.requireNonNull(recordState); + + InFlightState updateResult = offsetState.getValue().startStateTransition( + recordState, + DeliveryCountOps.NO_OP, + this.maxDeliveryCount, + EMPTY_MEMBER_ID + ); + + if (updateResult == null) { + log.debug("Unable to acknowledge records for the offset: {} in batch: {}" + + " for the share partition: {}-{}", offsetState.getKey(), + inFlightBatch, groupId, topicIdPartition); + return Optional.of(new InvalidRecordStateException( + "Unable to acknowledge records for the batch")); + } + // Successfully updated the state of the offset and created a persister state batch for write to persister. + persisterBatches.add(new PersisterBatch(updateResult, new PersisterStateBatch(offsetState.getKey(), + offsetState.getKey(), updateResult.state().id(), (short) updateResult.deliveryCount()))); + if (isStateTerminal(updateResult.state())) { + deliveryCompleteCount.incrementAndGet(); + } + // Do not update the nextFetchOffset as the offset has not completed the transition yet. } - // Do not update the nextFetchOffset as the offset has not completed the transition yet. } } finally { lock.writeLock().unlock(); @@ -2116,8 +2138,9 @@ private Optional acknowledgePerOffsetBatchRecords( private Optional acknowledgeCompleteBatch( ShareAcknowledgementBatch batch, InFlightBatch inFlightBatch, - RecordState recordState, - List persisterBatches + byte ackType, + List persisterBatches, + String memberId ) { lock.writeLock().lock(); try { @@ -2131,10 +2154,30 @@ private Optional acknowledgeCompleteBatch( "The batch cannot be acknowledged. The batch is not in the acquired state.")); } + // If the request is a full-batch RENEW acknowledgement (ack type 4), then renew the + // acquisition lock without changing the state or persisting anything. + // Before reaching this point, it should be verified that it is full batch ack and + // not per offset ack as well as startOffset not moved. + if (ackType == AcknowledgeType.RENEW.id) { + // Renew the acquisition lock timer for the complete batch. We have already + // checked that the batchState is ACQUIRED above. + log.debug("Renewing acq lock for {}-{} with batch {}-{} for member {}.", + groupId, topicIdPartition, inFlightBatch.firstOffset(), inFlightBatch.lastOffset(), memberId); + inFlightBatch.cancelAndClearAcquisitionLockTimeoutTask(); + AcquisitionLockTimerTask renewalTask = scheduleAcquisitionLockTimeout(memberId, + inFlightBatch.firstOffset(), inFlightBatch.lastOffset()); + inFlightBatch.updateAcquisitionLockTimeout(renewalTask); + // Nothing to persist. + return Optional.empty(); + } + // Change the state of complete batch since the same state exists for the entire inFlight batch. - // The member id is reset to EMPTY_MEMBER_ID irrespective of the acknowledge type as the batch is + // The member id is reset to EMPTY_MEMBER_ID irrespective of the ack type as the batch is // either released or moved to a state where member id existence is not important. The member id // is only important when the batch is acquired. + RecordState recordState = ACK_TYPE_TO_RECORD_STATE.get(ackType); + Objects.requireNonNull(recordState); + InFlightState updateResult = inFlightBatch.startBatchStateTransition( recordState, DeliveryCountOps.NO_OP, @@ -3121,4 +3164,9 @@ private record LastOffsetAndMaxRecords( long lastOffset, int maxRecords ) { } + + // Visibility for testing + static Map ackTypeToRecordStateMapping() { + return ACK_TYPE_TO_RECORD_STATE; + } } diff --git a/core/src/test/java/kafka/server/share/SharePartitionTest.java b/core/src/test/java/kafka/server/share/SharePartitionTest.java index c9514c8ec3745..eb20dc9b8e3e5 100644 --- a/core/src/test/java/kafka/server/share/SharePartitionTest.java +++ b/core/src/test/java/kafka/server/share/SharePartitionTest.java @@ -56,6 +56,7 @@ import org.apache.kafka.server.share.acknowledge.ShareAcknowledgementBatch; import org.apache.kafka.server.share.fetch.AcquisitionLockTimerTask; import org.apache.kafka.server.share.fetch.DelayedShareFetchGroupKey; +import org.apache.kafka.server.share.fetch.InFlightBatch; import org.apache.kafka.server.share.fetch.InFlightState; import org.apache.kafka.server.share.fetch.RecordState; import org.apache.kafka.server.share.fetch.ShareAcquiredRecords; @@ -102,10 +103,13 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -9893,6 +9897,342 @@ public void testRecordArchivedWithWriteStateRPCFailure() throws InterruptedExcep assertEquals(5, sharePartition.deliveryCompleteCount()); } + @Test + public void testAckTypeToRecordStateMapping() { + // This test will help catch bugs if the map changes. + Map actualMap = SharePartition.ackTypeToRecordStateMapping(); + assertEquals(4, actualMap.size()); + + Map expected = Map.of( + (byte) 0, RecordState.ARCHIVED, + AcknowledgeType.ACCEPT.id, RecordState.ACKNOWLEDGED, + AcknowledgeType.RELEASE.id, RecordState.AVAILABLE, + AcknowledgeType.REJECT.id, RecordState.ARCHIVED + ); + + for (byte key : expected.keySet()) { + assertEquals(expected.get(key), actualMap.get(key)); + } + } + + @Test + public void testFetchAckTypeMapForBatch() { + ShareAcknowledgementBatch batch = mock(ShareAcknowledgementBatch.class); + when(batch.acknowledgeTypes()).thenReturn(List.of((byte) -1)); + assertThrows(IllegalArgumentException.class, () -> SharePartition.fetchAckTypeMapForBatch(batch)); + } + + @Test + public void testRenewAcknowledgeWithCompleteBatchAck() throws InterruptedException { + Persister persister = Mockito.mock(Persister.class); + SharePartition sharePartition = SharePartitionBuilder.builder() + .withState(SharePartitionState.ACTIVE) + .withDefaultAcquisitionLockTimeoutMs(ACQUISITION_LOCK_TIMEOUT_MS) + .withMaxDeliveryCount(2) + .withPersister(persister) + .build(); + + List records = fetchAcquiredRecords(sharePartition, memoryRecords(0, 1), 1); + assertEquals(1, records.size()); + assertEquals(records.get(0).firstOffset(), records.get(0).lastOffset()); + assertEquals(1, sharePartition.cachedState().size()); + InFlightBatch batch = sharePartition.cachedState().get(0L); + AcquisitionLockTimerTask taskOrig = batch.batchAcquisitionLockTimeoutTask(); + + sharePartition.acknowledge(MEMBER_ID, List.of(new ShareAcknowledgementBatch(0, 0, List.of(AcknowledgeType.RENEW.id)))); + assertTrue(taskOrig.isCancelled()); // Original acq lock cancelled. + assertNotEquals(taskOrig, batch.batchAcquisitionLockTimeoutTask()); // Lock changes. + assertEquals(1, sharePartition.timer().size()); // Timer jobs + assertEquals(RecordState.ACQUIRED, batch.batchState()); + Mockito.verify(persister, Mockito.times(0)).writeState(Mockito.any()); // No persister call. + + // Expire timer + // On expiration state will transition to AVAILABLE resulting in persister write RPC + WriteShareGroupStateResult writeShareGroupStateResult = Mockito.mock(WriteShareGroupStateResult.class); + Mockito.when(writeShareGroupStateResult.topicsData()).thenReturn(List.of( + new TopicData<>(TOPIC_ID_PARTITION.topicId(), List.of( + PartitionFactory.newPartitionErrorData(0, Errors.NONE.code(), Errors.NONE.message()))))); + when(persister.writeState(Mockito.any())).thenReturn(CompletableFuture.completedFuture(writeShareGroupStateResult)); + + mockTimer.advanceClock(ACQUISITION_LOCK_TIMEOUT_MS + 1); // Trigger expire + + assertNull(batch.batchAcquisitionLockTimeoutTask()); + assertEquals(RecordState.AVAILABLE, batch.batchState()); // Verify batch record state + assertEquals(0, sharePartition.timer().size()); // Timer jobs + Mockito.verify(persister, Mockito.times(1)).writeState(Mockito.any()); // 1 persister call. + } + + @Test + public void testRenewAcknowledgeOnExpiredBatch() throws InterruptedException { + Persister persister = Mockito.mock(Persister.class); + SharePartition sharePartition = SharePartitionBuilder.builder() + .withState(SharePartitionState.ACTIVE) + .withDefaultAcquisitionLockTimeoutMs(ACQUISITION_LOCK_TIMEOUT_MS) + .withMaxDeliveryCount(2) + .withPersister(persister) + .build(); + + List records = fetchAcquiredRecords(sharePartition, memoryRecords(0, 1), 1); + assertEquals(1, records.size()); + assertEquals(records.get(0).firstOffset(), records.get(0).lastOffset()); + assertEquals(1, sharePartition.cachedState().size()); + InFlightBatch batch = sharePartition.cachedState().get(0L); + AcquisitionLockTimerTask taskOrig = batch.batchAcquisitionLockTimeoutTask(); + + // Expire acq lock timeout. + // Persister mocking for recordState transition. + WriteShareGroupStateResult writeShareGroupStateResult = Mockito.mock(WriteShareGroupStateResult.class); + Mockito.when(writeShareGroupStateResult.topicsData()).thenReturn(List.of( + new TopicData<>(TOPIC_ID_PARTITION.topicId(), List.of( + PartitionFactory.newPartitionErrorData(0, Errors.NONE.code(), Errors.NONE.message()))))); + + when(persister.writeState(Mockito.any())).thenReturn(CompletableFuture.completedFuture(writeShareGroupStateResult)); + + mockTimer.advanceClock(ACQUISITION_LOCK_TIMEOUT_MS + 1); + TestUtils.waitForCondition(() -> batch.batchAcquisitionLockTimeoutTask() == null, "Acq lock timeout not cancelled."); + CompletableFuture future = sharePartition.acknowledge(MEMBER_ID, List.of(new ShareAcknowledgementBatch(0, 0, List.of(AcknowledgeType.RENEW.id)))); + + assertTrue(future.isCompletedExceptionally()); + try { + future.get(); + fail("No exception thrown"); + } catch (Exception e) { + assertNotNull(e); + assertInstanceOf(InvalidRecordStateException.class, e.getCause()); + } + assertTrue(taskOrig.isCancelled()); // Original acq lock cancelled. + assertNotEquals(taskOrig, batch.batchAcquisitionLockTimeoutTask()); // Lock changes. + assertEquals(0, sharePartition.timer().size()); // Timer jobs + assertEquals(RecordState.AVAILABLE, batch.batchState()); + Mockito.verify(persister, Mockito.times(1)).writeState(Mockito.any()); // 1 persister call to update record state. + } + + @Test + public void testRenewAcknowledgeWithPerOffsetAck() throws InterruptedException { + Persister persister = Mockito.mock(Persister.class); + SharePartition sharePartition = SharePartitionBuilder.builder() + .withState(SharePartitionState.ACTIVE) + .withDefaultAcquisitionLockTimeoutMs(ACQUISITION_LOCK_TIMEOUT_MS) + .withMaxDeliveryCount(2) + .withPersister(persister) + .build(); + + List records = fetchAcquiredRecords(sharePartition, memoryRecords(0, 2), 2); + assertEquals(1, records.size()); + assertEquals(records.get(0).firstOffset() + 1, records.get(0).lastOffset()); + assertEquals(1, sharePartition.cachedState().size()); + InFlightBatch batch = sharePartition.cachedState().get(0L); + assertEquals(RecordState.ACQUIRED, batch.batchState()); + AcquisitionLockTimerTask taskOrig = batch.batchAcquisitionLockTimeoutTask(); + + // For ACCEPT ack call. + WriteShareGroupStateResult writeShareGroupStateResult = Mockito.mock(WriteShareGroupStateResult.class); + Mockito.when(writeShareGroupStateResult.topicsData()).thenReturn(List.of( + new TopicData<>(TOPIC_ID_PARTITION.topicId(), List.of( + PartitionFactory.newPartitionErrorData(0, Errors.NONE.code(), Errors.NONE.message()))))); + + when(persister.writeState(Mockito.any())).thenReturn(CompletableFuture.completedFuture(writeShareGroupStateResult)); + + sharePartition.acknowledge(MEMBER_ID, List.of(new ShareAcknowledgementBatch(0, 1, + List.of(AcknowledgeType.RENEW.id, AcknowledgeType.ACCEPT.id)))); + + assertTrue(taskOrig.isCancelled()); // Original acq lock cancelled. + assertNotEquals(taskOrig, sharePartition.cachedState().get(0L).offsetState().get(0L).acquisitionLockTimeoutTask()); + assertNotNull(sharePartition.cachedState().get(0L).offsetState()); + + InFlightState offset0 = sharePartition.cachedState().get(0L).offsetState().get(0L); + InFlightState offset1 = sharePartition.cachedState().get(0L).offsetState().get(1L); + assertEquals(RecordState.ACQUIRED, offset0.state()); + assertNotNull(offset0.acquisitionLockTimeoutTask()); + assertEquals(1, sharePartition.timer().size()); // Timer jobs + + assertEquals(RecordState.ACKNOWLEDGED, offset1.state()); + assertNull(offset1.acquisitionLockTimeoutTask()); + + Mockito.verify(persister, Mockito.times(1)).writeState(Mockito.any()); + + // Expire timer + mockTimer.advanceClock(ACQUISITION_LOCK_TIMEOUT_MS + 1); // Trigger expire + + assertNull(offset0.acquisitionLockTimeoutTask()); + assertEquals(RecordState.AVAILABLE, offset0.state()); // Verify batch record state + assertEquals(0, sharePartition.timer().size()); // Timer jobs + Mockito.verify(persister, Mockito.times(2)).writeState(Mockito.any()); // 1 more persister call. + } + + @Test + public void testLsoMovementWithBatchRenewal() { + Persister persister = Mockito.mock(Persister.class); + SharePartition sharePartition = SharePartitionBuilder.builder() + .withState(SharePartitionState.ACTIVE) + .withDefaultAcquisitionLockTimeoutMs(ACQUISITION_LOCK_TIMEOUT_MS) + .withMaxDeliveryCount(2) + .withPersister(persister) + .build(); + + List records = fetchAcquiredRecords(sharePartition, memoryRecords(0, 10), 10); + assertEquals(1, records.size()); + assertNotEquals(records.get(0).firstOffset(), records.get(0).lastOffset()); + assertEquals(1, sharePartition.cachedState().size()); + InFlightBatch batch = sharePartition.cachedState().get(0L); + AcquisitionLockTimerTask taskOrig = batch.batchAcquisitionLockTimeoutTask(); + + sharePartition.acknowledge(MEMBER_ID, List.of(new ShareAcknowledgementBatch(0, 9, List.of(AcknowledgeType.RENEW.id)))); + sharePartition.updateCacheAndOffsets(5); + + assertEquals(10, sharePartition.nextFetchOffset()); + assertEquals(5, sharePartition.startOffset()); + assertEquals(9, sharePartition.endOffset()); + assertEquals(1, sharePartition.cachedState().size()); + + assertEquals(MEMBER_ID, sharePartition.cachedState().get(0L).batchMemberId()); + assertEquals(RecordState.ACQUIRED, sharePartition.cachedState().get(0L).batchState()); + + assertTrue(taskOrig.isCancelled()); // Original acq lock cancelled. + assertNotEquals(taskOrig, batch.batchAcquisitionLockTimeoutTask()); // Lock changes. + assertEquals(1, sharePartition.timer().size()); // Timer jobs + Mockito.verify(persister, Mockito.times(0)).writeState(Mockito.any()); // No persister call. + } + + @Test + public void testLsoMovementWithPerOffsetRenewal() throws InterruptedException { + Persister persister = Mockito.mock(Persister.class); + SharePartition sharePartition = SharePartitionBuilder.builder() + .withState(SharePartitionState.ACTIVE) + .withDefaultAcquisitionLockTimeoutMs(ACQUISITION_LOCK_TIMEOUT_MS) + .withMaxDeliveryCount(2) + .withPersister(persister) + .build(); + + List records = fetchAcquiredRecords(sharePartition, memoryRecords(0, 5), 5); + assertEquals(1, records.size()); + assertEquals(records.get(0).firstOffset() + 4, records.get(0).lastOffset()); + assertEquals(1, sharePartition.cachedState().size()); + InFlightBatch batch = sharePartition.cachedState().get(0L); + assertEquals(RecordState.ACQUIRED, batch.batchState()); + AcquisitionLockTimerTask taskOrig = batch.batchAcquisitionLockTimeoutTask(); + + // For ACCEPT ack call. + WriteShareGroupStateResult writeShareGroupStateResult = Mockito.mock(WriteShareGroupStateResult.class); + Mockito.when(writeShareGroupStateResult.topicsData()).thenReturn(List.of( + new TopicData<>(TOPIC_ID_PARTITION.topicId(), List.of( + PartitionFactory.newPartitionErrorData(0, Errors.NONE.code(), Errors.NONE.message()))))); + + when(persister.writeState(Mockito.any())).thenReturn(CompletableFuture.completedFuture(writeShareGroupStateResult)); + + sharePartition.acknowledge(MEMBER_ID, List.of(new ShareAcknowledgementBatch(0, 4, + List.of(AcknowledgeType.RENEW.id, AcknowledgeType.ACCEPT.id, AcknowledgeType.RENEW.id, AcknowledgeType.ACCEPT.id, AcknowledgeType.RENEW.id)))); + + sharePartition.updateCacheAndOffsets(3); + + assertEquals(5, sharePartition.nextFetchOffset()); + assertEquals(3, sharePartition.startOffset()); + assertEquals(4, sharePartition.endOffset()); + assertEquals(1, sharePartition.cachedState().size()); + + assertTrue(taskOrig.isCancelled()); // Original acq lock cancelled. + assertNotEquals(taskOrig, sharePartition.cachedState().get(0L).offsetState().get(0L).acquisitionLockTimeoutTask()); + assertNotNull(sharePartition.cachedState().get(0L).offsetState()); + + InFlightState offset0 = sharePartition.cachedState().get(0L).offsetState().get(0L); + InFlightState offset1 = sharePartition.cachedState().get(0L).offsetState().get(1L); + InFlightState offset2 = sharePartition.cachedState().get(0L).offsetState().get(2L); + InFlightState offset3 = sharePartition.cachedState().get(0L).offsetState().get(3L); + InFlightState offset4 = sharePartition.cachedState().get(0L).offsetState().get(4L); + + assertEquals(RecordState.ACQUIRED, offset0.state()); + assertNotNull(offset0.acquisitionLockTimeoutTask()); + + assertEquals(RecordState.ACKNOWLEDGED, offset1.state()); + assertNull(offset1.acquisitionLockTimeoutTask()); + + assertEquals(RecordState.ACQUIRED, offset2.state()); + assertNotNull(offset2.acquisitionLockTimeoutTask()); + + assertEquals(RecordState.ACKNOWLEDGED, offset3.state()); + assertNull(offset3.acquisitionLockTimeoutTask()); + + assertEquals(RecordState.ACQUIRED, offset4.state()); + assertNotNull(offset4.acquisitionLockTimeoutTask()); + + assertEquals(3, sharePartition.timer().size()); // Timer jobs - 3 because the renewed offsets are non-contiguous. + + // Expire timer + mockTimer.advanceClock(ACQUISITION_LOCK_TIMEOUT_MS + 1); // Trigger expire + // todo: index 2 in expectedStates should be RecordState.ARCHIVED - fix after ticket KAFKA-19859 is addressed. + List expectedStates = List.of(RecordState.ARCHIVED, RecordState.ACKNOWLEDGED, RecordState.AVAILABLE, RecordState.ACKNOWLEDGED, RecordState.AVAILABLE); + for (long i = 0; i <= 4; i++) { + InFlightState offset = sharePartition.cachedState().get(0L).offsetState().get(i); + assertNull(offset.acquisitionLockTimeoutTask()); + assertEquals(expectedStates.get((int) i), offset.state()); + } + + assertEquals(0, sharePartition.timer().size()); // Timer jobs + + Mockito.verify(persister, Mockito.times(4)).writeState(Mockito.any()); + } + + @Test + public void testRenewAcknowledgeWithPerOffsetAndBatchMix() { + Persister persister = Mockito.mock(Persister.class); + SharePartition sharePartition = SharePartitionBuilder.builder() + .withState(SharePartitionState.ACTIVE) + .withDefaultAcquisitionLockTimeoutMs(ACQUISITION_LOCK_TIMEOUT_MS) + .withMaxDeliveryCount(2) + .withPersister(persister) + .build(); + + // Batch + List recordsB = fetchAcquiredRecords(sharePartition, memoryRecords(0, 1), 1); + assertEquals(1, recordsB.size()); + assertEquals(recordsB.get(0).firstOffset(), recordsB.get(0).lastOffset()); + assertEquals(1, sharePartition.cachedState().size()); + InFlightBatch batchB = sharePartition.cachedState().get(0L); + AcquisitionLockTimerTask taskOrigB = batchB.batchAcquisitionLockTimeoutTask(); + + // Per offset + List recordsO = fetchAcquiredRecords(sharePartition, memoryRecords(1, 2), 2); + assertEquals(1, recordsO.size()); + assertEquals(recordsO.get(0).firstOffset() + 1, recordsO.get(0).lastOffset()); + assertEquals(2, sharePartition.cachedState().size()); + InFlightBatch batchO = sharePartition.cachedState().get(0L); + assertEquals(RecordState.ACQUIRED, batchO.batchState()); + AcquisitionLockTimerTask taskOrigO = batchO.batchAcquisitionLockTimeoutTask(); + + // For ACCEPT ack call. + WriteShareGroupStateResult writeShareGroupStateResult = Mockito.mock(WriteShareGroupStateResult.class); + Mockito.when(writeShareGroupStateResult.topicsData()).thenReturn(List.of( + new TopicData<>(TOPIC_ID_PARTITION.topicId(), List.of( + PartitionFactory.newPartitionErrorData(0, Errors.NONE.code(), Errors.NONE.message()))))); + + when(persister.writeState(Mockito.any())).thenReturn(CompletableFuture.completedFuture(writeShareGroupStateResult)); + + sharePartition.acknowledge(MEMBER_ID, List.of( + new ShareAcknowledgementBatch(0, 0, List.of(AcknowledgeType.RENEW.id)), + new ShareAcknowledgementBatch(1, 2, List.of(AcknowledgeType.RENEW.id, AcknowledgeType.ACCEPT.id)) + )); + + // Batch checks + assertTrue(taskOrigB.isCancelled()); // Original acq lock cancelled. + assertNotEquals(taskOrigB, batchB.batchAcquisitionLockTimeoutTask()); // Lock changes. + + // Per offset checks + assertTrue(taskOrigO.isCancelled()); // Original acq lock cancelled. + assertNotEquals(taskOrigO, sharePartition.cachedState().get(1L).offsetState().get(1L).acquisitionLockTimeoutTask()); + assertNotNull(sharePartition.cachedState().get(1L).offsetState()); + + InFlightState offset1 = sharePartition.cachedState().get(1L).offsetState().get(1L); + InFlightState offset2 = sharePartition.cachedState().get(1L).offsetState().get(2L); + assertEquals(RecordState.ACQUIRED, offset1.state()); + assertNotNull(offset1.acquisitionLockTimeoutTask()); + + assertEquals(RecordState.ACKNOWLEDGED, offset2.state()); + assertNull(offset2.acquisitionLockTimeoutTask()); + + assertEquals(2, sharePartition.timer().size()); // Timer jobs one for batch and one for single renewal in per offset. + Mockito.verify(persister, Mockito.times(1)).writeState(Mockito.any()); + } + /** * This function produces transactional data of a given no. of records followed by a transactional marker (COMMIT/ABORT). */