diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskDescriptorStorage.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskDescriptorStorage.java index d00036326ac7..10a17253030e 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskDescriptorStorage.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskDescriptorStorage.java @@ -13,8 +13,13 @@ */ package io.trino.execution.scheduler; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Suppliers; import com.google.common.base.VerifyException; +import com.google.common.collect.HashBasedTable; import com.google.common.collect.Multimap; +import com.google.common.collect.Table; +import com.google.common.math.Quantiles; import com.google.common.math.Stats; import com.google.errorprone.annotations.concurrent.GuardedBy; import com.google.inject.Inject; @@ -29,6 +34,7 @@ import io.trino.spi.TrinoException; import io.trino.sql.planner.plan.PlanNodeId; import org.weakref.jmx.Managed; +import org.weakref.jmx.Nested; import java.util.Collection; import java.util.Comparator; @@ -36,16 +42,20 @@ import java.util.List; import java.util.Map; import java.util.NoSuchElementException; -import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; +import java.util.stream.Stream; -import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.ImmutableSetMultimap.toImmutableSetMultimap; +import static com.google.common.math.Quantiles.percentiles; import static io.airlift.units.DataSize.succinctBytes; import static io.trino.spi.StandardErrorCode.EXCEEDED_TASK_DESCRIPTOR_STORAGE_CAPACITY; import static java.lang.String.format; @@ -57,6 +67,7 @@ public class TaskDescriptorStorage private final long maxMemoryInBytes; private final JsonCodec splitJsonCodec; + private final StorageStats storageStats; @GuardedBy("this") private final Map storages = new HashMap<>(); @@ -75,6 +86,7 @@ public TaskDescriptorStorage(DataSize maxMemory, JsonCodec splitJsonCodec { this.maxMemoryInBytes = maxMemory.toBytes(); this.splitJsonCodec = requireNonNull(splitJsonCodec, "splitJsonCodec is null"); + this.storageStats = new StorageStats(Suppliers.memoizeWithExpiration(this::computeStats, 1, TimeUnit.SECONDS)); } /** @@ -188,34 +200,102 @@ private synchronized void updateMemoryReservation(long delta) } } - @Managed - public synchronized long getReservedBytes() + @VisibleForTesting + synchronized long getReservedBytes() { return reservedBytes; } + @Managed + @Nested + public StorageStats getStats() + { + // This should not contain materialized values. GuiceMBeanExporter calls it only once during application startup + // and then only @Managed methods all called on that instance. + return storageStats; + } + + private synchronized StorageStatsValue computeStats() + { + int queriesCount = storages.size(); + long stagesCount = storages.values().stream().mapToLong(TaskDescriptors::getStagesCount).sum(); + + Quantiles.ScaleAndIndexes percentiles = percentiles().indexes(50, 90, 95); + + long queryReservedBytesP50 = 0; + long queryReservedBytesP90 = 0; + long queryReservedBytesP95 = 0; + long queryReservedBytesAvg = 0; + long stageReservedBytesP50 = 0; + long stageReservedBytesP90 = 0; + long stageReservedBytesP95 = 0; + long stageReservedBytesAvg = 0; + + if (queriesCount > 0) { // we cannot compute percentiles for empty set + + Map queryReservedBytesPercentiles = percentiles.compute( + storages.values().stream() + .map(TaskDescriptors::getReservedBytes) + .collect(toImmutableList())); + + queryReservedBytesP50 = queryReservedBytesPercentiles.get(50).longValue(); + queryReservedBytesP90 = queryReservedBytesPercentiles.get(90).longValue(); + queryReservedBytesP95 = queryReservedBytesPercentiles.get(95).longValue(); + queryReservedBytesAvg = reservedBytes / queriesCount; + + List storagesReservedBytes = storages.values().stream() + .flatMap(TaskDescriptors::getStagesReservedBytes) + .collect(toImmutableList()); + + if (!storagesReservedBytes.isEmpty()) { + Map stagesReservedBytesPercentiles = percentiles.compute( + storagesReservedBytes); + stageReservedBytesP50 = stagesReservedBytesPercentiles.get(50).longValue(); + stageReservedBytesP90 = stagesReservedBytesPercentiles.get(90).longValue(); + stageReservedBytesP95 = stagesReservedBytesPercentiles.get(95).longValue(); + stageReservedBytesAvg = reservedBytes / stagesCount; + } + } + + return new StorageStatsValue( + queriesCount, + stagesCount, + reservedBytes, + queryReservedBytesAvg, + queryReservedBytesP50, + queryReservedBytesP90, + queryReservedBytesP95, + stageReservedBytesAvg, + stageReservedBytesP50, + stageReservedBytesP90, + stageReservedBytesP95); + } + @NotThreadSafe private class TaskDescriptors { - private final Map descriptors = new HashMap<>(); + private final Table descriptors = HashBasedTable.create(); + private long reservedBytes; + private final Map stagesReservedBytes = new HashMap<>(); private RuntimeException failure; public void put(StageId stageId, int partitionId, TaskDescriptor descriptor) { throwIfFailed(); - TaskDescriptorKey key = new TaskDescriptorKey(stageId, partitionId); - checkState(descriptors.putIfAbsent(key, descriptor) == null, "task descriptor is already present for key %s ", key); - reservedBytes += descriptor.getRetainedSizeInBytes(); + checkState(!descriptors.contains(stageId, partitionId), "task descriptor is already present for key %s/%s ", stageId, partitionId); + descriptors.put(stageId, partitionId, descriptor); + long descriptorRetainedBytes = descriptor.getRetainedSizeInBytes(); + reservedBytes += descriptorRetainedBytes; + stagesReservedBytes.computeIfAbsent(stageId, ignored -> new AtomicLong()).addAndGet(descriptorRetainedBytes); } public TaskDescriptor get(StageId stageId, int partitionId) { throwIfFailed(); - TaskDescriptorKey key = new TaskDescriptorKey(stageId, partitionId); - TaskDescriptor descriptor = descriptors.get(key); + TaskDescriptor descriptor = descriptors.get(stageId, partitionId); if (descriptor == null) { - throw new NoSuchElementException(format("descriptor not found for key %s", key)); + throw new NoSuchElementException(format("descriptor not found for key %s/%s", stageId, partitionId)); } return descriptor; } @@ -223,12 +303,13 @@ public TaskDescriptor get(StageId stageId, int partitionId) public void remove(StageId stageId, int partitionId) { throwIfFailed(); - TaskDescriptorKey key = new TaskDescriptorKey(stageId, partitionId); - TaskDescriptor descriptor = descriptors.remove(key); + TaskDescriptor descriptor = descriptors.remove(stageId, partitionId); if (descriptor == null) { - throw new NoSuchElementException(format("descriptor not found for key %s", key)); + throw new NoSuchElementException(format("descriptor not found for key %s/%s", stageId, partitionId)); } - reservedBytes -= descriptor.getRetainedSizeInBytes(); + long descriptorRetainedBytes = descriptor.getRetainedSizeInBytes(); + reservedBytes -= descriptorRetainedBytes; + requireNonNull(stagesReservedBytes.get(stageId), () -> format("no entry for stage %s", stageId)).addAndGet(-descriptorRetainedBytes); } public long getReservedBytes() @@ -238,10 +319,10 @@ public long getReservedBytes() private String getDebugInfo() { - Multimap descriptorsByStageId = descriptors.entrySet().stream() + Multimap descriptorsByStageId = descriptors.cellSet().stream() .collect(toImmutableSetMultimap( - entry -> entry.getKey().getStageId(), - Map.Entry::getValue)); + Table.Cell::getRowKey, + Table.Cell::getValue)); Map debugInfoByStageId = descriptorsByStageId.asMap().entrySet().stream() .collect(toImmutableMap( @@ -299,55 +380,105 @@ private void throwIfFailed() throw failure; } } + + public int getStagesCount() + { + return descriptors.rowMap().size(); + } + + public Stream getStagesReservedBytes() + { + return stagesReservedBytes.values().stream() + .map(AtomicLong::get); + } } - private static class TaskDescriptorKey + private record StorageStatsValue( + long queriesCount, + long stagesCount, + long reservedBytes, + long queryReservedBytesAvg, + long queryReservedBytesP50, + long queryReservedBytesP90, + long queryReservedBytesP95, + long stageReservedBytesAvg, + long stageReservedBytesP50, + long stageReservedBytesP90, + long stageReservedBytesP95) {} + + public static class StorageStats { - private final StageId stageId; - private final int partitionId; + private final Supplier statsSupplier; - private TaskDescriptorKey(StageId stageId, int partitionId) + StorageStats(Supplier statsSupplier) { - this.stageId = requireNonNull(stageId, "stageId is null"); - this.partitionId = partitionId; + this.statsSupplier = requireNonNull(statsSupplier, "statsSupplier is null"); } - public StageId getStageId() + @Managed + public long getQueriesCount() { - return stageId; + return statsSupplier.get().queriesCount(); } - public int getPartitionId() + @Managed + public long getStagesCount() { - return partitionId; + return statsSupplier.get().stagesCount(); } - @Override - public boolean equals(Object o) + @Managed + public long getReservedBytes() { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - TaskDescriptorKey key = (TaskDescriptorKey) o; - return partitionId == key.partitionId && Objects.equals(stageId, key.stageId); + return statsSupplier.get().reservedBytes(); + } + + @Managed + public long getQueryReservedBytesAvg() + { + return statsSupplier.get().queryReservedBytesAvg(); + } + + @Managed + public long getQueryReservedBytesP50() + { + return statsSupplier.get().queryReservedBytesP50(); + } + + @Managed + public long getQueryReservedBytesP90() + { + return statsSupplier.get().queryReservedBytesP90(); + } + + @Managed + public long getQueryReservedBytesP95() + { + return statsSupplier.get().queryReservedBytesP95(); + } + + @Managed + public long getStageReservedBytesAvg() + { + return statsSupplier.get().stageReservedBytesP50(); + } + + @Managed + public long getStageReservedBytesP50() + { + return statsSupplier.get().stageReservedBytesP50(); } - @Override - public int hashCode() + @Managed + public long getStageReservedBytesP90() { - return Objects.hash(stageId, partitionId); + return statsSupplier.get().stageReservedBytesP90(); } - @Override - public String toString() + @Managed + public long getStageReservedBytesP95() { - return toStringHelper(this) - .add("stageId", stageId) - .add("partitionId", partitionId) - .toString(); + return statsSupplier.get().stageReservedBytesP95(); } } }