Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1013,10 +1013,10 @@ private TaskFailureReporter(AtomicReference<DistributedStagesScheduler> distribu
@Override
public void onTaskFailed(TaskId taskId, Throwable failure)
{
if (failure instanceof TrinoException && ((TrinoException) failure).getErrorCode() == REMOTE_TASK_FAILED.toErrorCode()) {
if (failure instanceof TrinoException && REMOTE_TASK_FAILED.toErrorCode().equals(((TrinoException) failure).getErrorCode())) {
// This error indicates that a downstream task was trying to fetch results from an upstream task that is marked as failed
// Instead of failing a downstream task let the coordinator handle and report the failure of an upstream task to ensure correct error reporting
log.info("Task failure discovered while fetching task results: %s", taskId);
log.debug("Task failure discovered while fetching task results: %s", taskId);
return;
}
log.warn(failure, "Reported task failure: %s", taskId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ListMultimap;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
import io.airlift.units.DataSize;
import io.trino.execution.TaskId;
Expand All @@ -35,18 +36,20 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Throwables.throwIfUnchecked;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.util.concurrent.Futures.nonCancellationPropagating;
import static io.trino.operator.RetryPolicy.QUERY;
import static io.trino.operator.RetryPolicy.TASK;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.StandardErrorCode.REMOTE_TASK_FAILED;
import static java.lang.Math.max;
import static java.util.Objects.requireNonNull;

public class DeduplicationExchangeClientBuffer
implements ExchangeClientBuffer
{
private static final Logger log = Logger.get(DeduplicationExchangeClientBuffer.class);

private final Executor executor;
private final long bufferCapacityInBytes;
private final RetryPolicy retryPolicy;
Expand Down Expand Up @@ -255,21 +258,32 @@ private synchronized void checkInputFinished()
return;
}

List<Throwable> failures = failedTasks.entrySet().stream()
.filter(entry -> entry.getKey().getAttemptId() == maxAttemptId)
.map(Map.Entry::getValue)
.collect(toImmutableList());

if (!failures.isEmpty()) {
Throwable failure = null;
for (Throwable taskFailure : failures) {
if (failure == null) {
failure = taskFailure;
}
else if (failure != taskFailure) {
failure.addSuppressed(taskFailure);
}
Throwable failure = null;
for (Map.Entry<TaskId, Throwable> entry : failedTasks.entrySet()) {
TaskId taskId = entry.getKey();
Throwable taskFailure = entry.getValue();

if (taskId.getAttemptId() != maxAttemptId) {
// ignore failures from previous attempts
continue;
}

if (taskFailure instanceof TrinoException && REMOTE_TASK_FAILED.toErrorCode().equals(((TrinoException) taskFailure).getErrorCode())) {
// This error indicates that a downstream task was trying to fetch results from an upstream task that is marked as failed
// Instead of failing a downstream task let the coordinator handle and report the failure of an upstream task to ensure correct error reporting
log.debug("Task failure discovered while fetching task results: %s", taskId);
continue;
}

if (failure == null) {
failure = taskFailure;
}
else if (failure != taskFailure) {
failure.addSuppressed(taskFailure);
}
}

if (failure != null) {
pageBuffer.clear();
bufferRetainedSizeInBytes = 0;
this.failure = failure;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
import io.airlift.units.DataSize;
import io.trino.execution.TaskId;
Expand All @@ -39,6 +40,8 @@
public class StreamingExchangeClientBuffer
implements ExchangeClientBuffer
{
private static final Logger log = Logger.get(StreamingExchangeClientBuffer.class);

private final Executor executor;
private final long bufferCapacityInBytes;

Expand Down Expand Up @@ -141,8 +144,10 @@ public synchronized void taskFailed(TaskId taskId, Throwable t)
}
checkState(activeTasks.contains(taskId), "taskId not registered: %s", taskId);

if (t instanceof TrinoException && ((TrinoException) t).getErrorCode() == REMOTE_TASK_FAILED.toErrorCode()) {
// let coordinator handle this
if (t instanceof TrinoException && REMOTE_TASK_FAILED.toErrorCode().equals(((TrinoException) t).getErrorCode())) {
// This error indicates that a downstream task was trying to fetch results from an upstream task that is marked as failed
// Instead of failing a downstream task let the coordinator handle and report the failure of an upstream task to ensure correct error reporting
log.debug("Task failure discovered while fetching task results: %s", taskId);
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@
import static io.airlift.slice.Slices.utf8Slice;
import static io.airlift.units.DataSize.Unit.BYTE;
import static io.airlift.units.DataSize.Unit.KILOBYTE;
import static io.trino.spi.StandardErrorCode.REMOTE_TASK_FAILED;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;

public class TestDeduplicationExchangeClientBuffer
Expand Down Expand Up @@ -455,6 +457,36 @@ public void testRemainingBufferCapacity()
}
}

@Test
public void testRemoteTaskFailedError()
{
// fail before noMoreTasks
try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) {
TaskId taskId = createTaskId(0, 0);
buffer.addTask(taskId);
buffer.taskFailed(taskId, new TrinoException(REMOTE_TASK_FAILED, "Remote task failed"));
buffer.noMoreTasks();

assertFalse(buffer.isFinished());
assertFalse(buffer.isFailed());
assertBlocked(buffer.isBlocked());
assertNull(buffer.pollPage());
}

// fail after noMoreTasks
try (ExchangeClientBuffer buffer = new DeduplicationExchangeClientBuffer(directExecutor(), ONE_KB, RetryPolicy.QUERY)) {
TaskId taskId = createTaskId(0, 0);
buffer.addTask(taskId);
buffer.noMoreTasks();
buffer.taskFailed(taskId, new TrinoException(REMOTE_TASK_FAILED, "Remote task failed"));

assertFalse(buffer.isFinished());
assertFalse(buffer.isFailed());
assertBlocked(buffer.isBlocked());
assertNull(buffer.pollPage());
}
}

private static TaskId createTaskId(int partition, int attempt)
{
return new TaskId(new StageId("query", 0), partition, attempt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import io.trino.execution.StageId;
import io.trino.execution.TaskId;
import io.trino.spi.QueryId;
import io.trino.spi.TrinoException;
import org.testng.annotations.Test;

import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.airlift.slice.Slices.utf8Slice;
import static io.airlift.units.DataSize.Unit.KILOBYTE;
import static io.trino.spi.StandardErrorCode.REMOTE_TASK_FAILED;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
Expand Down Expand Up @@ -196,4 +198,30 @@ public void testFutureCancellationDoesNotAffectOtherFutures()
assertTrue(blocked2.isDone());
}
}

@Test
public void testRemoteTaskFailedError()
{
// fail before noMoreTasks
try (ExchangeClientBuffer buffer = new StreamingExchangeClientBuffer(directExecutor(), DataSize.of(1, KILOBYTE))) {
buffer.addTask(TASK_0);
buffer.taskFailed(TASK_0, new TrinoException(REMOTE_TASK_FAILED, "Remote task failed"));
buffer.noMoreTasks();

assertFalse(buffer.isFinished());
assertFalse(buffer.isFailed());
assertNull(buffer.pollPage());
}

// fail after noMoreTasks
try (ExchangeClientBuffer buffer = new StreamingExchangeClientBuffer(directExecutor(), DataSize.of(1, KILOBYTE))) {
buffer.addTask(TASK_0);
buffer.noMoreTasks();
buffer.taskFailed(TASK_0, new TrinoException(REMOTE_TASK_FAILED, "Remote task failed"));

assertFalse(buffer.isFinished());
assertFalse(buffer.isFailed());
assertNull(buffer.pollPage());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -363,17 +363,15 @@ protected void testNonSelect(Optional<Session> session, Optional<String> setupQu
.withCleanupQuery(cleanupQuery)
.experiencing(TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR))
.at(boundaryCoordinatorStage())
// TODO(https://github.com/trinodb/trino/issues/10395 original exception message is lost sometimes
.failsAlways(failure -> failure.hasMessageFindingMatch("\\Q" + FAILURE_INJECTION_MESSAGE + "\\E|Remote task failed.*"));
.failsAlways(failure -> failure.hasMessageContaining(FAILURE_INJECTION_MESSAGE));

assertThatQuery(query)
.withSession(session)
.withSetupQuery(setupQuery)
.withCleanupQuery(cleanupQuery)
.experiencing(TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR))
.at(rootStage())
// TODO(https://github.com/trinodb/trino/issues/10395 original exception message is lost sometimes
.failsAlways(failure -> failure.hasMessageFindingMatch("\\Q" + FAILURE_INJECTION_MESSAGE + "\\E|Remote task failed.*"));
.failsAlways(failure -> failure.hasMessageContaining(FAILURE_INJECTION_MESSAGE));

assertThatQuery(query)
.withSession(session)
Expand Down