Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@
*/
package io.trino.execution.scheduler;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Suppliers;
import com.google.common.base.VerifyException;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Multimap;
import com.google.common.collect.Table;
import com.google.common.math.Quantiles;
import com.google.common.math.Stats;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import com.google.inject.Inject;
Expand All @@ -29,23 +34,28 @@
import io.trino.spi.TrinoException;
import io.trino.sql.planner.plan.PlanNodeId;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import java.util.stream.Stream;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.ImmutableSetMultimap.toImmutableSetMultimap;
import static com.google.common.math.Quantiles.percentiles;
import static io.airlift.units.DataSize.succinctBytes;
import static io.trino.spi.StandardErrorCode.EXCEEDED_TASK_DESCRIPTOR_STORAGE_CAPACITY;
import static java.lang.String.format;
Expand All @@ -57,6 +67,7 @@ public class TaskDescriptorStorage

private final long maxMemoryInBytes;
private final JsonCodec<Split> splitJsonCodec;
private final StorageStats storageStats;

@GuardedBy("this")
private final Map<QueryId, TaskDescriptors> storages = new HashMap<>();
Expand All @@ -75,6 +86,7 @@ public TaskDescriptorStorage(DataSize maxMemory, JsonCodec<Split> splitJsonCodec
{
this.maxMemoryInBytes = maxMemory.toBytes();
this.splitJsonCodec = requireNonNull(splitJsonCodec, "splitJsonCodec is null");
this.storageStats = new StorageStats(Suppliers.memoizeWithExpiration(this::computeStats, 1, TimeUnit.SECONDS));
}

/**
Expand Down Expand Up @@ -188,47 +200,116 @@ private synchronized void updateMemoryReservation(long delta)
}
}

@Managed
public synchronized long getReservedBytes()
@VisibleForTesting
synchronized long getReservedBytes()
{
return reservedBytes;
}

@Managed
@Nested
public StorageStats getStats()
{
// This should not contain materialized values. GuiceMBeanExporter calls it only once during application startup
// and then only @Managed methods all called on that instance.
return storageStats;
}

private synchronized StorageStatsValue computeStats()
{
int queriesCount = storages.size();
long stagesCount = storages.values().stream().mapToLong(TaskDescriptors::getStagesCount).sum();

Quantiles.ScaleAndIndexes percentiles = percentiles().indexes(50, 90, 95);

long queryReservedBytesP50 = 0;
long queryReservedBytesP90 = 0;
long queryReservedBytesP95 = 0;
long queryReservedBytesAvg = 0;
long stageReservedBytesP50 = 0;
long stageReservedBytesP90 = 0;
long stageReservedBytesP95 = 0;
long stageReservedBytesAvg = 0;

if (queriesCount > 0) { // we cannot compute percentiles for empty set

Map<Integer, Double> queryReservedBytesPercentiles = percentiles.compute(
storages.values().stream()
.map(TaskDescriptors::getReservedBytes)
.collect(toImmutableList()));

queryReservedBytesP50 = queryReservedBytesPercentiles.get(50).longValue();
queryReservedBytesP90 = queryReservedBytesPercentiles.get(90).longValue();
queryReservedBytesP95 = queryReservedBytesPercentiles.get(95).longValue();
queryReservedBytesAvg = reservedBytes / queriesCount;

List<Long> storagesReservedBytes = storages.values().stream()
.flatMap(TaskDescriptors::getStagesReservedBytes)
.collect(toImmutableList());

if (!storagesReservedBytes.isEmpty()) {
Map<Integer, Double> stagesReservedBytesPercentiles = percentiles.compute(
storagesReservedBytes);
stageReservedBytesP50 = stagesReservedBytesPercentiles.get(50).longValue();
stageReservedBytesP90 = stagesReservedBytesPercentiles.get(90).longValue();
stageReservedBytesP95 = stagesReservedBytesPercentiles.get(95).longValue();
stageReservedBytesAvg = reservedBytes / stagesCount;
}
}

return new StorageStatsValue(
queriesCount,
stagesCount,
reservedBytes,
queryReservedBytesAvg,
queryReservedBytesP50,
queryReservedBytesP90,
queryReservedBytesP95,
stageReservedBytesAvg,
stageReservedBytesP50,
stageReservedBytesP90,
stageReservedBytesP95);
}

@NotThreadSafe
private class TaskDescriptors
{
private final Map<TaskDescriptorKey, TaskDescriptor> descriptors = new HashMap<>();
private final Table<StageId, Integer /* partitionId */, TaskDescriptor> descriptors = HashBasedTable.create();

private long reservedBytes;
private final Map<StageId, AtomicLong> stagesReservedBytes = new HashMap<>();
private RuntimeException failure;

public void put(StageId stageId, int partitionId, TaskDescriptor descriptor)
{
throwIfFailed();
TaskDescriptorKey key = new TaskDescriptorKey(stageId, partitionId);
checkState(descriptors.putIfAbsent(key, descriptor) == null, "task descriptor is already present for key %s ", key);
reservedBytes += descriptor.getRetainedSizeInBytes();
checkState(!descriptors.contains(stageId, partitionId), "task descriptor is already present for key %s/%s ", stageId, partitionId);
descriptors.put(stageId, partitionId, descriptor);
long descriptorRetainedBytes = descriptor.getRetainedSizeInBytes();
reservedBytes += descriptorRetainedBytes;
stagesReservedBytes.computeIfAbsent(stageId, ignored -> new AtomicLong()).addAndGet(descriptorRetainedBytes);
}

public TaskDescriptor get(StageId stageId, int partitionId)
{
throwIfFailed();
TaskDescriptorKey key = new TaskDescriptorKey(stageId, partitionId);
TaskDescriptor descriptor = descriptors.get(key);
TaskDescriptor descriptor = descriptors.get(stageId, partitionId);
if (descriptor == null) {
throw new NoSuchElementException(format("descriptor not found for key %s", key));
throw new NoSuchElementException(format("descriptor not found for key %s/%s", stageId, partitionId));
}
return descriptor;
}

public void remove(StageId stageId, int partitionId)
{
throwIfFailed();
TaskDescriptorKey key = new TaskDescriptorKey(stageId, partitionId);
TaskDescriptor descriptor = descriptors.remove(key);
TaskDescriptor descriptor = descriptors.remove(stageId, partitionId);
if (descriptor == null) {
throw new NoSuchElementException(format("descriptor not found for key %s", key));
throw new NoSuchElementException(format("descriptor not found for key %s/%s", stageId, partitionId));
}
reservedBytes -= descriptor.getRetainedSizeInBytes();
long descriptorRetainedBytes = descriptor.getRetainedSizeInBytes();
reservedBytes -= descriptorRetainedBytes;
requireNonNull(stagesReservedBytes.get(stageId), () -> format("no entry for stage %s", stageId)).addAndGet(-descriptorRetainedBytes);
}

public long getReservedBytes()
Expand All @@ -238,10 +319,10 @@ public long getReservedBytes()

private String getDebugInfo()
{
Multimap<StageId, TaskDescriptor> descriptorsByStageId = descriptors.entrySet().stream()
Multimap<StageId, TaskDescriptor> descriptorsByStageId = descriptors.cellSet().stream()
.collect(toImmutableSetMultimap(
entry -> entry.getKey().getStageId(),
Map.Entry::getValue));
Table.Cell::getRowKey,
Table.Cell::getValue));

