diff --git a/presto-testing/src/main/java/io/prestosql/testing/sql/TestTable.java b/presto-testing/src/main/java/io/prestosql/testing/sql/TestTable.java index dc21f18d2169..70614ef0416d 100644 --- a/presto-testing/src/main/java/io/prestosql/testing/sql/TestTable.java +++ b/presto-testing/src/main/java/io/prestosql/testing/sql/TestTable.java @@ -17,11 +17,15 @@ import java.security.SecureRandom; import java.util.List; +import java.util.Map; +import java.util.function.Function; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.lang.Character.MAX_RADIX; import static java.lang.Math.abs; import static java.lang.Math.min; import static java.lang.String.format; +import static java.lang.String.join; public class TestTable implements AutoCloseable @@ -60,6 +64,57 @@ public String getName() return name; } + public static TestTable fromColumns(SqlExecutor sqlExecutor, String namePrefix, Map> columns) + { + return fromColumns( + sqlExecutor, + namePrefix, + columns, + column -> { + throw new IllegalArgumentException(String.format("Some values missing for column '%s'", column)); + }); + } + + public static TestTable fromColumns(SqlExecutor sqlExecutor, String namePrefix, Map> columns, String defaultValue) + { + return fromColumns(sqlExecutor, namePrefix, columns, column -> defaultValue); + } + + private static TestTable fromColumns(SqlExecutor sqlExecutor, String namePrefix, Map> columns, Function defaultValueSupplier) + { + int rowsCount = columns.values().stream() + .mapToInt(List::size) + .max().orElseThrow(() -> new IllegalArgumentException("please, give me at least one column to work with")); + return fromColumnValueProviders( + sqlExecutor, + namePrefix, + rowsCount, + columns.entrySet().stream() + .collect(toImmutableMap( + Map.Entry::getKey, + entry -> index -> { + if (index < entry.getValue().size()) { + return entry.getValue().get(index); + } + return defaultValueSupplier.apply(entry.getKey()); + }))); + } + + public static TestTable fromColumnValueProviders(SqlExecutor sqlExecutor, String namePrefix, int rowsCount, Map> columnsValueProviders) + { + String tableDefinition = "(" + join(",", columnsValueProviders.keySet()) + ")"; + + ImmutableList.Builder rows = ImmutableList.builder(); + for (int rowId = 0; rowId < rowsCount; rowId++) { + ImmutableList.Builder rowValues = ImmutableList.builder(); + for (Function columnValues : columnsValueProviders.values()) { + rowValues.add(columnValues.apply(rowId)); + } + rows.add(join(",", rowValues.build())); + } + return new TestTable(sqlExecutor, namePrefix, tableDefinition, rows.build()); + } + @Override public void close() {