Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public class HttpNativeExecutionTaskInfoFetcher
private final ScheduledExecutorService errorRetryScheduledExecutor;
private final AtomicReference<RuntimeException> lastException = new AtomicReference<>();
private final Duration maxErrorDuration;
private final Object taskFinished;

@GuardedBy("this")
private ScheduledFuture<?> scheduledFuture;
Expand All @@ -68,7 +69,8 @@ public HttpNativeExecutionTaskInfoFetcher(
PrestoSparkHttpTaskClient workerClient,
Executor executor,
Duration infoFetchInterval,
Duration maxErrorDuration)
Duration maxErrorDuration,
Object taskFinished)
{
this.workerClient = requireNonNull(workerClient, "workerClient is null");
this.updateScheduledExecutor = requireNonNull(updateScheduledExecutor, "updateScheduledExecutor is null");
Expand All @@ -84,6 +86,7 @@ public HttpNativeExecutionTaskInfoFetcher(
maxErrorDuration,
errorRetryScheduledExecutor,
"getting taskInfo from native process");
this.taskFinished = requireNonNull(taskFinished, "taskFinished is null");
}

public void start()
Expand All @@ -101,6 +104,11 @@ public void onSuccess(BaseResponse<TaskInfo> result)
{
log.debug("TaskInfoCallback success %s", result.getValue().getTaskId());
taskInfo.set(result.getValue());
if (result.getValue().getTaskStatus().getState().isDone()) {
synchronized (taskFinished) {
taskFinished.notifyAll();
}
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,20 @@ public class HttpNativeExecutionTaskResultFetcher
private final PrestoSparkHttpTaskClient workerClient;
private final LinkedBlockingDeque<SerializedPage> pageBuffer = new LinkedBlockingDeque<>();
private final AtomicLong bufferMemoryBytes;
private final Object taskHasResult;

private ScheduledFuture<?> schedulerFuture;
private boolean started;

public HttpNativeExecutionTaskResultFetcher(
ScheduledExecutorService scheduler,
PrestoSparkHttpTaskClient workerClient)
PrestoSparkHttpTaskClient workerClient,
Object taskHasResult)
{
this.scheduler = requireNonNull(scheduler, "scheduler is null");
this.workerClient = requireNonNull(workerClient, "workerClient is null");
this.bufferMemoryBytes = new AtomicLong();
this.taskHasResult = requireNonNull(taskHasResult, "taskHasResult is null");
}

public CompletableFuture<Void> start()
Expand All @@ -85,7 +88,8 @@ public CompletableFuture<Void> start()
workerClient,
future,
pageBuffer,
bufferMemoryBytes),
bufferMemoryBytes,
taskHasResult),
0,
(long) FETCH_INTERVAL.getValue(),
FETCH_INTERVAL.getUnit());
Expand Down Expand Up @@ -125,6 +129,11 @@ public Optional<SerializedPage> pollPage()
return Optional.empty();
}

public boolean hasPage()
{
return !pageBuffer.isEmpty();
}

private static class HttpNativeExecutionTaskResultFetcherRunner
implements Runnable
{
Expand All @@ -135,6 +144,7 @@ private static class HttpNativeExecutionTaskResultFetcherRunner
private final LinkedBlockingDeque<SerializedPage> pageBuffer;
private final AtomicLong bufferMemoryBytes;
private final CompletableFuture<Void> future;
private final Object taskFinishedOrHasResult;

private int transportErrorRetries;
private long token;
Expand All @@ -143,14 +153,16 @@ public HttpNativeExecutionTaskResultFetcherRunner(
PrestoSparkHttpTaskClient client,
CompletableFuture<Void> future,
LinkedBlockingDeque<SerializedPage> pageBuffer,
AtomicLong bufferMemoryBytes)
AtomicLong bufferMemoryBytes,
Object taskFinishedOrHasResult)
{
this.client = requireNonNull(client, "client is null");
this.future = requireNonNull(future, "future is null");
this.pageBuffer = requireNonNull(pageBuffer, "pageBuffer is null");
this.bufferMemoryBytes = requireNonNull(
bufferMemoryBytes,
"bufferMemoryBytes is null");
this.taskFinishedOrHasResult = requireNonNull(taskFinishedOrHasResult, "taskFinishedOrHasResult is null");
}