Map<StageId, String> debugInfoByStageId = descriptorsByStageId.asMap().entrySet().stream()
.collect(toImmutableMap(
Expand Down Expand Up @@ -299,55 +380,105 @@ private void throwIfFailed()
throw failure;
}
}

public int getStagesCount()
{
return descriptors.rowMap().size();
}

public Stream<Long> getStagesReservedBytes()
{
return stagesReservedBytes.values().stream()
.map(AtomicLong::get);
}
}

private static class TaskDescriptorKey
private record StorageStatsValue(
long queriesCount,
long stagesCount,
long reservedBytes,
long queryReservedBytesAvg,
long queryReservedBytesP50,
long queryReservedBytesP90,
long queryReservedBytesP95,
long stageReservedBytesAvg,
long stageReservedBytesP50,
long stageReservedBytesP90,
long stageReservedBytesP95) {}

public static class StorageStats
Comment thread
findepi marked this conversation as resolved.
Outdated
{
private final StageId stageId;
private final int partitionId;
private final Supplier<StorageStatsValue> statsSupplier;

private TaskDescriptorKey(StageId stageId, int partitionId)
StorageStats(Supplier<StorageStatsValue> statsSupplier)
{
this.stageId = requireNonNull(stageId, "stageId is null");
this.partitionId = partitionId;
this.statsSupplier = requireNonNull(statsSupplier, "statsSupplier is null");
}

public StageId getStageId()
@Managed
public long getQueriesCount()
{
return stageId;
return statsSupplier.get().queriesCount();
}

public int getPartitionId()
@Managed
public long getStagesCount()
{
return partitionId;
return statsSupplier.get().stagesCount();
}

@Override
public boolean equals(Object o)
@Managed
public long getReservedBytes()
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TaskDescriptorKey key = (TaskDescriptorKey) o;
return partitionId == key.partitionId && Objects.equals(stageId, key.stageId);
return statsSupplier.get().reservedBytes();
}

@Managed
public long getQueryReservedBytesAvg()
{
return statsSupplier.get().queryReservedBytesAvg();
}

@Managed
public long getQueryReservedBytesP50()
{
return statsSupplier.get().queryReservedBytesP50();
}

@Managed
public long getQueryReservedBytesP90()
{
return statsSupplier.get().queryReservedBytesP90();
}

@Managed
public long getQueryReservedBytesP95()
{
return statsSupplier.get().queryReservedBytesP95();
}

@Managed
public long getStageReservedBytesAvg()
{
return statsSupplier.get().stageReservedBytesP50();
}

@Managed
public long getStageReservedBytesP50()
{
return statsSupplier.get().stageReservedBytesP50();
}

@Override
public int hashCode()
@Managed
public long getStageReservedBytesP90()
{
return Objects.hash(stageId, partitionId);
return statsSupplier.get().stageReservedBytesP90();
}

@Override
public String toString()
@Managed
public long getStageReservedBytesP95()
{
return toStringHelper(this)
.add("stageId", stageId)
.add("partitionId", partitionId)
.toString();
return statsSupplier.get().stageReservedBytesP95();
}
}
}