diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/HttpNativeExecutionTaskInfoFetcher.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/HttpNativeExecutionTaskInfoFetcher.java index 2c6bcf0bc7f7..c28b1519a19e 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/HttpNativeExecutionTaskInfoFetcher.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/HttpNativeExecutionTaskInfoFetcher.java @@ -58,6 +58,7 @@ public class HttpNativeExecutionTaskInfoFetcher private final ScheduledExecutorService errorRetryScheduledExecutor; private final AtomicReference lastException = new AtomicReference<>(); private final Duration maxErrorDuration; + private final Object taskFinished; @GuardedBy("this") private ScheduledFuture scheduledFuture; @@ -68,7 +69,8 @@ public HttpNativeExecutionTaskInfoFetcher( PrestoSparkHttpTaskClient workerClient, Executor executor, Duration infoFetchInterval, - Duration maxErrorDuration) + Duration maxErrorDuration, + Object taskFinished) { this.workerClient = requireNonNull(workerClient, "workerClient is null"); this.updateScheduledExecutor = requireNonNull(updateScheduledExecutor, "updateScheduledExecutor is null"); @@ -84,6 +86,7 @@ public HttpNativeExecutionTaskInfoFetcher( maxErrorDuration, errorRetryScheduledExecutor, "getting taskInfo from native process"); + this.taskFinished = requireNonNull(taskFinished, "taskFinished is null"); } public void start() @@ -101,6 +104,11 @@ public void onSuccess(BaseResponse result) { log.debug("TaskInfoCallback success %s", result.getValue().getTaskId()); taskInfo.set(result.getValue()); + if (result.getValue().getTaskStatus().getState().isDone()) { + synchronized (taskFinished) { + taskFinished.notifyAll(); + } + } } @Override diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/HttpNativeExecutionTaskResultFetcher.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/HttpNativeExecutionTaskResultFetcher.java index f8c85c72c21b..1a5b6b7d0caf 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/HttpNativeExecutionTaskResultFetcher.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/HttpNativeExecutionTaskResultFetcher.java @@ -59,17 +59,20 @@ public class HttpNativeExecutionTaskResultFetcher private final PrestoSparkHttpTaskClient workerClient; private final LinkedBlockingDeque pageBuffer = new LinkedBlockingDeque<>(); private final AtomicLong bufferMemoryBytes; + private final Object taskHasResult; private ScheduledFuture schedulerFuture; private boolean started; public HttpNativeExecutionTaskResultFetcher( ScheduledExecutorService scheduler, - PrestoSparkHttpTaskClient workerClient) + PrestoSparkHttpTaskClient workerClient, + Object taskHasResult) { this.scheduler = requireNonNull(scheduler, "scheduler is null"); this.workerClient = requireNonNull(workerClient, "workerClient is null"); this.bufferMemoryBytes = new AtomicLong(); + this.taskHasResult = requireNonNull(taskHasResult, "taskHasResult is null"); } public CompletableFuture start() @@ -85,7 +88,8 @@ public CompletableFuture start() workerClient, future, pageBuffer, - bufferMemoryBytes), + bufferMemoryBytes, + taskHasResult), 0, (long) FETCH_INTERVAL.getValue(), FETCH_INTERVAL.getUnit()); @@ -125,6 +129,11 @@ public Optional pollPage() return Optional.empty(); } + public boolean hasPage() + { + return !pageBuffer.isEmpty(); + } + private static class HttpNativeExecutionTaskResultFetcherRunner implements Runnable { @@ -135,6 +144,7 @@ private static class HttpNativeExecutionTaskResultFetcherRunner private final LinkedBlockingDeque pageBuffer; private final AtomicLong bufferMemoryBytes; private final CompletableFuture future; + private final Object taskFinishedOrHasResult; private int transportErrorRetries; private long token; @@ -143,7 +153,8 @@ public HttpNativeExecutionTaskResultFetcherRunner( PrestoSparkHttpTaskClient client, CompletableFuture future, LinkedBlockingDeque pageBuffer, - AtomicLong bufferMemoryBytes) + AtomicLong bufferMemoryBytes, + Object taskFinishedOrHasResult) { this.client = requireNonNull(client, "client is null"); this.future = requireNonNull(future, "future is null"); @@ -151,6 +162,7 @@ public HttpNativeExecutionTaskResultFetcherRunner( this.bufferMemoryBytes = requireNonNull( bufferMemoryBytes, "bufferMemoryBytes is null"); + this.taskFinishedOrHasResult = requireNonNull(taskFinishedOrHasResult, "taskFinishedOrHasResult is null"); } @Override @@ -184,6 +196,11 @@ public void run() client.abortResults(); future.complete(null); } + if (!pages.isEmpty()) { + synchronized (taskFinishedOrHasResult) { + taskFinishedOrHasResult.notifyAll(); + } + } } catch (InterruptedException e) { if (!future.isDone()) { diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/NativeExecutionTask.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/NativeExecutionTask.java index 0a215ef363b1..971de33f0497 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/NativeExecutionTask.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/NativeExecutionTask.java @@ -73,6 +73,7 @@ public class NativeExecutionTask private final HttpNativeExecutionTaskInfoFetcher taskInfoFetcher; // Results will be fetched only if not written to shuffle. private final Optional taskResultFetcher; + private final Object taskFinishedOrHasResult = new Object(); public NativeExecutionTask( Session session, @@ -116,11 +117,13 @@ public NativeExecutionTask( this.workerClient, this.executor, taskManagerConfig.getInfoUpdateInterval(), - queryManagerConfig.getRemoteTaskMaxErrorDuration()); + queryManagerConfig.getRemoteTaskMaxErrorDuration(), + taskFinishedOrHasResult); if (!shuffleWriteInfo.isPresent()) { this.taskResultFetcher = Optional.of(new HttpNativeExecutionTaskResultFetcher( updateScheduledExecutor, - this.workerClient)); + this.workerClient, + taskFinishedOrHasResult)); } else { this.taskResultFetcher = Optional.empty(); @@ -139,6 +142,17 @@ public Optional getTaskInfo() return taskInfoFetcher.getTaskInfo(); } + public boolean isTaskDone() + { + Optional taskInfo = getTaskInfo(); + return taskInfo.isPresent() && taskInfo.get().getTaskStatus().getState().isDone(); + } + + public Object getTaskFinishedOrHasResult() + { + return taskFinishedOrHasResult; + } + /** * Blocking call to poll from result fetcher buffer. Blocks until content becomes available in the buffer, or until timeout is hit. * @@ -153,6 +167,11 @@ public Optional pollResult() return taskResultFetcher.get().pollPage(); } + public boolean hasResult() + { + return taskResultFetcher.isPresent() && taskResultFetcher.get().hasPage(); + } + /** * Blocking call to create and start native task. *

diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkNativeTaskExecutorFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkNativeTaskExecutorFactory.java index bdd9a2defcbb..50b8d46b6cd9 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkNativeTaskExecutorFactory.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/PrestoSparkNativeTaskExecutorFactory.java @@ -27,7 +27,6 @@ import com.facebook.presto.execution.TaskInfo; import com.facebook.presto.execution.TaskSource; import com.facebook.presto.execution.TaskState; -import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.RemoteTransactionHandle; import com.facebook.presto.metadata.SessionPropertyManager; import com.facebook.presto.metadata.Split; @@ -97,18 +96,18 @@ * It will send necessary metadata (e.g, plan fragment, session properties etc.) as a part of * BatchTaskUpdateRequest. It will poll the remote CPP task for status and results (pages/data if applicable) * and send these back to the Spark's RDD api - * + *

* PrestoSparkNativeTaskExecutorFactory is singleton instantiated once per executor. - * + *

* For every task it receives, it does the following * 1. Create the Native execution Process (NativeTaskExecutionFactory) ensure that is it created only once. * 2. Serialize and pass the planFragment, source-metadata (taskSources), sink-metadata (tableWriteInfo or shuffleWriteInfo) - * and submit a nativeExecutionTask. + * and submit a nativeExecutionTask. * 3. Return Iterator to sparkRDD layer. RDD execution will call the .next() methods, which will - * 3.a Call {@link NativeExecutionTask}'s pollResult() to retrieve {@link SerializedPage} back from external process. - * 3.b If no more output is available, then check if task has finished successfully or with exception - * If task finished with exception - fail the spark task (throw exception) - * IF task finished successfully - collect statistics through taskInfo object and add to accumulator + * 3.a Call {@link NativeExecutionTask}'s pollResult() to retrieve {@link SerializedPage} back from external process. + * 3.b If no more output is available, then check if task has finished successfully or with exception + * If task finished with exception - fail the spark task (throw exception) + * IF task finished successfully - collect statistics through taskInfo object and add to accumulator */ public class PrestoSparkNativeTaskExecutorFactory implements IPrestoSparkTaskExecutorFactory @@ -123,7 +122,6 @@ public class PrestoSparkNativeTaskExecutorFactory private static final TaskId DUMMY_TASK_ID = TaskId.valueOf("remotesourcetaskid.0.0.0.0"); private final SessionPropertyManager sessionPropertyManager; - private final FunctionAndTypeManager functionAndTypeManager; private final JsonCodec taskDescriptorJsonCodec; private final Codec taskSourceCodec; private final Codec taskInfoCodec; @@ -137,7 +135,6 @@ public class PrestoSparkNativeTaskExecutorFactory @Inject public PrestoSparkNativeTaskExecutorFactory( SessionPropertyManager sessionPropertyManager, - FunctionAndTypeManager functionAndTypeManager, JsonCodec taskDescriptorJsonCodec, Codec taskSourceCodec, Codec taskInfoCodec, @@ -149,7 +146,6 @@ public PrestoSparkNativeTaskExecutorFactory( PrestoSparkShuffleInfoTranslator shuffleInfoTranslator) { this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null"); - this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null"); this.taskDescriptorJsonCodec = requireNonNull(taskDescriptorJsonCodec, "sparkTaskDescriptorJsonCodec is null"); this.taskSourceCodec = requireNonNull(taskSourceCodec, "taskSourceCodec is null"); this.taskInfoCodec = requireNonNull(taskInfoCodec, "taskInfoCodec is null"); @@ -255,7 +251,7 @@ public IPrestoSparkTaskExecutor doCreate( TaskInfo taskInfo = task.start(); // task creation might have failed - processTaskInfoForErrors(taskInfo); + processTaskInfoForErrorsOrCompletion(taskInfo); // 4. return output to spark RDD layer return new PrestoSparkNativeTaskOutputIterator<>(task, outputType, taskInfoCollector, taskInfoCodec, executionExceptionFactory); } @@ -286,7 +282,7 @@ private static void completeTask(CollectionAccumulator taskI PrestoSparkStatsCollectionUtils.collectMetrics(taskInfoOptional.get()); } - private static void processTaskInfoForErrors(TaskInfo taskInfo) + private static void processTaskInfoForErrorsOrCompletion(TaskInfo taskInfo) { if (!taskInfo.getTaskStatus().getState().isDone()) { log.info("processTaskInfoForErrors: task is not done yet.. %s", taskInfo); @@ -414,8 +410,9 @@ public PrestoSparkNativeTaskOutputIterator( /** * This function is called by Spark's RDD layer to check if there are output pages * There are 2 scenarios - * 1. ShuffleMap Task - Always returns false. But the internal function calls do all the work needed - * 2. Result Task - True until pages are available. False once all pages have been extracted + * 1. ShuffleMap Task - Always returns false. But the internal function calls do all the work needed + * 2. Result Task - True until pages are available. False once all pages have been extracted + * * @return if output is available */ @Override @@ -425,74 +422,72 @@ public boolean hasNext() return next.isPresent(); } - /** This function returns the next available page fetched from CPP process - * - * Has 3 main responsibilities - * 1) Busy-wait-for-pages-or-completion - * - * Loop until either of the 3 conditions happen - * * 1. We get a page - * * 2. Task has finished successfully - * * 3. Task has finished with error - * - * For ShuffleMap Task, as of now, the CPP process returns no pages. - * So the loop acts as a wait-for-completion loop and returns an Optional.empty() - * once the task has terminated - * - * For a Result Task, this function will return all the pages and Optional.empty() - * once all the pages have been read and the task has been terminates - * - * 2) Exception handling - * when there are no pages available, the function checks if the task has finished - * with exceptions and throws the appropriate exception back to spark's RDD processing - * layer - * - * 3) Statistics collection - * For both, when the task finished successfully or with exception, it tries to collect - * statistics if TaskInfo object is available + /** + * This function returns the next available page fetched from CPP process + *

+ * Has 3 main responsibilities + * 1) wait-for-pages-or-completion + *

+ * The thread running this method will wait until either of the 3 conditions happen + * * 1. We get a page + * * 2. Task has finished successfully + * * 3. Task has finished with error + *

+ * For ShuffleMap Task, as of now, the CPP process returns no pages. + * So the thread will be in WAITING state till the CPP task is done and returns an Optional.empty() + * once the task has terminated + *

+ * For a Result Task, this function will return pages retrieved from CPP side once we got them. + * Once all the pages have been read and the task has been terminates + *

+ * 2) Exception handling + * The function also checks if the task has finished + * with exceptions and throws the appropriate exception back to spark's RDD processing + * layer + *

+ * 3) Statistics collection + * For both, when the task finished successfully or with exception, it tries to collect + * statistics if TaskInfo object is available * * @return Optional outputPage */ private Optional computeNext() { - // A while(true) loop is not desirable, but in this case we cannot avoid - // it because of Spark'sRDD contract, which is that this iterator either - // returns data or is complete. It CANNOT return null. - // While the remote task is still running and there is no output pages, - // we need to simulate a busy-loop to avoid returning null. - while (true) { - try { - // For ShuffleMap Task, this will always return Optional.empty() - Optional pageOptional = nativeExecutionTask.pollResult(); - - if (pageOptional.isPresent()) { - return pageOptional; + try { + Object taskFinishedOrHasResult = nativeExecutionTask.getTaskFinishedOrHasResult(); + // Blocking wait if task is still running or hasn't produced any output page + synchronized (taskFinishedOrHasResult) { + while (!nativeExecutionTask.isTaskDone() && !nativeExecutionTask.hasResult()) { + taskFinishedOrHasResult.wait(); } + } - try { - Optional taskInfo = nativeExecutionTask.getTaskInfo(); + // For ShuffleMap Task, this will always return Optional.empty() + Optional pageOptional = nativeExecutionTask.pollResult(); - // Case1: Task is still running - if (!taskInfo.isPresent() || !taskInfo.get().getTaskStatus().getState().isDone()) { - continue; - } + if (pageOptional.isPresent()) { + return pageOptional; + } - // Case 2: Task finished with errors captured inside taskInfo - processTaskInfoForErrors(taskInfo.get()); - } - catch (RuntimeException ex) { - // For a failed task, if taskInfo is present we still want to log the metrics - completeTask(taskInfoCollectionAccumulator, nativeExecutionTask, taskInfoCodec); - throw executionExceptionFactory.toPrestoSparkExecutionException(ex); + // Double check if current task's already done (since thread could be awoken by either having output or task is done above) + synchronized (taskFinishedOrHasResult) { + while (!nativeExecutionTask.isTaskDone()) { + taskFinishedOrHasResult.wait(); } - - // Case3: Task terminated with success - break; - } - catch (InterruptedException e) { - log.error(e); - throw new RuntimeException(e); } + + Optional taskInfo = nativeExecutionTask.getTaskInfo(); + + processTaskInfoForErrorsOrCompletion(taskInfo.get()); + } + catch (RuntimeException ex) { + // For a failed task, if taskInfo is present we still want to log the metrics + completeTask(taskInfoCollectionAccumulator, nativeExecutionTask, taskInfoCodec); + throw executionExceptionFactory.toPrestoSparkExecutionException(ex); + } + catch (InterruptedException e) { + log.error(e); + throw new RuntimeException(e); } // Reaching here marks the end of task processing diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/http/TestPrestoSparkHttpClient.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/http/TestPrestoSparkHttpClient.java index 99d2c70d3ea4..e4b5bc07ea13 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/http/TestPrestoSparkHttpClient.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/http/TestPrestoSparkHttpClient.java @@ -28,6 +28,8 @@ import com.facebook.presto.execution.TaskInfo; import com.facebook.presto.execution.TaskManagerConfig; import com.facebook.presto.execution.TaskSource; +import com.facebook.presto.execution.TaskState; +import com.facebook.presto.execution.TaskStatus; import com.facebook.presto.execution.scheduler.TableWriteInfo; import com.facebook.presto.operator.PageBufferClient; import com.facebook.presto.operator.PageTransportErrorException; @@ -53,6 +55,8 @@ import com.facebook.presto.sql.planner.PlanFragment; import com.facebook.presto.testing.TestingSession; import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import com.google.common.net.MediaType; import com.google.common.util.concurrent.AbstractFuture; @@ -334,7 +338,8 @@ public void testResultFetcher() new Duration(1, TimeUnit.SECONDS)); HttpNativeExecutionTaskResultFetcher taskResultFetcher = new HttpNativeExecutionTaskResultFetcher( newScheduledThreadPool(1), - workerClient); + workerClient, + new Object()); CompletableFuture future = taskResultFetcher.start(); try { future.get(); @@ -405,7 +410,8 @@ else if (requestCount == numPages) { new Duration(1, TimeUnit.SECONDS)); HttpNativeExecutionTaskResultFetcher taskResultFetcher = new HttpNativeExecutionTaskResultFetcher( newScheduledThreadPool(1), - workerClient); + workerClient, + new Object()); CompletableFuture future = taskResultFetcher.start(); try { future.get(); @@ -499,7 +505,8 @@ public void testResultFetcherExceedingBufferLimit() new Duration(1, TimeUnit.SECONDS)); HttpNativeExecutionTaskResultFetcher taskResultFetcher = new HttpNativeExecutionTaskResultFetcher( newScheduledThreadPool(10), - workerClient); + workerClient, + new Object()); CompletableFuture future = taskResultFetcher.start(); try { Optional page = Optional.empty(); @@ -604,7 +611,8 @@ public void testResultFetcherTransportErrorRecovery() new Duration(1, TimeUnit.SECONDS)); HttpNativeExecutionTaskResultFetcher taskResultFetcher = new HttpNativeExecutionTaskResultFetcher( newScheduledThreadPool(10), - workerClient); + workerClient, + new Object()); CompletableFuture future = taskResultFetcher.start(); try { future.get(); @@ -641,11 +649,45 @@ public void testResultFetcherTransportErrorFail() new Duration(1, TimeUnit.SECONDS)); HttpNativeExecutionTaskResultFetcher taskResultFetcher = new HttpNativeExecutionTaskResultFetcher( newScheduledThreadPool(1), - workerClient); + workerClient, + new Object()); CompletableFuture future = taskResultFetcher.start(); assertThrows(ExecutionException.class, future::get); } + @Test + public void testResultFetcherWaitOnSignal() + { + TaskId taskId = new TaskId("testid", 0, 0, 0, 0); + Object lock = new Object(); + + PrestoSparkHttpTaskClient workerClient = new PrestoSparkHttpTaskClient( + new TestingHttpClient(new TestingResponseManager(taskId.toString())), + taskId, + BASE_URI, + TASK_INFO_JSON_CODEC, + PLAN_FRAGMENT_JSON_CODEC, + TASK_UPDATE_REQUEST_JSON_CODEC, + new Duration(1, TimeUnit.SECONDS)); + HttpNativeExecutionTaskResultFetcher taskResultFetcher = new HttpNativeExecutionTaskResultFetcher( + newScheduledThreadPool(1), + workerClient, + lock); + taskResultFetcher.start(); + try { + synchronized (lock) { + while (!taskResultFetcher.hasPage()) { + lock.wait(); + } + } + assertTrue(taskResultFetcher.hasPage()); + } + catch (InterruptedException e) { + e.printStackTrace(); + fail(); + } + } + @Test public void testInfoFetcher() { @@ -674,7 +716,8 @@ public void testInfoFetcherWithRetry() HttpNativeExecutionTaskInfoFetcher taskInfoFetcher = createTaskInfoFetcher( taskId, new TestingResponseManager(taskId.toString(), new FailureTaskInfoRetryResponseManager(1)), - new Duration(5, TimeUnit.SECONDS)); + new Duration(5, TimeUnit.SECONDS), + new Object()); assertFalse(taskInfoFetcher.getTaskInfo().isPresent()); taskInfoFetcher.start(); try { @@ -699,6 +742,34 @@ public void testInfoFetcherWithRetry() assertTrue(exception.getMessage().contains("TaskInfoFetcher encountered too many errors talking to native process.")); } + @Test + public void testInfoFetcherWaitOnSignal() + { + TaskId taskId = new TaskId("testid", 0, 0, 0, 0); + Object lock = new Object(); + + Duration fetchInterval = new Duration(1, TimeUnit.SECONDS); + HttpNativeExecutionTaskInfoFetcher taskInfoFetcher = createTaskInfoFetcher(taskId, new TestingResponseManager(taskId.toString(), TaskState.FINISHED), lock); + assertFalse(taskInfoFetcher.getTaskInfo().isPresent()); + taskInfoFetcher.start(); + try { + synchronized (lock) { + while (!isTaskDone(taskInfoFetcher.getTaskInfo())) { + lock.wait(); + } + } + } + catch (InterruptedException e) { + fail(); + } + assertTrue(isTaskDone(taskInfoFetcher.getTaskInfo())); + } + + private boolean isTaskDone(Optional taskInfo) + { + return taskInfo.isPresent() && taskInfo.get().getTaskStatus().getState().isDone(); + } + @Test public void testNativeExecutionTask() { @@ -786,10 +857,15 @@ private NativeExecutionProcess createNativeExecutionProcess( private HttpNativeExecutionTaskInfoFetcher createTaskInfoFetcher(TaskId taskId, TestingResponseManager testingResponseManager) { - return createTaskInfoFetcher(taskId, testingResponseManager, new Duration(1, TimeUnit.MINUTES)); + return createTaskInfoFetcher(taskId, testingResponseManager, new Duration(1, TimeUnit.MINUTES), new Object()); + } + + private HttpNativeExecutionTaskInfoFetcher createTaskInfoFetcher(TaskId taskId, TestingResponseManager testingResponseManager, Object lock) + { + return createTaskInfoFetcher(taskId, testingResponseManager, new Duration(1, TimeUnit.MINUTES), lock); } - private HttpNativeExecutionTaskInfoFetcher createTaskInfoFetcher(TaskId taskId, TestingResponseManager testingResponseManager, Duration maxErrorDuration) + private HttpNativeExecutionTaskInfoFetcher createTaskInfoFetcher(TaskId taskId, TestingResponseManager testingResponseManager, Duration maxErrorDuration, Object lock) { PrestoSparkHttpTaskClient workerClient = new PrestoSparkHttpTaskClient( new TestingHttpClient(testingResponseManager), @@ -805,7 +881,8 @@ private HttpNativeExecutionTaskInfoFetcher createTaskInfoFetcher(TaskId taskId, workerClient, newSingleThreadExecutor(), new Duration(1, TimeUnit.SECONDS), - maxErrorDuration); + maxErrorDuration, + lock); } private static class TestingHttpResponseFuture @@ -1012,6 +1089,14 @@ public TestingResponseManager(String taskId) this.taskInfoResponseManager = new TestingTaskInfoResponseManager(); } + public TestingResponseManager(String taskId, TaskState taskState) + { + this.taskId = requireNonNull(taskId, "taskId is null"); + this.resultResponseManager = new TestingResultResponseManager(); + this.serverResponseManager = new TestingServerResponseManager(); + this.taskInfoResponseManager = new TestingTaskInfoResponseManager(taskState); + } + public TestingResponseManager(String taskId, TestingResultResponseManager resultResponseManager) { this.taskId = requireNonNull(taskId, "taskId is null"); @@ -1130,23 +1215,62 @@ protected Response createResultResponseHelper( */ public static class TestingTaskInfoResponseManager { + private final TaskState taskState; + + public TestingTaskInfoResponseManager() + { + taskState = TaskState.PLANNED; + } + + public TestingTaskInfoResponseManager(TaskState taskState) + { + this.taskState = taskState; + } + public Response createTaskInfoResponse(HttpStatus httpStatus, String taskId) throws PrestoException { + URI location = uriBuilderFrom(BASE_URI).appendPath(TASK_ROOT_PATH).build(); ListMultimap headers = ArrayListMultimap.create(); headers.put(HeaderName.of(CONTENT_TYPE), String.valueOf(MediaType.create("application", "json"))); TaskInfo taskInfo = TaskInfo.createInitialTask( TaskId.valueOf(taskId), - uriBuilderFrom(BASE_URI).appendPath(TASK_ROOT_PATH).build(), + location, new ArrayList<>(), new TaskStats(DateTime.now(), null), - "dummy-node"); + "dummy-node").withTaskStatus(createTaskStatusDone(location)); return new TestingResponse( httpStatus.code(), httpStatus.toString(), headers, new ByteArrayInputStream(taskInfoCodec.toBytes(taskInfo))); } + + private TaskStatus createTaskStatusDone(URI location) + { + return new TaskStatus( + 0L, + 0L, + 0, + taskState, + location, + ImmutableSet.of(), + ImmutableList.of(), + 0, + 0, + 0.0, + false, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0L, + 0L); + } } } @@ -1252,6 +1376,7 @@ private static class FailureTaskInfoRetryResponseManager public FailureTaskInfoRetryResponseManager(int failureCount) { + super(); this.failureCount = failureCount; }