@Override
Expand Down Expand Up @@ -184,6 +196,11 @@ public void run()
client.abortResults();
future.complete(null);
}
if (!pages.isEmpty()) {
synchronized (taskFinishedOrHasResult) {
taskFinishedOrHasResult.notifyAll();
}
}
}
catch (InterruptedException e) {
if (!future.isDone()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public class NativeExecutionTask
private final HttpNativeExecutionTaskInfoFetcher taskInfoFetcher;
// Results will be fetched only if not written to shuffle.
private final Optional<HttpNativeExecutionTaskResultFetcher> taskResultFetcher;
private final Object taskFinishedOrHasResult = new Object();

public NativeExecutionTask(
Session session,
Expand Down Expand Up @@ -116,11 +117,13 @@ public NativeExecutionTask(
this.workerClient,
this.executor,
taskManagerConfig.getInfoUpdateInterval(),
queryManagerConfig.getRemoteTaskMaxErrorDuration());
queryManagerConfig.getRemoteTaskMaxErrorDuration(),
taskFinishedOrHasResult);
if (!shuffleWriteInfo.isPresent()) {
this.taskResultFetcher = Optional.of(new HttpNativeExecutionTaskResultFetcher(
updateScheduledExecutor,
this.workerClient));
this.workerClient,
taskFinishedOrHasResult));
}
else {
this.taskResultFetcher = Optional.empty();
Expand All @@ -139,6 +142,17 @@ public Optional<TaskInfo> getTaskInfo()
return taskInfoFetcher.getTaskInfo();
}

public boolean isTaskDone()
{
Optional<TaskInfo> taskInfo = getTaskInfo();
return taskInfo.isPresent() && taskInfo.get().getTaskStatus().getState().isDone();
}

public Object getTaskFinishedOrHasResult()
{
return taskFinishedOrHasResult;
}

/**
* Blocking call to poll from result fetcher buffer. Blocks until content becomes available in the buffer, or until timeout is hit.
*
Expand All @@ -153,6 +167,11 @@ public Optional<SerializedPage> pollResult()
return taskResultFetcher.get().pollPage();
}

public boolean hasResult()
{
return taskResultFetcher.isPresent() && taskResultFetcher.get().hasPage();
}

/**
* Blocking call to create and start native task.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import com.facebook.presto.execution.TaskInfo;
import com.facebook.presto.execution.TaskSource;
import com.facebook.presto.execution.TaskState;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.RemoteTransactionHandle;
import com.facebook.presto.metadata.SessionPropertyManager;
import com.facebook.presto.metadata.Split;
Expand Down Expand Up @@ -97,18 +96,18 @@
* It will send necessary metadata (e.g, plan fragment, session properties etc.) as a part of
* BatchTaskUpdateRequest. It will poll the remote CPP task for status and results (pages/data if applicable)
* and send these back to the Spark's RDD api
*
* <p>
* PrestoSparkNativeTaskExecutorFactory is singleton instantiated once per executor.
*
* <p>
* For every task it receives, it does the following
* 1. Create the Native execution Process (NativeTaskExecutionFactory) ensure that is it created only once.
* 2. Serialize and pass the planFragment, source-metadata (taskSources), sink-metadata (tableWriteInfo or shuffleWriteInfo)
* and submit a nativeExecutionTask.
* and submit a nativeExecutionTask.
* 3. Return Iterator to sparkRDD layer. RDD execution will call the .next() methods, which will
* 3.a Call {@link NativeExecutionTask}'s pollResult() to retrieve {@link SerializedPage} back from external process.
* 3.b If no more output is available, then check if task has finished successfully or with exception
* If task finished with exception - fail the spark task (throw exception)
* IF task finished successfully - collect statistics through taskInfo object and add to accumulator
* 3.a Call {@link NativeExecutionTask}'s pollResult() to retrieve {@link SerializedPage} back from external process.
* 3.b If no more output is available, then check if task has finished successfully or with exception
* If task finished with exception - fail the spark task (throw exception)
* IF task finished successfully - collect statistics through taskInfo object and add to accumulator
*/
public class PrestoSparkNativeTaskExecutorFactory
implements IPrestoSparkTaskExecutorFactory
Expand All @@ -123,7 +122,6 @@ public class PrestoSparkNativeTaskExecutorFactory
private static final TaskId DUMMY_TASK_ID = TaskId.valueOf("remotesourcetaskid.0.0.0.0");

