diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/http/PrestoSparkHttpTaskClient.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/http/PrestoSparkHttpTaskClient.java index 0b395e55a0a75..161e51293c48f 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/http/PrestoSparkHttpTaskClient.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/http/PrestoSparkHttpTaskClient.java @@ -98,7 +98,6 @@ public class PrestoSparkHttpTaskClient private static final Logger log = Logger.get(PrestoSparkHttpTaskClient.class); private final OkHttpClient httpClient; private final URI location; - private final URI taskUri; private final JsonCodec taskInfoCodec; private final JsonCodec planFragmentCodec; private final JsonCodec taskUpdateRequestCodec; @@ -109,7 +108,6 @@ public class PrestoSparkHttpTaskClient public PrestoSparkHttpTaskClient( OkHttpClient httpClient, - TaskId taskId, URI location, JsonCodec taskInfoCodec, JsonCodec planFragmentCodec, @@ -124,7 +122,6 @@ public PrestoSparkHttpTaskClient( this.taskInfoCodec = requireNonNull(taskInfoCodec, "taskInfoCodec is null"); this.planFragmentCodec = requireNonNull(planFragmentCodec, "planFragmentCodec is null"); this.taskUpdateRequestCodec = requireNonNull(taskUpdateRequestCodec, "taskUpdateRequestCodec is null"); - this.taskUri = createTaskUri(location, taskId); this.infoRefreshMaxWait = requireNonNull(infoRefreshMaxWait, "infoRefreshMaxWait is null"); this.executor = requireNonNull(executor, "executor is null"); this.scheduledExecutorService = requireNonNull(scheduledExecutorService, "scheduledExecutorService is null"); @@ -134,7 +131,7 @@ public PrestoSparkHttpTaskClient( /** * Get results from a native engine task that ends with none shuffle operator. It always fetches from a single buffer. */ - public ListenableFuture getResults(long token, DataSize maxResponseSize) + public ListenableFuture getResults(TaskId taskId, long token, DataSize maxResponseSize) { RequestErrorTracker errorTracker = new RequestErrorTracker( "NativeExecution", @@ -145,100 +142,13 @@ public ListenableFuture getResults(long token, DataSize maxRespon scheduledExecutorService, "sending update request to native process"); SettableFuture result = SettableFuture.create(); - scheduleGetResultsRequest(prepareGetResultsRequest(token, maxResponseSize), errorTracker, result); + scheduleGetResultsRequest(prepareGetResultsRequest(taskId, token, maxResponseSize), errorTracker, result); return result; } - private void scheduleGetResultsRequest( - Request request, - RequestErrorTracker errorTracker, - SettableFuture result) - { - ListenableFuture permitFuture = (ListenableFuture) errorTracker.acquireRequestPermit(); - addCallback(permitFuture, new FutureCallback() { - @Override - public void onSuccess(Void ignored) - { - errorTracker.startRequest(); - httpClient.newCall(request).enqueue(new Callback() { - @Override - public void onFailure(Call call, IOException e) - { - handleGetResultsFailure(e, errorTracker, request, result); - } - - @Override - public void onResponse(Call call, Response response) - { - try { - BaseResponse baseResponse = new PageResponseHandler().handle(request, response); - if (baseResponse.hasValue()) { - errorTracker.requestSucceeded(); - result.set(baseResponse.getValue()); - } - else { - Exception exception = baseResponse.getException(); - if (exception != null) { - handleGetResultsFailure(exception, errorTracker, request, result); - } - else { - handleGetResultsFailure(new RuntimeException("Empty response without exception"), errorTracker, request, result); - } - } - } - catch (Exception e) { - handleGetResultsFailure(e, errorTracker, request, result); - } - finally { - response.close(); - } - } - }); - } - - @Override - public void onFailure(Throwable t) - { - result.setException(t); - } - }, executor); - } - - private void handleGetResultsFailure(Throwable failure, RequestErrorTracker errorTracker, - Request request, SettableFuture result) - { - log.info("Received failure response with exception %s", failure); - if (Arrays.stream(failure.getSuppressed()).anyMatch(t -> t instanceof PrestoException)) { - result.setException(failure); - return; - } - try { - errorTracker.requestFailed(failure); - scheduleGetResultsRequest(request, errorTracker, result); - } - catch (Throwable t) { - result.setException(t); - } - } - - private Request prepareGetResultsRequest(long token, DataSize maxResponseSize) + public void acknowledgeResultsAsync(TaskId taskId, long nextToken) { - HttpUrl url = HttpUrl.get(taskUri).newBuilder() - .addPathSegment("results") - .addPathSegment("0") - .addPathSegment(String.valueOf(token)) - .build(); - - return new Request.Builder() - .url(url) - .get() - .addHeader(PRESTO_MAX_SIZE, maxResponseSize.toString()) - .build(); - } - - public void acknowledgeResultsAsync(long nextToken) - { - HttpUrl url = HttpUrl.get(taskUri).newBuilder() + HttpUrl url = HttpUrl.get(getTaskUri(taskId)).newBuilder() .addPathSegment("results") .addPathSegment("0") .addPathSegment(String.valueOf(nextToken)) @@ -259,9 +169,9 @@ public void acknowledgeResultsAsync(long nextToken) scheduleVoidRequest(request, new BytesResponseHandler(), errorTracker, result); } - public ListenableFuture abortResultsAsync() + public ListenableFuture abortResultsAsync(TaskId taskId) { - HttpUrl url = HttpUrl.get(taskUri).newBuilder() + HttpUrl url = HttpUrl.get(getTaskUri(taskId)).newBuilder() .addPathSegment("results") .addPathSegment("0") .build(); @@ -280,11 +190,11 @@ public ListenableFuture abortResultsAsync() return result; } - public TaskInfo getTaskInfo() + public TaskInfo getTaskInfo(TaskId taskId) { Request request = setContentTypeHeaders(new Request.Builder()) .addHeader(PRESTO_MAX_WAIT, infoRefreshMaxWait.toString()) - .url(taskUri.toString()) + .url(getTaskUri(taskId).toString()) .get() .build(); ListenableFuture future = executeWithRetries( @@ -296,6 +206,7 @@ public TaskInfo getTaskInfo() } public TaskInfo updateTask( + TaskId taskId, List sources, PlanFragment planFragment, TableWriteInfo tableWriteInfo, @@ -315,7 +226,7 @@ public TaskInfo updateTask( writeInfo); BatchTaskUpdateRequest batchTaskUpdateRequest = new BatchTaskUpdateRequest(updateRequest, shuffleWriteInfo, broadcastBasePath); - HttpUrl url = HttpUrl.get(taskUri).newBuilder() + HttpUrl url = HttpUrl.get(getTaskUri(taskId)).newBuilder() .addPathSegment("batch") .build(); byte[] requestBody = taskUpdateRequestCodec.toBytes(batchTaskUpdateRequest); @@ -336,19 +247,101 @@ public URI getLocation() return location; } - public URI getTaskUri() + public URI getTaskUri(TaskId taskId) { - return taskUri; + return HttpUrl.get(location).newBuilder() + .addPathSegment("v1") + .addPathSegment("task") + .addPathSegment(taskId.toString()) + .build() + .uri(); } - private URI createTaskUri(URI baseUri, TaskId taskId) + private void scheduleGetResultsRequest( + Request request, + RequestErrorTracker errorTracker, + SettableFuture result) { - return HttpUrl.get(baseUri).newBuilder() - .addPathSegment("v1") - .addPathSegment("task") - .addPathSegment(taskId.toString()) - .build() - .uri(); + ListenableFuture permitFuture = (ListenableFuture) errorTracker.acquireRequestPermit(); + addCallback(permitFuture, new FutureCallback() { + @Override + public void onSuccess(Void ignored) + { + errorTracker.startRequest(); + httpClient.newCall(request).enqueue(new Callback() { + @Override + public void onFailure(Call call, IOException e) + { + handleGetResultsFailure(e, errorTracker, request, result); + } + + @Override + public void onResponse(Call call, Response response) + { + try { + BaseResponse baseResponse = new PageResponseHandler().handle(request, response); + if (baseResponse.hasValue()) { + errorTracker.requestSucceeded(); + result.set(baseResponse.getValue()); + } + else { + Exception exception = baseResponse.getException(); + if (exception != null) { + handleGetResultsFailure(exception, errorTracker, request, result); + } + else { + handleGetResultsFailure(new RuntimeException("Empty response without exception"), errorTracker, request, result); + } + } + } + catch (Exception e) { + handleGetResultsFailure(e, errorTracker, request, result); + } + finally { + response.close(); + } + } + }); + } + + @Override + public void onFailure(Throwable t) + { + result.setException(t); + } + }, executor); + } + + private void handleGetResultsFailure(Throwable failure, RequestErrorTracker errorTracker, + Request request, SettableFuture result) + { + log.info("Received failure response with exception %s", failure); + if (Arrays.stream(failure.getSuppressed()).anyMatch(t -> t instanceof PrestoException)) { + result.setException(failure); + return; + } + try { + errorTracker.requestFailed(failure); + scheduleGetResultsRequest(request, errorTracker, result); + } + catch (Throwable t) { + result.setException(t); + } + } + + private Request prepareGetResultsRequest(TaskId taskId, long token, DataSize maxResponseSize) + { + HttpUrl url = HttpUrl.get(getTaskUri(taskId)).newBuilder() + .addPathSegment("results") + .addPathSegment("0") + .addPathSegment(String.valueOf(token)) + .build(); + + return new Request.Builder() + .url(url) + .get() + .addHeader(PRESTO_MAX_SIZE, maxResponseSize.toString()) + .build(); } private ListenableFuture executeWithRetries( diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/HttpNativeExecutionTaskInfoFetcher.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/HttpNativeExecutionTaskInfoFetcher.java index 66ba162b54b25..037486f83f913 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/HttpNativeExecutionTaskInfoFetcher.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/HttpNativeExecutionTaskInfoFetcher.java @@ -15,6 +15,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.airlift.units.Duration; +import com.facebook.presto.execution.TaskId; import com.facebook.presto.execution.TaskInfo; import com.facebook.presto.spark.execution.http.PrestoSparkHttpTaskClient; import com.google.common.annotations.VisibleForTesting; @@ -39,6 +40,7 @@ public class HttpNativeExecutionTaskInfoFetcher { private static final Logger log = Logger.get(HttpNativeExecutionTaskInfoFetcher.class); + private final TaskId taskId; private final PrestoSparkHttpTaskClient workerClient; private final ScheduledExecutorService updateScheduledExecutor; private final AtomicReference taskInfo = new AtomicReference<>(); @@ -50,11 +52,13 @@ public class HttpNativeExecutionTaskInfoFetcher private ScheduledFuture scheduledFuture; public HttpNativeExecutionTaskInfoFetcher( + TaskId taskId, ScheduledExecutorService updateScheduledExecutor, PrestoSparkHttpTaskClient workerClient, Duration infoFetchInterval, Object taskFinished) { + this.taskId = requireNonNull(taskId, "taskId is null"); this.workerClient = requireNonNull(workerClient, "workerClient is null"); this.updateScheduledExecutor = requireNonNull(updateScheduledExecutor, "updateScheduledExecutor is null"); this.infoFetchInterval = requireNonNull(infoFetchInterval, "infoFetchInterval is null"); @@ -78,7 +82,7 @@ public void stop() void doGetTaskInfo() { try { - TaskInfo result = workerClient.getTaskInfo(); + TaskInfo result = workerClient.getTaskInfo(taskId); onSuccess(result); } catch (Throwable t) { diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/HttpNativeExecutionTaskResultFetcher.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/HttpNativeExecutionTaskResultFetcher.java index 7ecfa9b3d5d8c..881ae57098acc 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/HttpNativeExecutionTaskResultFetcher.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/HttpNativeExecutionTaskResultFetcher.java @@ -16,6 +16,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.airlift.units.DataSize; import com.facebook.airlift.units.Duration; +import com.facebook.presto.execution.TaskId; import com.facebook.presto.operator.PageBufferClient; import com.facebook.presto.spark.execution.http.PrestoSparkHttpTaskClient; import com.facebook.presto.spi.HostAddress; @@ -58,6 +59,7 @@ public class HttpNativeExecutionTaskResultFetcher private static final DataSize MAX_RESPONSE_SIZE = new DataSize(32, DataSize.Unit.MEGABYTE); private static final DataSize MAX_BUFFER_SIZE = new DataSize(128, DataSize.Unit.MEGABYTE); + private final TaskId taskId; private final ScheduledExecutorService scheduler; private final PrestoSparkHttpTaskClient workerClient; private final LinkedBlockingDeque pageBuffer = new LinkedBlockingDeque<>(); @@ -72,10 +74,12 @@ public class HttpNativeExecutionTaskResultFetcher private long token; public HttpNativeExecutionTaskResultFetcher( + TaskId taskId, ScheduledExecutorService scheduler, PrestoSparkHttpTaskClient workerClient, Object taskHasResult) { + this.taskId = taskId; this.scheduler = requireNonNull(scheduler, "scheduler is null"); this.workerClient = requireNonNull(workerClient, "workerClient is null"); this.bufferMemoryBytes = new AtomicLong(); @@ -146,7 +150,8 @@ private void doGetResults() } try { - PageBufferClient.PagesResponse pagesResponse = getFutureValue(workerClient.getResults(token, MAX_RESPONSE_SIZE)); + PageBufferClient.PagesResponse pagesResponse = getFutureValue( + workerClient.getResults(taskId, token, MAX_RESPONSE_SIZE)); onSuccess(pagesResponse); } catch (Throwable t) { @@ -169,18 +174,19 @@ private void onSuccess(PageBufferClient.PagesResponse pagesResponse) bytes += page.getSizeInBytes(); positionCount += page.getPositionCount(); } - log.info("Received %s rows in %s pages from %s", positionCount, pages.size(), workerClient.getTaskUri()); + log.info("Received %s rows in %s pages from %s", positionCount, pages.size(), + workerClient.getTaskUri(taskId)); pageBuffer.addAll(pages); bufferMemoryBytes.addAndGet(bytes); long nextToken = pagesResponse.getNextToken(); if (pages.size() > 0) { - workerClient.acknowledgeResultsAsync(nextToken); + workerClient.acknowledgeResultsAsync(taskId, nextToken); } token = nextToken; if (pagesResponse.isClientComplete()) { completed = true; - workerClient.abortResultsAsync(); + workerClient.abortResultsAsync(taskId); if (scheduledFuture != null) { scheduledFuture.cancel(false); } @@ -194,7 +200,7 @@ private void onSuccess(PageBufferClient.PagesResponse pagesResponse) private void onFailure(Throwable t) { - workerClient.abortResultsAsync(); + workerClient.abortResultsAsync(taskId); stop(false); lastException.set(t); synchronized (taskHasResult) { diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/NativeExecutionTask.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/NativeExecutionTask.java index fa08889dd8241..3fc810076c1ab 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/NativeExecutionTask.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/NativeExecutionTask.java @@ -15,6 +15,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.Session; +import com.facebook.presto.execution.TaskId; import com.facebook.presto.execution.TaskInfo; import com.facebook.presto.execution.TaskManagerConfig; import com.facebook.presto.execution.TaskSource; @@ -35,6 +36,7 @@ import static com.facebook.presto.execution.TaskState.CANCELED; import static com.facebook.presto.execution.TaskState.FAILED; import static com.facebook.presto.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; /** @@ -53,6 +55,7 @@ public class NativeExecutionTask { private static final Logger log = Logger.get(NativeExecutionTask.class); + private final TaskId taskId; private final Session session; private final PlanFragment planFragment; private final OutputBuffers outputBuffers; @@ -67,6 +70,7 @@ public class NativeExecutionTask private final Object taskFinishedOrHasResult = new Object(); public NativeExecutionTask( + TaskId taskId, Session session, PrestoSparkHttpTaskClient workerClient, PlanFragment planFragment, @@ -77,6 +81,7 @@ public NativeExecutionTask( ScheduledExecutorService scheduledExecutorService, TaskManagerConfig taskManagerConfig) { + this.taskId = requireNonNull(taskId, "taskId is null"); this.session = requireNonNull(session, "session is null"); this.planFragment = requireNonNull(planFragment, "planFragment is null"); this.tableWriteInfo = requireNonNull(tableWriteInfo, "tableWriteInfo is null"); @@ -88,12 +93,14 @@ public NativeExecutionTask( requireNonNull(taskManagerConfig, "taskManagerConfig is null"); requireNonNull(scheduledExecutorService, "scheduledExecutorService is null"); this.taskInfoFetcher = new HttpNativeExecutionTaskInfoFetcher( + taskId, scheduledExecutorService, this.workerClient, taskManagerConfig.getInfoUpdateInterval(), taskFinishedOrHasResult); if (!shuffleWriteInfo.isPresent()) { this.taskResultFetcher = Optional.of(new HttpNativeExecutionTaskResultFetcher( + this.taskId, scheduledExecutorService, this.workerClient, taskFinishedOrHasResult)); @@ -156,7 +163,7 @@ public TaskInfo start() // We do not start taskInfo fetcher for failed tasks if (!ImmutableList.of(CANCELED, FAILED, ABORTED).contains(taskInfo.getTaskStatus().getState())) { - log.info("Starting TaskInfoFetcher and TaskResultFetcher."); + log.info(format("Starting TaskInfoFetcher and TaskResultFetcher for %s.", taskId.toString())); taskResultFetcher.ifPresent(fetcher -> fetcher.start()); taskInfoFetcher.start(); } @@ -171,12 +178,13 @@ public void stop(boolean success) { taskInfoFetcher.stop(); taskResultFetcher.ifPresent(fetcher -> fetcher.stop(success)); - workerClient.abortResultsAsync(); + workerClient.abortResultsAsync(taskId); } private TaskInfo sendUpdateRequest() { return workerClient.updateTask( + taskId, sources, planFragment, tableWriteInfo, diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/NativeExecutionTaskFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/NativeExecutionTaskFactory.java index 33f1fcc53546b..d209096c07aaa 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/NativeExecutionTaskFactory.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/NativeExecutionTaskFactory.java @@ -95,7 +95,6 @@ public NativeExecutionTask createNativeExecutionTask( : Optional.empty(); PrestoSparkHttpTaskClient workerClient = new PrestoSparkHttpTaskClient( httpClient, - taskId, location, taskInfoCodec, planFragmentCodec, @@ -105,6 +104,7 @@ public NativeExecutionTask createNativeExecutionTask( scheduledExecutorService, queryManagerConfig.getRemoteTaskMaxErrorDuration()); return new NativeExecutionTask( + taskId, session, workerClient, fragment, diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/PrestoSparkNativeTaskExecutorFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/PrestoSparkNativeTaskExecutorFactory.java index 30a3359c4c3de..09b56a833e87b 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/PrestoSparkNativeTaskExecutorFactory.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/execution/task/PrestoSparkNativeTaskExecutorFactory.java @@ -313,7 +313,7 @@ public IPrestoSparkTaskExecutor doCreate( Duration terminateWithCoreTimeout = getNativeTerminateWithCoreTimeout(session); try { // 3. Submit the task to cpp process for execution - log.info("Submitting native execution task "); + log.info(format("Submitting native execution task. taskId %s", taskId.toString())); NativeExecutionTask task = nativeExecutionTaskFactory.createNativeExecutionTask( session, nativeExecutionProcess.getLocation(), diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/http/TestPrestoSparkHttpClient.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/http/TestPrestoSparkHttpClient.java index f2dfe3d1b40ca..fcbe8caeb4855 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/http/TestPrestoSparkHttpClient.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/http/TestPrestoSparkHttpClient.java @@ -150,6 +150,7 @@ public void testResultGet() PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId); ListenableFuture future = workerClient.getResults( + taskId, 0, new DataSize(32, MEGABYTE)); try { @@ -170,19 +171,28 @@ public void testResultAcknowledge() TaskId taskId = new TaskId("testid", 0, 0, 0, 0); PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId); - workerClient.acknowledgeResultsAsync(1); + workerClient.acknowledgeResultsAsync(taskId, 1); } private PrestoSparkHttpTaskClient createWorkerClient(TaskId taskId) { - return createWorkerClient(taskId, new TestingOkHttpClient(scheduledExecutorService, new TestingResponseManager(taskId.toString()))); + return new PrestoSparkHttpTaskClient( + new TestingOkHttpClient(scheduledExecutorService, + new TestingResponseManager(taskId.toString())), + BASE_URI, + TASK_INFO_JSON_CODEC, + PLAN_FRAGMENT_JSON_CODEC, + TASK_UPDATE_REQUEST_JSON_CODEC, + new Duration(1, TimeUnit.SECONDS), + scheduledExecutorService, + scheduledExecutorService, + new Duration(1, TimeUnit.SECONDS)); } - private PrestoSparkHttpTaskClient createWorkerClient(TaskId taskId, TestingOkHttpClient httpClient) + private PrestoSparkHttpTaskClient createWorkerClient(TestingOkHttpClient httpClient) { return new PrestoSparkHttpTaskClient( httpClient, - taskId, BASE_URI, TASK_INFO_JSON_CODEC, PLAN_FRAGMENT_JSON_CODEC, @@ -193,14 +203,19 @@ private PrestoSparkHttpTaskClient createWorkerClient(TaskId taskId, TestingOkHtt new Duration(1, TimeUnit.SECONDS)); } - HttpNativeExecutionTaskResultFetcher createResultFetcher(PrestoSparkHttpTaskClient workerClient) + HttpNativeExecutionTaskResultFetcher createResultFetcher( + TaskId taskId, + PrestoSparkHttpTaskClient workerClient) { - return createResultFetcher(workerClient, new Object()); + return createResultFetcher(taskId, workerClient, new Object()); } - HttpNativeExecutionTaskResultFetcher createResultFetcher(PrestoSparkHttpTaskClient workerClient, Object lock) + HttpNativeExecutionTaskResultFetcher createResultFetcher( + TaskId taskId, + PrestoSparkHttpTaskClient workerClient, Object lock) { return new HttpNativeExecutionTaskResultFetcher( + taskId, scheduledExecutorService, workerClient, lock); @@ -212,7 +227,7 @@ public void testResultAbort() TaskId taskId = new TaskId("testid", 0, 0, 0, 0); PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId); - ListenableFuture future = workerClient.abortResultsAsync(); + ListenableFuture future = workerClient.abortResultsAsync(taskId); try { future.get(); } @@ -229,7 +244,7 @@ public void testGetTaskInfo() PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId); try { - TaskInfo taskInfo = workerClient.getTaskInfo(); + TaskInfo taskInfo = workerClient.getTaskInfo(taskId); assertEquals(taskInfo.getTaskId().toString(), taskId.toString()); } catch (Exception e) { @@ -249,6 +264,7 @@ public void testUpdateTask() try { TaskInfo taskInfo = workerClient.updateTask( + taskId, sources, createPlanFragment(), new TableWriteInfo(Optional.empty(), Optional.empty()), @@ -269,9 +285,11 @@ public void testUpdateTaskUnexpectedResponse() { TaskId taskId = new TaskId("testid", 0, 0, 0, 0); PrestoSparkHttpTaskClient workerClient = createWorkerClient( - taskId, - new TestingOkHttpClient(scheduledExecutorService, new TestingResponseManager(taskId.toString(), new UnexpectedResponseTaskInfoRetryResponseManager()))); + new TestingOkHttpClient(scheduledExecutorService, + new TestingResponseManager(taskId.toString(), + new UnexpectedResponseTaskInfoRetryResponseManager()))); assertThatThrownBy(() -> workerClient.updateTask( + taskId, new ArrayList<>(), createPlanFragment(), new TableWriteInfo(Optional.empty(), Optional.empty()), @@ -288,9 +306,11 @@ public void testUpdateTaskWithRetries() { TaskId taskId = new TaskId("testid", 0, 0, 0, 0); PrestoSparkHttpTaskClient workerClient = createWorkerClient( - taskId, - new TestingOkHttpClient(scheduledExecutorService, new TestingResponseManager(taskId.toString(), new FailureRetryTaskInfoResponseManager(2)))); + new TestingOkHttpClient(scheduledExecutorService, + new TestingResponseManager(taskId.toString(), + new FailureRetryTaskInfoResponseManager(2)))); workerClient.updateTask( + taskId, new ArrayList<>(), createPlanFragment(), new TableWriteInfo(Optional.empty(), Optional.empty()), @@ -362,7 +382,8 @@ public void testResultFetcher() TaskId taskId = new TaskId("testid", 0, 0, 0, 0); PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId); - HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient); + HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(taskId, + workerClient); taskResultFetcher.start(); try { List pages = new ArrayList<>(); @@ -401,7 +422,6 @@ public void testResultFetcherMultipleNonEmptyResults() int serializedPageSize = (int) new DataSize(1, MEGABYTE).toBytes(); int numPages = 10; PrestoSparkHttpTaskClient workerClient = createWorkerClient( - taskId, new TestingOkHttpClient( scheduledExecutorService, new TestingResponseManager(taskId.toString(), new TestingResponseManager.TestingResultResponseManager() @@ -409,7 +429,8 @@ public void testResultFetcherMultipleNonEmptyResults() private int requestCount; @Override - public Response createResultResponse(String taskId, Request request) + public Response createResultResponse(String taskId, + Request request) throws PageTransportErrorException { requestCount++; @@ -439,7 +460,8 @@ else if (requestCount == numPages) { } } }))); - HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient); + HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(taskId, + workerClient); taskResultFetcher.start(); try { List pages = fetchResults(taskResultFetcher, numPages); @@ -517,13 +539,13 @@ public void testResultFetcherExceedingBufferLimit() new BreakingLimitResponseManager(serializedPageSize, numPages); PrestoSparkHttpTaskClient workerClient = createWorkerClient( - taskId, new TestingOkHttpClient( scheduledExecutorService, new TestingResponseManager( taskId.toString(), breakingLimitResponseManager))); - HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient); + HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(taskId, + workerClient); taskResultFetcher.start(); try { Optional page = Optional.empty(); @@ -635,11 +657,11 @@ public void testResultFetcherTransportErrorRecovery() new TimeoutResponseManager(serializedPageSize, numPages, numTransportErrors); PrestoSparkHttpTaskClient workerClient = createWorkerClient( - taskId, new TestingOkHttpClient( scheduledExecutorService, new TestingResponseManager(taskId.toString(), timeoutResponseManager))); - HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient); + HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(taskId, + workerClient); taskResultFetcher.start(); try { List pages = fetchResults(taskResultFetcher, numPages); @@ -662,11 +684,12 @@ public void testResultFetcherTransportErrorFail() TaskId taskId = new TaskId("testid", 0, 0, 0, 0); PrestoSparkHttpTaskClient workerClient = createWorkerClient( - taskId, new TestingOkHttpClient( scheduledExecutorService, - new TestingResponseManager(taskId.toString(), new TimeoutResponseManager(0, 10, 10)))); - HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient); + new TestingResponseManager(taskId.toString(), + new TimeoutResponseManager(0, 10, 10)))); + HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(taskId, + workerClient); taskResultFetcher.start(); try { for (int i = 0; i < 1_000; ++i) { @@ -685,12 +708,13 @@ public void testResultFetcherPrestoException() { TaskId taskId = new TaskId("testid", 0, 0, 0, 0); PrestoSparkHttpTaskClient workerClient = createWorkerClient( - taskId, new TestingOkHttpClient( scheduledExecutorService, - new TestingResponseManager(taskId.toString(), new PrestoExceptionResponseManager()))); + new TestingResponseManager(taskId.toString(), + new PrestoExceptionResponseManager()))); Object monitor = new Object(); - HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient, monitor); + HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(taskId, + workerClient, monitor); taskResultFetcher.start(); synchronized (monitor) { try { @@ -713,7 +737,8 @@ public void testResultFetcherWaitOnSignal() Object lock = new Object(); PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId); - HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(workerClient, lock); + HttpNativeExecutionTaskResultFetcher taskResultFetcher = createResultFetcher(taskId, + workerClient, lock); taskResultFetcher.start(); try { synchronized (lock) { @@ -735,7 +760,8 @@ public void testInfoFetcher() TaskId taskId = new TaskId("testid", 0, 0, 0, 0); Duration fetchInterval = new Duration(1, TimeUnit.SECONDS); - HttpNativeExecutionTaskInfoFetcher taskInfoFetcher = createTaskInfoFetcher(taskId, new TestingResponseManager(taskId.toString())); + HttpNativeExecutionTaskInfoFetcher taskInfoFetcher = createTaskInfoFetcher(taskId, + new TestingResponseManager(taskId.toString())); assertFalse(taskInfoFetcher.getTaskInfo().isPresent()); taskInfoFetcher.start(); try { @@ -792,12 +818,14 @@ public void testInfoFetcherUnexpectedResponse() Object monitor = new Object(); HttpNativeExecutionTaskInfoFetcher taskInfoFetcher = createTaskInfoFetcher( taskId, - new TestingResponseManager(taskId.toString(), new UnexpectedResponseTaskInfoRetryResponseManager()), + new TestingResponseManager(taskId.toString(), + new UnexpectedResponseTaskInfoRetryResponseManager()), new Duration(5, TimeUnit.SECONDS), monitor); taskInfoFetcher.start(); synchronized (monitor) { - while (taskInfoFetcher.getLastException().get() == null && !taskInfoFetcher.getTaskInfo().isPresent()) { + while (taskInfoFetcher.getLastException().get() == null + && !taskInfoFetcher.getTaskInfo().isPresent()) { monitor.wait(); } } @@ -812,7 +840,8 @@ public void testInfoFetcherWaitOnSignal() TaskId taskId = new TaskId("testid", 0, 0, 0, 0); Object lock = new Object(); - HttpNativeExecutionTaskInfoFetcher taskInfoFetcher = createTaskInfoFetcher(taskId, new TestingResponseManager(taskId.toString(), TaskState.FINISHED), lock); + HttpNativeExecutionTaskInfoFetcher taskInfoFetcher = createTaskInfoFetcher(taskId, + new TestingResponseManager(taskId.toString(), TaskState.FINISHED), lock); assertFalse(taskInfoFetcher.getTaskInfo().isPresent()); taskInfoFetcher.start(); try { @@ -850,7 +879,8 @@ public void testNativeExecutionTask() NativeExecutionTaskFactory taskFactory = new NativeExecutionTaskFactory( new TestingOkHttpClient( scheduledExecutorService, - new TestingResponseManager(taskId.toString(), new TimeoutResponseManager(0, 10, 0))), + new TestingResponseManager(taskId.toString(), + new TimeoutResponseManager(0, 10, 0))), scheduledExecutorService, scheduledExecutorService, TASK_INFO_JSON_CODEC, @@ -912,20 +942,25 @@ private NativeExecutionProcess createNativeExecutionProcess( Optional.empty()); } - private HttpNativeExecutionTaskInfoFetcher createTaskInfoFetcher(TaskId taskId, TestingResponseManager testingResponseManager) + private HttpNativeExecutionTaskInfoFetcher createTaskInfoFetcher(TaskId taskId, + TestingResponseManager testingResponseManager) { return createTaskInfoFetcher(taskId, testingResponseManager, new Duration(1, TimeUnit.MINUTES), new Object()); } - private HttpNativeExecutionTaskInfoFetcher createTaskInfoFetcher(TaskId taskId, TestingResponseManager testingResponseManager, Object lock) + private HttpNativeExecutionTaskInfoFetcher createTaskInfoFetcher(TaskId taskId, + TestingResponseManager testingResponseManager, Object lock) { return createTaskInfoFetcher(taskId, testingResponseManager, new Duration(1, TimeUnit.MINUTES), lock); } - private HttpNativeExecutionTaskInfoFetcher createTaskInfoFetcher(TaskId taskId, TestingResponseManager testingResponseManager, Duration maxErrorDuration, Object lock) + private HttpNativeExecutionTaskInfoFetcher createTaskInfoFetcher(TaskId taskId, + TestingResponseManager testingResponseManager, Duration maxErrorDuration, Object lock) { - PrestoSparkHttpTaskClient workerClient = createWorkerClient(taskId, new TestingOkHttpClient(scheduledExecutorService, testingResponseManager)); + PrestoSparkHttpTaskClient workerClient = createWorkerClient( + new TestingOkHttpClient(scheduledExecutorService, testingResponseManager)); return new HttpNativeExecutionTaskInfoFetcher( + taskId, scheduledExecutorService, workerClient, new Duration(1, TimeUnit.SECONDS), @@ -939,7 +974,8 @@ public static class TestingOkHttpClient private final ScheduledExecutorService executor; private final TestingResponseManager responseManager; - public TestingOkHttpClient(ScheduledExecutorService executor, TestingResponseManager responseManager) + public TestingOkHttpClient(ScheduledExecutorService executor, + TestingResponseManager responseManager) { this.executor = executor; this.responseManager = responseManager; @@ -961,7 +997,8 @@ public static class TestingCall private final TestingResponseManager responseManager; private boolean executed; - public TestingCall(Request request, ScheduledExecutorService executor, TestingResponseManager responseManager) + public TestingCall(Request request, ScheduledExecutorService executor, + TestingResponseManager responseManager) { this.request = request; this.executor = executor; @@ -979,7 +1016,8 @@ public Request request() @Override public Response execute() { - throw new UnsupportedOperationException("TestingCall should use enqueue() method for proper testing"); + throw new UnsupportedOperationException( + "TestingCall should use enqueue() method for proper testing"); } public Response executeAndGetTestingResponse() diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/nativeprocess/TestHttpNativeExecutionTaskInfoFetcher.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/nativeprocess/TestHttpNativeExecutionTaskInfoFetcher.java index eb7df66fb2f4c..dd32d50da0f05 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/nativeprocess/TestHttpNativeExecutionTaskInfoFetcher.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/execution/nativeprocess/TestHttpNativeExecutionTaskInfoFetcher.java @@ -36,10 +36,14 @@ public class TestHttpNativeExecutionTaskInfoFetcher { private static final URI BASE_URI = URI.create("http://localhost"); private static final TaskId TEST_TASK_ID = TaskId.valueOf("test.0.0.0.0"); - private static final JsonCodec TASK_INFO_JSON_CODEC = JsonCodec.jsonCodec(TaskInfo.class); - private static final JsonCodec PLAN_FRAGMENT_JSON_CODEC = JsonCodec.jsonCodec(PlanFragment.class); - private static final JsonCodec TASK_UPDATE_REQUEST_JSON_CODEC = JsonCodec.jsonCodec(BatchTaskUpdateRequest.class); - private static final ScheduledExecutorService updateScheduledExecutor = newScheduledThreadPool(4); + private static final JsonCodec TASK_INFO_JSON_CODEC = JsonCodec.jsonCodec( + TaskInfo.class); + private static final JsonCodec PLAN_FRAGMENT_JSON_CODEC = JsonCodec.jsonCodec( + PlanFragment.class); + private static final JsonCodec TASK_UPDATE_REQUEST_JSON_CODEC = JsonCodec.jsonCodec( + BatchTaskUpdateRequest.class); + private static final ScheduledExecutorService updateScheduledExecutor = newScheduledThreadPool( + 4); @Test public void testNativeExecutionTaskFailsWhenProcessCrashes() @@ -50,7 +54,6 @@ public void testNativeExecutionTaskFailsWhenProcessCrashes() new TestPrestoSparkHttpClient.TestingResponseManager( TEST_TASK_ID.toString(), new TestPrestoSparkHttpClient.TestingResponseManager.CrashingTaskInfoResponseManager())), - TEST_TASK_ID, BASE_URI, TASK_INFO_JSON_CODEC, PLAN_FRAGMENT_JSON_CODEC, @@ -64,6 +67,7 @@ public void testNativeExecutionTaskFailsWhenProcessCrashes() Object taskFinishedOrLostSignal = new Object(); HttpNativeExecutionTaskInfoFetcher taskInfoFetcher = new HttpNativeExecutionTaskInfoFetcher( + TEST_TASK_ID, updateScheduledExecutor, workerClient, new Duration(1, TimeUnit.SECONDS),