Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ public class PrestoSparkHttpTaskClient
private static final Logger log = Logger.get(PrestoSparkHttpTaskClient.class);
private final OkHttpClient httpClient;
private final URI location;
private final URI taskUri;
private final JsonCodec<TaskInfo> taskInfoCodec;
private final JsonCodec<PlanFragment> planFragmentCodec;
private final JsonCodec<BatchTaskUpdateRequest> taskUpdateRequestCodec;
Expand All @@ -109,7 +108,6 @@ public class PrestoSparkHttpTaskClient

public PrestoSparkHttpTaskClient(
OkHttpClient httpClient,
TaskId taskId,
URI location,
JsonCodec<TaskInfo> taskInfoCodec,
JsonCodec<PlanFragment> planFragmentCodec,
Expand All @@ -124,7 +122,6 @@ public PrestoSparkHttpTaskClient(
this.taskInfoCodec = requireNonNull(taskInfoCodec, "taskInfoCodec is null");
this.planFragmentCodec = requireNonNull(planFragmentCodec, "planFragmentCodec is null");
this.taskUpdateRequestCodec = requireNonNull(taskUpdateRequestCodec, "taskUpdateRequestCodec is null");
this.taskUri = createTaskUri(location, taskId);
this.infoRefreshMaxWait = requireNonNull(infoRefreshMaxWait, "infoRefreshMaxWait is null");
this.executor = requireNonNull(executor, "executor is null");
this.scheduledExecutorService = requireNonNull(scheduledExecutorService, "scheduledExecutorService is null");
Expand All @@ -134,7 +131,7 @@ public PrestoSparkHttpTaskClient(
/**
* Get results from a native engine task that ends with none shuffle operator. It always fetches from a single buffer.
*/
public ListenableFuture<PagesResponse> getResults(long token, DataSize maxResponseSize)
public ListenableFuture<PagesResponse> getResults(TaskId taskId, long token, DataSize maxResponseSize)
{
RequestErrorTracker errorTracker = new RequestErrorTracker(
"NativeExecution",
Expand All @@ -145,100 +142,13 @@ public ListenableFuture<PagesResponse> getResults(long token, DataSize maxRespon
scheduledExecutorService,
"sending update request to native process");
SettableFuture<PagesResponse> result = SettableFuture.create();
scheduleGetResultsRequest(prepareGetResultsRequest(token, maxResponseSize), errorTracker, result);
scheduleGetResultsRequest(prepareGetResultsRequest(taskId, token, maxResponseSize), errorTracker, result);
return result;
}

private void scheduleGetResultsRequest(
Request request,
RequestErrorTracker errorTracker,
SettableFuture<PagesResponse> result)
{
ListenableFuture<Void> permitFuture = (ListenableFuture<Void>) errorTracker.acquireRequestPermit();
addCallback(permitFuture, new FutureCallback<Void>() {
@Override
public void onSuccess(Void ignored)
{
errorTracker.startRequest();
httpClient.newCall(request).enqueue(new Callback() {
@Override
public void onFailure(Call call, IOException e)
{
handleGetResultsFailure(e, errorTracker, request, result);
}

@Override
public void onResponse(Call call, Response response)
{
try {
BaseResponse<PagesResponse> baseResponse = new PageResponseHandler().handle(request, response);
if (baseResponse.hasValue()) {
errorTracker.requestSucceeded();
result.set(baseResponse.getValue());
}
else {
Exception exception = baseResponse.getException();
if (exception != null) {
handleGetResultsFailure(exception, errorTracker, request, result);
}
else {
handleGetResultsFailure(new RuntimeException("Empty response without exception"), errorTracker, request, result);
}
}
}
catch (Exception e) {
handleGetResultsFailure(e, errorTracker, request, result);
}
finally {
response.close();
}
}
});
}

@Override
public void onFailure(Throwable t)
{
result.setException(t);
}
}, executor);
}

private void handleGetResultsFailure(Throwable failure, RequestErrorTracker errorTracker,
Request request, SettableFuture<PagesResponse> result)
{
log.info("Received failure response with exception %s", failure);
if (Arrays.stream(failure.getSuppressed()).anyMatch(t -> t instanceof PrestoException)) {
result.setException(failure);
return;
}
try {
errorTracker.requestFailed(failure);
scheduleGetResultsRequest(request, errorTracker, result);
}
catch (Throwable t) {
result.setException(t);
}
}

private Request prepareGetResultsRequest(long token, DataSize maxResponseSize)
public void acknowledgeResultsAsync(TaskId taskId, long nextToken)
{
HttpUrl url = HttpUrl.get(taskUri).newBuilder()
.addPathSegment("results")
.addPathSegment("0")
.addPathSegment(String.valueOf(token))
.build();

return new Request.Builder()
.url(url)
.get()
.addHeader(PRESTO_MAX_SIZE, maxResponseSize.toString())
.build();
}

public void acknowledgeResultsAsync(long nextToken)
{
HttpUrl url = HttpUrl.get(taskUri).newBuilder()
HttpUrl url = HttpUrl.get(getTaskUri(taskId)).newBuilder()
.addPathSegment("results")
.addPathSegment("0")
.addPathSegment(String.valueOf(nextToken))
Expand All @@ -259,9 +169,9 @@ public void acknowledgeResultsAsync(long nextToken)
scheduleVoidRequest(request, new BytesResponseHandler(), errorTracker, result);
}

public ListenableFuture<Void> abortResultsAsync()
public ListenableFuture<Void> abortResultsAsync(TaskId taskId)
{
HttpUrl url = HttpUrl.get(taskUri).newBuilder()
HttpUrl url = HttpUrl.get(getTaskUri(taskId)).newBuilder()
.addPathSegment("results")
.addPathSegment("0")
.build();
Expand All @@ -280,11 +190,11 @@ public ListenableFuture<Void> abortResultsAsync()
return result;
}

public TaskInfo getTaskInfo()
public TaskInfo getTaskInfo(TaskId taskId)
{
Request request = setContentTypeHeaders(new Request.Builder())
.addHeader(PRESTO_MAX_WAIT, infoRefreshMaxWait.toString())
.url(taskUri.toString())
.url(getTaskUri(taskId).toString())
.get()
.build();
ListenableFuture<TaskInfo> future = executeWithRetries(
Expand All @@ -296,6 +206,7 @@ public TaskInfo getTaskInfo()
}

public TaskInfo updateTask(
TaskId taskId,
List<TaskSource> sources,
PlanFragment planFragment,
TableWriteInfo tableWriteInfo,
Expand All @@ -315,7 +226,7 @@ public TaskInfo updateTask(
writeInfo);
BatchTaskUpdateRequest batchTaskUpdateRequest = new BatchTaskUpdateRequest(updateRequest, shuffleWriteInfo, broadcastBasePath);

HttpUrl url = HttpUrl.get(taskUri).newBuilder()
HttpUrl url = HttpUrl.get(getTaskUri(taskId)).newBuilder()
.addPathSegment("batch")
.build();
byte[] requestBody = taskUpdateRequestCodec.toBytes(batchTaskUpdateRequest);
Expand All @@ -336,19 +247,101 @@ public URI getLocation()
return location;
}

public URI getTaskUri()
public URI getTaskUri(TaskId taskId)
{
return taskUri;
return HttpUrl.get(location).newBuilder()
.addPathSegment("v1")
.addPathSegment("task")
.addPathSegment(taskId.toString())
.build()
.uri();
}

private URI createTaskUri(URI baseUri, TaskId taskId)
private void scheduleGetResultsRequest(
Request request,
RequestErrorTracker errorTracker,
SettableFuture<PagesResponse> result)
{
return HttpUrl.get(baseUri).newBuilder()
.addPathSegment("v1")
.addPathSegment("task")
.addPathSegment(taskId.toString())
.build()
.uri();
ListenableFuture<Void> permitFuture = (ListenableFuture<Void>) errorTracker.acquireRequestPermit();
addCallback(permitFuture, new FutureCallback<Void>() {
@Override
public void onSuccess(Void ignored)
{
errorTracker.startRequest();
httpClient.newCall(request).enqueue(new Callback() {
@Override
public void onFailure(Call call, IOException e)
{
handleGetResultsFailure(e, errorTracker, request, result);
}

@Override
public void onResponse(Call call, Response response)
{
try {
BaseResponse<PagesResponse> baseResponse = new PageResponseHandler().handle(request, response);
if (baseResponse.hasValue()) {
errorTracker.requestSucceeded();
result.set(baseResponse.getValue());
}
else {
Exception exception = baseResponse.getException();
if (exception != null) {
handleGetResultsFailure(exception, errorTracker, request, result);
}
else {
handleGetResultsFailure(new RuntimeException("Empty response without exception"), errorTracker, request, result);
}
}
}
catch (Exception e) {
handleGetResultsFailure(e, errorTracker, request, result);
}
finally {
response.close();
}
}
});
}

@Override
public void onFailure(Throwable t)
{
result.setException(t);
}
}, executor);
}

private void handleGetResultsFailure(Throwable failure, RequestErrorTracker errorTracker,
Request request, SettableFuture<PagesResponse> result)
{
log.info("Received failure response with exception %s", failure);
if (Arrays.stream(failure.getSuppressed()).anyMatch(t -> t instanceof PrestoException)) {
result.setException(failure);
return;
}
try {
errorTracker.requestFailed(failure);
scheduleGetResultsRequest(request, errorTracker, result);
}
catch (Throwable t) {
result.setException(t);
}
}

private Request prepareGetResultsRequest(TaskId taskId, long token, DataSize maxResponseSize)
{
HttpUrl url = HttpUrl.get(getTaskUri(taskId)).newBuilder()
.addPathSegment("results")
.addPathSegment("0")
.addPathSegment(String.valueOf(token))
.build();

return new Request.Builder()
.url(url)
.get()
.addHeader(PRESTO_MAX_SIZE, maxResponseSize.toString())
.build();
}

private <T> ListenableFuture<T> executeWithRetries(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.airlift.log.Logger;
import com.facebook.airlift.units.Duration;
import com.facebook.presto.execution.TaskId;
import com.facebook.presto.execution.TaskInfo;
import com.facebook.presto.spark.execution.http.PrestoSparkHttpTaskClient;
import com.google.common.annotations.VisibleForTesting;
Expand All @@ -39,6 +40,7 @@ public class HttpNativeExecutionTaskInfoFetcher
{
private static final Logger log = Logger.get(HttpNativeExecutionTaskInfoFetcher.class);

private final TaskId taskId;
private final PrestoSparkHttpTaskClient workerClient;
private final ScheduledExecutorService updateScheduledExecutor;
private final AtomicReference<TaskInfo> taskInfo = new AtomicReference<>();
Expand All @@ -50,11 +52,13 @@ public class HttpNativeExecutionTaskInfoFetcher
private ScheduledFuture<?> scheduledFuture;

public HttpNativeExecutionTaskInfoFetcher(
TaskId taskId,
ScheduledExecutorService updateScheduledExecutor,
PrestoSparkHttpTaskClient workerClient,
Duration infoFetchInterval,
Object taskFinished)
{
this.taskId = requireNonNull(taskId, "taskId is null");
this.workerClient = requireNonNull(workerClient, "workerClient is null");
this.updateScheduledExecutor = requireNonNull(updateScheduledExecutor, "updateScheduledExecutor is null");
this.infoFetchInterval = requireNonNull(infoFetchInterval, "infoFetchInterval is null");
Expand All @@ -78,7 +82,7 @@ public void stop()
void doGetTaskInfo()
{
try {
TaskInfo result = workerClient.getTaskInfo();
TaskInfo result = workerClient.getTaskInfo(taskId);
onSuccess(result);
}
catch (Throwable t) {
Expand Down
Loading
Loading