private final SessionPropertyManager sessionPropertyManager;
private final FunctionAndTypeManager functionAndTypeManager;
private final JsonCodec<PrestoSparkTaskDescriptor> taskDescriptorJsonCodec;
private final Codec<TaskSource> taskSourceCodec;
private final Codec<TaskInfo> taskInfoCodec;
Expand All @@ -137,7 +135,6 @@ public class PrestoSparkNativeTaskExecutorFactory
@Inject
public PrestoSparkNativeTaskExecutorFactory(
SessionPropertyManager sessionPropertyManager,
FunctionAndTypeManager functionAndTypeManager,
JsonCodec<PrestoSparkTaskDescriptor> taskDescriptorJsonCodec,
Codec<TaskSource> taskSourceCodec,
Codec<TaskInfo> taskInfoCodec,
Expand All @@ -149,7 +146,6 @@ public PrestoSparkNativeTaskExecutorFactory(
PrestoSparkShuffleInfoTranslator shuffleInfoTranslator)
{
this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null");
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null");
this.taskDescriptorJsonCodec = requireNonNull(taskDescriptorJsonCodec, "sparkTaskDescriptorJsonCodec is null");
this.taskSourceCodec = requireNonNull(taskSourceCodec, "taskSourceCodec is null");
this.taskInfoCodec = requireNonNull(taskInfoCodec, "taskInfoCodec is null");
Expand Down Expand Up @@ -255,7 +251,7 @@ public <T extends PrestoSparkTaskOutput> IPrestoSparkTaskExecutor<T> doCreate(
TaskInfo taskInfo = task.start();

// task creation might have failed
processTaskInfoForErrors(taskInfo);
processTaskInfoForErrorsOrCompletion(taskInfo);
// 4. return output to spark RDD layer
return new PrestoSparkNativeTaskOutputIterator<>(task, outputType, taskInfoCollector, taskInfoCodec, executionExceptionFactory);
}
Expand Down Expand Up @@ -286,7 +282,7 @@ private static void completeTask(CollectionAccumulator<SerializedTaskInfo> taskI
PrestoSparkStatsCollectionUtils.collectMetrics(taskInfoOptional.get());
}

private static void processTaskInfoForErrors(TaskInfo taskInfo)
private static void processTaskInfoForErrorsOrCompletion(TaskInfo taskInfo)
{
if (!taskInfo.getTaskStatus().getState().isDone()) {
log.info("processTaskInfoForErrors: task is not done yet.. %s", taskInfo);
Expand Down Expand Up @@ -414,8 +410,9 @@ public PrestoSparkNativeTaskOutputIterator(
/**
* This function is called by Spark's RDD layer to check if there are output pages
* There are 2 scenarios
* 1. ShuffleMap Task - Always returns false. But the internal function calls do all the work needed
* 2. Result Task - True until pages are available. False once all pages have been extracted
* 1. ShuffleMap Task - Always returns false. But the internal function calls do all the work needed
* 2. Result Task - True until pages are available. False once all pages have been extracted
*
* @return if output is available
*/
@Override
Expand All @@ -425,74 +422,72 @@ public boolean hasNext()
return next.isPresent();
}

/** This function returns the next available page fetched from CPP process
*
* Has 3 main responsibilities
* 1) Busy-wait-for-pages-or-completion
*
* Loop until either of the 3 conditions happen
* * 1. We get a page
* * 2. Task has finished successfully
* * 3. Task has finished with error
*
* For ShuffleMap Task, as of now, the CPP process returns no pages.
* So the loop acts as a wait-for-completion loop and returns an Optional.empty()
* once the task has terminated
*
* For a Result Task, this function will return all the pages and Optional.empty()
* once all the pages have been read and the task has been terminates
*
* 2) Exception handling
* when there are no pages available, the function checks if the task has finished
* with exceptions and throws the appropriate exception back to spark's RDD processing
* layer
*
* 3) Statistics collection
* For both, when the task finished successfully or with exception, it tries to collect
* statistics if TaskInfo object is available
/**
* This function returns the next available page fetched from CPP process
* <p>
* Has 3 main responsibilities
* 1) wait-for-pages-or-completion
* <p>
* The thread running this method will wait until either of the 3 conditions happen
* * 1. We get a page
* * 2. Task has finished successfully
* * 3. Task has finished with error
* <p>
* For ShuffleMap Task, as of now, the CPP process returns no pages.
* So the thread will be in WAITING state till the CPP task is done and returns an Optional.empty()
* once the task has terminated
* <p>
* For a Result Task, this function will return pages retrieved from CPP side once we got them.
* Once all the pages have been read and the task has been terminates
* <p>
* 2) Exception handling
* The function also checks if the task has finished
* with exceptions and throws the appropriate exception back to spark's RDD processing
* layer
* <p>
* 3) Statistics collection
* For both, when the task finished successfully or with exception, it tries to collect
* statistics if TaskInfo object is available
*
* @return Optional<SerializedPage> outputPage
*/
private Optional<SerializedPage> computeNext()
{
// A while(true) loop is not desirable, but in this case we cannot avoid
// it because of Spark'sRDD contract, which is that this iterator either
// returns data or is complete. It CANNOT return null.
// While the remote task is still running and there is no output pages,
// we need to simulate a busy-loop to avoid returning null.
while (true) {
try {
// For ShuffleMap Task, this will always return Optional.empty()
Optional<SerializedPage> pageOptional = nativeExecutionTask.pollResult();

if (pageOptional.isPresent()) {
return pageOptional;
try {
Copy link
Collaborator

@shrinidhijoshi shrinidhijoshi Jun 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can simplify this code by writing 2 different iterators - 1. shuffleMap task and 2. for result task,
Reason being

  • Behavior of Shuffle Map task is fundamentally different and is more straight forward (wait for CompletableFuture returned by NativeExecutionTask) Most of the stage in all of our prod workload would use the ShuffleMap iterator as that is the 99% use-case.

I am concerned the multiple wait()/notifyAll() pattern here might be hiding behaviors that are even harder to spot than the current busy-wait bug

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reviews @shrinidhijoshi.

Technically waiting for CompletableFuture (CompletableFuture::get()) is also a busy-waiting which is a not good idea to have it on the main thread.

On the other hand, for our scenarios - asynchronous communication between different threads/processes, I believe the wait()/notifyAll() is the right java primitives to use if we can't use more built-in high level blocks (e.g. blockingQueue etc.)

Object taskFinishedOrHasResult = nativeExecutionTask.getTaskFinishedOrHasResult();
// Blocking wait if task is still running or hasn't produced any output page
synchronized (taskFinishedOrHasResult) {
while (!nativeExecutionTask.isTaskDone() && !nativeExecutionTask.hasResult()) {
taskFinishedOrHasResult.wait();
}
}

try {
Optional<TaskInfo> taskInfo = nativeExecutionTask.getTaskInfo();
// For ShuffleMap Task, this will always return Optional.empty()
Optional<SerializedPage> pageOptional = nativeExecutionTask.pollResult();

// Case1: Task is still running
if (!taskInfo.isPresent() || !taskInfo.get().getTaskStatus().getState().isDone()) {
continue;
}
if (pageOptional.isPresent()) {
return pageOptional;
}

// Case 2: Task finished with errors captured inside taskInfo
processTaskInfoForErrors(taskInfo.get());
}
catch (RuntimeException ex) {
// For a failed task, if taskInfo is present we still want to log the metrics
completeTask(taskInfoCollectionAccumulator, nativeExecutionTask, taskInfoCodec);
throw executionExceptionFactory.toPrestoSparkExecutionException(ex);
// Double check if current task's already done (since thread could be awoken by either having output or task is done above)
synchronized (taskFinishedOrHasResult) {
while (!nativeExecutionTask.isTaskDone()) {
taskFinishedOrHasResult.wait();
}

// Case3: Task terminated with success
break;
}
catch (InterruptedException e) {
log.error(e);
throw new RuntimeException(e);
}

Optional<TaskInfo> taskInfo = nativeExecutionTask.getTaskInfo();

processTaskInfoForErrorsOrCompletion(taskInfo.get());
}
catch (RuntimeException ex) {
// For a failed task, if taskInfo is present we still want to log the metrics
completeTask(taskInfoCollectionAccumulator, nativeExecutionTask, taskInfoCodec);
throw executionExceptionFactory.toPrestoSparkExecutionException(ex);
}
catch (InterruptedException e) {
log.error(e);
throw new RuntimeException(e);
}

// Reaching here marks the end of task processing
Expand Down
Loading