diff --git a/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusRecordCursor.java b/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusRecordCursor.java index 6beb1827cebd..874142bd92b4 100644 --- a/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusRecordCursor.java +++ b/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusRecordCursor.java @@ -14,7 +14,6 @@ package io.trino.plugin.prometheus; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.io.ByteSource; import com.google.common.io.CountingInputStream; import io.airlift.slice.Slice; @@ -119,7 +118,7 @@ private Object getFieldValue(int field) int columnIndex = fieldToColumnIndex[field]; switch (columnIndex) { case 0: - return fields.getLabels(); + return getBlockFromMap(columnHandles.get(columnIndex).getColumnType(), fields.getLabels()); case 1: return fields.getTimestamp(); case 2: @@ -186,7 +185,7 @@ private List prometheusResultsInStandardizedForm(List { return results.stream().map(result -> result.getTimeSeriesValues().getValues().stream().map(prometheusTimeSeriesValue -> new PrometheusStandardizedRow( - getBlockFromMap(columnHandles.get(0).getColumnType(), ImmutableMap.copyOf(result.getMetricHeader())), + result.getMetricHeader(), prometheusTimeSeriesValue.getTimestamp(), Double.parseDouble(prometheusTimeSeriesValue.getValue()))) .collect(Collectors.toList())) diff --git a/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusStandardizedRow.java b/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusStandardizedRow.java index 0ad038807453..e74d950c8198 100644 --- a/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusStandardizedRow.java +++ b/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusStandardizedRow.java @@ -13,26 +13,27 @@ */ package io.trino.plugin.prometheus; -import io.trino.spi.block.Block; +import com.google.common.collect.ImmutableMap; import java.time.Instant; +import java.util.Map; import static java.util.Objects.requireNonNull; public class PrometheusStandardizedRow { - private final Block labels; + private final Map labels; private final Instant timestamp; private final Double value; - public PrometheusStandardizedRow(Block labels, Instant timestamp, Double value) + public PrometheusStandardizedRow(Map labels, Instant timestamp, Double value) { - this.labels = requireNonNull(labels, "labels is null"); + this.labels = ImmutableMap.copyOf(requireNonNull(labels, "labels is null")); this.timestamp = requireNonNull(timestamp, "timestamp is null"); this.value = requireNonNull(value, "value is null"); } - public Block getLabels() + public Map getLabels() { return labels; } diff --git a/plugin/trino-prometheus/src/test/java/io/trino/plugin/prometheus/TestPrometheusIntegration.java b/plugin/trino-prometheus/src/test/java/io/trino/plugin/prometheus/TestPrometheusIntegration.java index a2dd9254279d..57a853a41d72 100644 --- a/plugin/trino-prometheus/src/test/java/io/trino/plugin/prometheus/TestPrometheusIntegration.java +++ b/plugin/trino-prometheus/src/test/java/io/trino/plugin/prometheus/TestPrometheusIntegration.java @@ -56,6 +56,13 @@ public void testSelectTable() .matches("SELECT MAP(ARRAY[VARCHAR 'instance', '__name__', 'job'], ARRAY[VARCHAR 'localhost:9090', 'up', 'prometheus'])"); } + @Test + public void testAggregation() + { + assertQuerySucceeds("SELECT count(*) FROM default.up"); // Don't check value since the row number isn't deterministic + assertQuery("SELECT avg(value) FROM default.up", "VALUES ('1.0')"); + } + @Test public void testPushDown() { diff --git a/plugin/trino-prometheus/src/test/java/io/trino/plugin/prometheus/TestPrometheusRecordSet.java b/plugin/trino-prometheus/src/test/java/io/trino/plugin/prometheus/TestPrometheusRecordSet.java index 95ea11885ad9..7405ec17b4cf 100644 --- a/plugin/trino-prometheus/src/test/java/io/trino/plugin/prometheus/TestPrometheusRecordSet.java +++ b/plugin/trino-prometheus/src/test/java/io/trino/plugin/prometheus/TestPrometheusRecordSet.java @@ -27,6 +27,7 @@ import java.util.ArrayList; import java.util.List; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.trino.plugin.prometheus.MetadataUtil.METRIC_CODEC; import static io.trino.plugin.prometheus.MetadataUtil.varcharMapType; import static io.trino.plugin.prometheus.PrometheusClient.TIMESTAMP_COLUMN_TYPE; @@ -62,7 +63,8 @@ public void testCursorSimple() List actual = new ArrayList<>(); while (cursor.advanceNextPosition()) { actual.add(new PrometheusStandardizedRow( - (Block) cursor.getObject(0), + getMapFromBlock(varcharMapType, (Block) cursor.getObject(0)).entrySet().stream() + .collect(toImmutableMap(entry -> (String) entry.getKey(), entry -> (String) entry.getValue())), (Instant) cursor.getObject(1), cursor.getDouble(2))); assertFalse(cursor.isNull(0)); @@ -70,14 +72,14 @@ public void testCursorSimple() assertFalse(cursor.isNull(2)); } List expected = ImmutableList.builder() - .add(new PrometheusStandardizedRow(getBlockFromMap(varcharMapType, - ImmutableMap.of("instance", "localhost:9090", "__name__", "up", "job", "prometheus")), ofEpochMilli(1565962969044L), 1.0)) - .add(new PrometheusStandardizedRow(getBlockFromMap(varcharMapType, - ImmutableMap.of("instance", "localhost:9090", "__name__", "up", "job", "prometheus")), ofEpochMilli(1565962984045L), 1.0)) - .add(new PrometheusStandardizedRow(getBlockFromMap(varcharMapType, - ImmutableMap.of("instance", "localhost:9090", "__name__", "up", "job", "prometheus")), ofEpochMilli(1565962999044L), 1.0)) - .add(new PrometheusStandardizedRow(getBlockFromMap(varcharMapType, - ImmutableMap.of("instance", "localhost:9090", "__name__", "up", "job", "prometheus")), ofEpochMilli(1565963014044L), 1.0)) + .add(new PrometheusStandardizedRow( + ImmutableMap.of("instance", "localhost:9090", "__name__", "up", "job", "prometheus"), ofEpochMilli(1565962969044L), 1.0)) + .add(new PrometheusStandardizedRow( + ImmutableMap.of("instance", "localhost:9090", "__name__", "up", "job", "prometheus"), ofEpochMilli(1565962984045L), 1.0)) + .add(new PrometheusStandardizedRow( + ImmutableMap.of("instance", "localhost:9090", "__name__", "up", "job", "prometheus"), ofEpochMilli(1565962999044L), 1.0)) + .add(new PrometheusStandardizedRow( + ImmutableMap.of("instance", "localhost:9090", "__name__", "up", "job", "prometheus"), ofEpochMilli(1565963014044L), 1.0)) .build(); assertThat(actual).as("actual") @@ -85,7 +87,7 @@ public void testCursorSimple() for (int i = 0; i < actual.size(); i++) { PrometheusStandardizedRow actualRow = actual.get(i); PrometheusStandardizedRow expectedRow = expected.get(i); - assertEquals(getMapFromBlock(varcharMapType, actualRow.getLabels()), getMapFromBlock(varcharMapType, expectedRow.getLabels())); + assertEquals(getMapFromBlock(varcharMapType, getBlockFromMap(varcharMapType, actualRow.getLabels())), getMapFromBlock(varcharMapType, getBlockFromMap(varcharMapType, expectedRow.getLabels()))); assertEquals(actualRow.getTimestamp(), expectedRow.getTimestamp()); assertEquals(actualRow.getValue(), expectedRow.getValue()); }