Skip to content

Commit

Permalink
Merge pull request #349 from conductor-oss/fix/add-lease-extension
Browse files Browse the repository at this point in the history
Added lease extension to java sdk and fixed tests
ystxn authored Dec 30, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
2 parents e0780c3 + 395c672 commit 99b0fc8
Showing 10 changed files with 194 additions and 54 deletions.
Original file line number Diff line number Diff line change
@@ -21,8 +21,7 @@ dependencies {
testImplementation "org.junit.jupiter:junit-jupiter-api:${versions.junit}"
testRuntimeOnly "org.junit.jupiter:junit-jupiter-engine:${versions.junit}"

testImplementation "org.powermock:powermock-module-junit4:2.0.9"
testImplementation "org.powermock:powermock-api-mockito2:2.0.9"
testImplementation 'org.mockito:mockito-inline:5.2.0'

testImplementation 'org.spockframework:spock-core:2.3-groovy-3.0'
testImplementation 'org.codehaus.groovy:groovy:3.0.15'
@@ -36,6 +35,7 @@ java {

test {
useJUnitPlatform()
maxParallelForks = 1
}

shadowJar {
Original file line number Diff line number Diff line change
@@ -16,13 +16,17 @@
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
@@ -68,6 +72,10 @@ class TaskRunner {
private final EventDispatcher<TaskRunnerEvent> eventDispatcher;
private final LinkedBlockingQueue<Task> tasksTobeExecuted;
private final boolean enableUpdateV2;
private static final int LEASE_EXTEND_RETRY_COUNT = 3;
private static final double LEASE_EXTEND_DURATION_FACTOR = 0.8;
private final ScheduledExecutorService leaseExtendExecutorService;
private Map<String, ScheduledFuture<?>> leaseExtendMap = new HashMap<>();

TaskRunner(Worker worker,
TaskClient taskClient,
@@ -122,6 +130,15 @@ class TaskRunner {
pollingIntervalInMillis,
domain);
LOGGER.info("Polling errors for taskType {} will be printed at every {} occurrence.", taskType, errorAt);

LOGGER.info("Initialized the task lease extend executor");
leaseExtendExecutorService = Executors.newSingleThreadScheduledExecutor(
new BasicThreadFactory.Builder()
.namingPattern("workflow-lease-extend-%d")
.daemon(true)
.uncaughtExceptionHandler(uncaughtExceptionHandler)
.build()
);
}

public void pollAndExecute() {
@@ -145,7 +162,25 @@ public void pollAndExecute() {
LOGGER.trace("Poller for task {} waited for {} ms before getting {} tasks to execute", taskType, stopwatch.elapsed(TimeUnit.MILLISECONDS), tasks.size());
stopwatch = null;
}
tasks.forEach(task -> this.executorService.submit(() -> this.processTask(task)));
tasks.forEach(task -> {
Future<Task> taskFuture = this.executorService.submit(() -> this.processTask(task));

if (task.getResponseTimeoutSeconds() > 0 && worker.leaseExtendEnabled()) {
ScheduledFuture<?> scheduledFuture = leaseExtendMap.get(task.getTaskId());
if (scheduledFuture != null) {
scheduledFuture.cancel(false);
}

long delay = Math.round(task.getResponseTimeoutSeconds() * LEASE_EXTEND_DURATION_FACTOR);
ScheduledFuture<?> leaseExtendFuture = leaseExtendExecutorService.scheduleWithFixedDelay(
extendLease(task, taskFuture),
delay,
delay,
TimeUnit.SECONDS
);
leaseExtendMap.put(task.getTaskId(), leaseExtendFuture);
}
});
} catch (Throwable t) {
LOGGER.error(t.getMessage(), t);
}
@@ -251,7 +286,7 @@ private List<Task> pollTask(int count) {
LOGGER.error("Uncaught exception. Thread {} will exit now", thread, error);
};

private void processTask(Task task) {
private Task processTask(Task task) {
eventDispatcher.publish(new TaskExecutionStarted(taskType, task.getTaskId(), worker.getIdentity()));
LOGGER.trace("Executing task: {} of type: {} in worker: {} at {}", task.getTaskId(), taskType, worker.getClass().getSimpleName(), worker.getIdentity());
LOGGER.trace("task {} is getting executed after {} ms of getting polled", task.getTaskId(), (System.currentTimeMillis() - task.getStartTime()));
@@ -271,6 +306,7 @@ private void processTask(Task task) {
} finally {
permits.release();
}
return task;
}

private void executeTask(Worker worker, Task task) {
@@ -400,4 +436,30 @@ private void handleException(Throwable t, TaskResult result, Worker worker, Task
result.log(stringWriter.toString());
updateTaskResult(updateRetryCount, task, result, worker);
}

private Runnable extendLease(Task task, Future<Task> taskCompletableFuture) {
return () -> {
if (taskCompletableFuture.isDone()) {
LOGGER.warn(
"Task processing for {} completed, but its lease extend was not cancelled",
task.getTaskId());
return;
}
LOGGER.info("Attempting to extend lease for {}", task.getTaskId());
try {
TaskResult result = new TaskResult(task);
result.setExtendLease(true);
retryOperation(
(TaskResult taskResult) -> {
taskClient.updateTask(taskResult);
return null;
},
LEASE_EXTEND_RETRY_COUNT,
result,
"extend lease");
} catch (Exception e) {
LOGGER.error("Failed to extend lease for {}", task.getTaskId(), e);
}
};
}
}
Original file line number Diff line number Diff line change
@@ -260,6 +260,9 @@ public TaskRunnerConfigurer.Builder withTaskToDomain(Map<String, String> taskToD
public TaskRunnerConfigurer.Builder withTaskThreadCount(
Map<String, Integer> taskToThreadCount) {
this.taskToThreadCount = taskToThreadCount;
if (taskToThreadCount.values().stream().anyMatch(v -> v < 1)) {
throw new IllegalArgumentException("No. of threads cannot be less than 1");
}
return this;
}

Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@ public interface Worker {
String PROP_ALL_WORKERS = "all";
String PROP_LOG_INTERVAL = "log_interval";
String PROP_POLL_INTERVAL = "poll_interval";
String PROP_LEASE_EXTEND_ENABLED = "leaseExtendEnabled";
String PROP_PAUSED = "paused";

/**
@@ -91,6 +92,10 @@ default int getPollingInterval() {
return PropertyFactory.getInteger(getTaskDefName(), PROP_POLL_INTERVAL, 1000);
}

default boolean leaseExtendEnabled() {
return PropertyFactory.getBoolean(getTaskDefName(), PROP_LEASE_EXTEND_ENABLED, false);
}

static Worker create(String taskType, Function<Task, TaskResult> executor) {
return new Worker() {

Original file line number Diff line number Diff line change
@@ -15,58 +15,67 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;

import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

import com.netflix.conductor.client.exception.ConductorClientException;
import com.netflix.conductor.client.http.TaskClient;
import com.netflix.conductor.client.worker.Worker;
import com.netflix.conductor.common.metadata.tasks.Task;
import com.netflix.conductor.common.metadata.tasks.TaskResult;

import static com.netflix.conductor.common.metadata.tasks.TaskResult.Status.COMPLETED;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class TaskRunnerConfigurerTest {

private static final String TEST_TASK_DEF_NAME = "test";

private TaskClient client;

@Before
@BeforeEach
public void setup() {
client = Mockito.mock(TaskClient.class);
}

@Test(expected = NullPointerException.class)
@Test
public void testNoWorkersException() {
new TaskRunnerConfigurer.Builder(null, null).build();
assertThrows(NullPointerException.class, () -> new TaskRunnerConfigurer.Builder(null, null).build());
}

@Test(expected = ConductorClientException.class)
@Test
public void testInvalidThreadConfig() {
Worker worker1 = Worker.create("task1", TaskResult::new);
Worker worker2 = Worker.create("task2", TaskResult::new);
Map<String, Integer> taskThreadCount = new HashMap<>();
taskThreadCount.put(worker1.getTaskDefName(), 2);
taskThreadCount.put(worker1.getTaskDefName(), 0);
taskThreadCount.put(worker2.getTaskDefName(), 3);
new TaskRunnerConfigurer.Builder(client, Arrays.asList(worker1, worker2))
.withThreadCount(10)
.withTaskThreadCount(taskThreadCount)
.build();

assertThrows(IllegalArgumentException.class, () -> new TaskRunnerConfigurer.Builder(client, Arrays.asList(worker1, worker2))
.withThreadCount(-1)
.withTaskThreadCount(taskThreadCount)
.build());

assertThrows(IllegalArgumentException.class, () -> new TaskRunnerConfigurer.Builder(client, Arrays.asList(worker1, worker2))
.withTaskThreadCount(taskThreadCount)
.build());
}

@Test
@@ -81,12 +90,12 @@ public void testMissingTaskThreadConfig() {
.build();

assertFalse(configurer.getTaskThreadCount().isEmpty());
assertEquals(2, configurer.getTaskThreadCount().size());
assertEquals(1, configurer.getTaskThreadCount().size());
assertEquals(2, configurer.getTaskThreadCount().get("task1").intValue());
assertEquals(1, configurer.getTaskThreadCount().get("task2").intValue());
}

@Test
@SuppressWarnings("deprecation")
public void testPerTaskThreadPool() {
Worker worker1 = Worker.create("task1", TaskResult::new);
Worker worker2 = Worker.create("task2", TaskResult::new);
@@ -104,19 +113,18 @@ public void testPerTaskThreadPool() {
}

@Test
@SuppressWarnings("deprecation")
public void testSharedThreadPool() {
Worker worker = Worker.create(TEST_TASK_DEF_NAME, TaskResult::new);
TaskRunnerConfigurer configurer =
new TaskRunnerConfigurer.Builder(client, Arrays.asList(worker, worker, worker))
.build();
configurer.init();
assertEquals(3, configurer.getThreadCount());
assertEquals(-1, configurer.getThreadCount());
assertEquals(500, configurer.getSleepWhenRetry());
assertEquals(3, configurer.getUpdateRetryCount());
assertEquals(10, configurer.getShutdownGracePeriodSeconds());
assertFalse(configurer.getTaskThreadCount().isEmpty());
assertEquals(1, configurer.getTaskThreadCount().size());
assertEquals(3, configurer.getTaskThreadCount().get(TEST_TASK_DEF_NAME).intValue());
assertTrue(configurer.getTaskThreadCount().isEmpty());

configurer =
new TaskRunnerConfigurer.Builder(client, Collections.singletonList(worker))
@@ -133,9 +141,7 @@ public void testSharedThreadPool() {
assertEquals(10, configurer.getUpdateRetryCount());
assertEquals(15, configurer.getShutdownGracePeriodSeconds());
assertEquals("test-worker-", configurer.getWorkerNamePrefix());
assertFalse(configurer.getTaskThreadCount().isEmpty());
assertEquals(1, configurer.getTaskThreadCount().size());
assertEquals(100, configurer.getTaskThreadCount().get(TEST_TASK_DEF_NAME).intValue());
assertTrue(configurer.getTaskThreadCount().isEmpty());
}

@Test
@@ -175,7 +181,7 @@ public void testMultipleWorkersExecution() throws Exception {
TaskClient taskClient = Mockito.mock(TaskClient.class);
TaskRunnerConfigurer configurer =
new TaskRunnerConfigurer.Builder(taskClient, Arrays.asList(worker1, worker2))
.withThreadCount(2)
.withThreadCount(1)
.withSleepWhenRetry(100000)
.withUpdateRetryCount(1)
.withWorkerNamePrefix("test-worker-")
@@ -186,9 +192,9 @@ public void testMultipleWorkersExecution() throws Exception {
Object[] args = invocation.getArguments();
String taskName = args[0].toString();
if (taskName.equals(task1Name)) {
return Arrays.asList(task1);
return List.of(task1);
} else if (taskName.equals(task2Name)) {
return Arrays.asList(task2);
return List.of(task2);
} else {
return Collections.emptyList();
}
@@ -220,6 +226,58 @@ public void testMultipleWorkersExecution() throws Exception {
assertEquals(1, task2Counter.get());
}

@Test
public void testLeaseExtension() throws Exception {
TaskClient taskClient = mock(TaskClient.class);
String taskName = "task1";

Worker worker = mock(Worker.class);
when(worker.getTaskDefName()).thenReturn(taskName);
when(worker.leaseExtendEnabled()).thenReturn(true);

doAnswer(invocation -> {
TaskResult result = new TaskResult(invocation.getArgument(0));
result.setStatus(TaskResult.Status.IN_PROGRESS);
return result;
}).when(worker).execute(any(Task.class));

Task task = new Task();
task.setTaskId("task123");
task.setTaskDefName(taskName);
task.setStatus(Task.Status.IN_PROGRESS);
task.setResponseTimeoutSeconds(2000);

when(taskClient.batchPollTasksInDomain(any(), any(), any(), anyInt(), anyInt()))
.thenAnswer((invocation) -> List.of(task));
when(taskClient.ack(any(), any())).thenReturn(true);

CountDownLatch latch = new CountDownLatch(1);
doAnswer(invocation -> {
latch.countDown();
return null;
}).when(taskClient).updateTask(any(TaskResult.class));

TaskRunnerConfigurer configurer = new TaskRunnerConfigurer.Builder(taskClient, List.of(worker))
.withSleepWhenRetry(100)
.withUpdateRetryCount(3)
.withThreadCount(1)
.build();

configurer.init();
latch.await();

ArgumentCaptor<TaskResult> taskResultCaptor = ArgumentCaptor.forClass(TaskResult.class);
verify(taskClient, atLeastOnce()).updateTask(taskResultCaptor.capture());

TaskResult capturedResult = taskResultCaptor.getValue();
assertNotNull(capturedResult);
assertEquals("task123", capturedResult.getTaskId());
assertEquals(TaskResult.Status.IN_PROGRESS, capturedResult.getStatus());

verify(worker, atLeastOnce()).execute(task);
assertTrue(worker.leaseExtendEnabled(), "Worker lease extension should be enabled");
}

private Task testTask(String taskDefName) {
Task task = new Task();
task.setTaskId(UUID.randomUUID().toString());
Original file line number Diff line number Diff line change
@@ -12,16 +12,17 @@
*/
package com.netflix.conductor.client.config;

import org.junit.Test;
import org.junit.jupiter.api.Test;

import com.netflix.conductor.client.worker.Worker;
import com.netflix.conductor.common.metadata.tasks.TaskResult;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;


public class TestPropertyFactory {

@@ -30,14 +31,14 @@ public void testIdentity() {
Worker worker = Worker.create("Test2", TaskResult::new);
assertNotNull(worker.getIdentity());
boolean paused = worker.paused();
assertFalse("Paused? " + paused, paused);
assertFalse(paused);
}

@Test
public void test() {

int val = PropertyFactory.getInteger("workerB", "pollingInterval", 100);
assertEquals("got: " + val, 2, val);
assertEquals(2, val);
assertEquals(
100, PropertyFactory.getInteger("workerB", "propWithoutValue", 100).intValue());

@@ -67,6 +68,6 @@ public void test() {
public void testProperty() {
Worker worker = Worker.create("Test", TaskResult::new);
boolean paused = worker.paused();
assertTrue("Paused? " + paused, paused);
assertTrue(paused);
}
}
Original file line number Diff line number Diff line change
@@ -15,8 +15,8 @@
import java.io.InputStream;
import java.util.List;

import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import com.netflix.conductor.common.config.ObjectMapperProvider;
import com.netflix.conductor.common.metadata.tasks.Task;
@@ -25,14 +25,14 @@

import com.fasterxml.jackson.databind.ObjectMapper;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

public class TestWorkflowTask {

private ObjectMapper objectMapper;

@Before
@BeforeEach
public void setup() {
objectMapper = new ObjectMapperProvider().getObjectMapper();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
conductor.worker.pollingInterval=2
conductor.worker.paused=false
conductor.worker.workerA.paused=true
conductor.worker.workerA.domain=domainA
conductor.worker.workerB.batchSize=84
conductor.worker.workerB.domain=domainB
conductor.worker.Test.paused=true
conductor.worker.domainTestTask2.domain=visinghDomain
conductor.worker.task_run_always.pollOutOfDiscovery=true
conductor.worker.task_explicit_do_not_run_always.pollOutOfDiscovery=false
conductor.worker.task_ignore_override.pollOutOfDiscovery=true
Original file line number Diff line number Diff line change
@@ -38,7 +38,7 @@ BulkResponse<String> pauseWorkflows(List<String> workflowIds) {
.body(workflowIds)
.build();

ConductorClientResponse<BulkResponse> resp = client.execute(request, new TypeReference<>() {
ConductorClientResponse<BulkResponse<String>> resp = client.execute(request, new TypeReference<>() {
});

return resp.getData();
@@ -52,7 +52,7 @@ BulkResponse<String> restartWorkflows(List<String> workflowIds, Boolean useLates
.body(workflowIds)
.build();

ConductorClientResponse<BulkResponse> resp = client.execute(request, new TypeReference<>() {
ConductorClientResponse<BulkResponse<String>> resp = client.execute(request, new TypeReference<>() {
});

return resp.getData();
@@ -65,7 +65,7 @@ BulkResponse<String> resumeWorkflows(List<String> workflowIds) {
.body(workflowIds)
.build();

ConductorClientResponse<BulkResponse> resp = client.execute(request, new TypeReference<>() {
ConductorClientResponse<BulkResponse<String>> resp = client.execute(request, new TypeReference<>() {
});

return resp.getData();
@@ -78,7 +78,7 @@ BulkResponse<String> retryWorkflows(List<String> workflowIds) {
.body(workflowIds)
.build();

ConductorClientResponse<BulkResponse> resp = client.execute(request, new TypeReference<>() {
ConductorClientResponse<BulkResponse<String>> resp = client.execute(request, new TypeReference<>() {
});

return resp.getData();
@@ -93,7 +93,7 @@ public BulkResponse<String> terminateWorkflows(List<String> workflowIds, String
.body(workflowIds)
.build();

ConductorClientResponse<BulkResponse> resp = client.execute(request, new TypeReference<>() {
ConductorClientResponse<BulkResponse<String>> resp = client.execute(request, new TypeReference<>() {
});

return resp.getData();
Original file line number Diff line number Diff line change
@@ -16,17 +16,17 @@
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import com.netflix.conductor.common.metadata.tasks.TaskResult;
import com.netflix.conductor.common.metadata.workflow.WorkflowTask;
import com.netflix.conductor.common.run.WorkflowTestRequest;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import com.netflix.conductor.common.metadata.tasks.TaskDef;
import com.netflix.conductor.common.metadata.tasks.TaskResult;
import com.netflix.conductor.common.metadata.workflow.StartWorkflowRequest;
import com.netflix.conductor.common.metadata.workflow.WorkflowDef;
import com.netflix.conductor.common.metadata.workflow.WorkflowTask;
import com.netflix.conductor.common.run.Workflow;
import com.netflix.conductor.common.run.WorkflowTestRequest;
import com.netflix.conductor.sdk.workflow.def.ConductorWorkflow;
import com.netflix.conductor.sdk.workflow.def.tasks.Http;
import com.netflix.conductor.sdk.workflow.def.tasks.SimpleTask;

0 comments on commit 99b0fc8

Please sign in to comment.