Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@

import static java.util.Objects.requireNonNull;

public record MemoryInsertTableHandle(long table, Set<Long> activeTableIds)
public record MemoryInsertTableHandle(long table, InsertMode mode, Set<Long> activeTableIds)
implements ConnectorInsertTableHandle
{
public enum InsertMode
{
APPEND, OVERWRITE
}

public MemoryInsertTableHandle
{
requireNonNull(mode, "mode is null");
activeTableIds = ImmutableSet.copyOf(requireNonNull(activeTableIds, "activeTableIds is null"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.errorprone.annotations.concurrent.GuardedBy;
import com.google.inject.Inject;
import io.airlift.slice.Slice;
import io.trino.plugin.memory.MemoryInsertTableHandle.InsertMode;
import io.trino.spi.HostAddress;
import io.trino.spi.Node;
import io.trino.spi.NodeManager;
Expand Down Expand Up @@ -277,7 +278,7 @@ public synchronized void renameTable(ConnectorSession session, ConnectorTableHan
long tableId = handle.id();

TableInfo oldInfo = tables.get(tableId);
tables.put(tableId, new TableInfo(tableId, newTableName.getSchemaName(), newTableName.getTableName(), oldInfo.columns(), oldInfo.dataFragments(), oldInfo.comment()));
tables.put(tableId, new TableInfo(tableId, newTableName.getSchemaName(), newTableName.getTableName(), oldInfo.columns(), oldInfo.truncated(), oldInfo.dataFragments(), oldInfo.comment()));

tableIds.remove(oldInfo.getSchemaTableName());
tableIds.put(newTableName, tableId);
Expand Down Expand Up @@ -311,6 +312,7 @@ public synchronized MemoryOutputTableHandle beginCreateTable(ConnectorSession se
tableMetadata.getTable().getSchemaName(),
tableMetadata.getTable().getTableName(),
columns.build(),
false,
new HashMap<>(),
tableMetadata.getComment()));

Expand Down Expand Up @@ -350,7 +352,10 @@ public synchronized Optional<ConnectorOutputMetadata> finishCreateTable(Connecto
public synchronized MemoryInsertTableHandle beginInsert(ConnectorSession session, ConnectorTableHandle tableHandle, List<ColumnHandle> columns, RetryMode retryMode)
{
MemoryTableHandle memoryTableHandle = (MemoryTableHandle) tableHandle;
return new MemoryInsertTableHandle(memoryTableHandle.id(), ImmutableSet.copyOf(tableIds.values()));
TableInfo tableInfo = tables.get(memoryTableHandle.id());
InsertMode mode = tableInfo.truncated() ? InsertMode.OVERWRITE : InsertMode.APPEND;
tables.put(tableInfo.id(), new TableInfo(tableInfo.id(), tableInfo.schemaName(), tableInfo.tableName(), tableInfo.columns(), false, tableInfo.dataFragments(), tableInfo.comment()));
return new MemoryInsertTableHandle(memoryTableHandle.id(), mode, ImmutableSet.copyOf(tableIds.values()));
}

@Override
Expand All @@ -374,7 +379,7 @@ public synchronized void truncateTable(ConnectorSession session, ConnectorTableH
MemoryTableHandle handle = (MemoryTableHandle) tableHandle;
long tableId = handle.id();
TableInfo info = tables.get(handle.id());
tables.put(tableId, new TableInfo(tableId, info.schemaName(), info.tableName(), info.columns(), ImmutableMap.of(), info.comment()));
tables.put(tableId, new TableInfo(tableId, info.schemaName(), info.tableName(), info.columns(), true, ImmutableMap.of(), info.comment()));
}

@Override
Expand All @@ -393,7 +398,7 @@ public synchronized void addColumn(ConnectorSession session, ConnectorTableHandl
.add(new ColumnInfo(new MemoryColumnHandle(table.columns().size(), column.getType()), column.getName(), column.getType(), column.isNullable(), Optional.ofNullable(column.getComment())))
.build();

tables.put(tableId, new TableInfo(tableId, table.schemaName(), table.tableName(), columns, table.dataFragments(), table.comment()));
tables.put(tableId, new TableInfo(tableId, table.schemaName(), table.tableName(), columns, table.truncated(), table.dataFragments(), table.comment()));
}

@Override
Expand All @@ -408,7 +413,7 @@ public synchronized void renameColumn(ConnectorSession session, ConnectorTableHa
ColumnInfo columnInfo = columns.get(column.columnIndex());
columns.set(column.columnIndex(), new ColumnInfo(columnInfo.handle(), target, columnInfo.type(), columnInfo.nullable(), columnInfo.comment()));

tables.put(tableId, new TableInfo(tableId, table.schemaName(), table.tableName(), ImmutableList.copyOf(columns), table.dataFragments(), table.comment()));
tables.put(tableId, new TableInfo(tableId, table.schemaName(), table.tableName(), ImmutableList.copyOf(columns), table.truncated(), table.dataFragments(), table.comment()));
}

@Override
Expand All @@ -423,7 +428,7 @@ public synchronized void dropNotNullConstraint(ConnectorSession session, Connect
ColumnInfo columnInfo = columns.get(column.columnIndex());
columns.set(column.columnIndex(), new ColumnInfo(columnInfo.handle(), columnInfo.name(), columnInfo.type(), true, columnInfo.comment()));

tables.put(tableId, new TableInfo(tableId, table.schemaName(), table.tableName(), ImmutableList.copyOf(columns), table.dataFragments(), table.comment()));
tables.put(tableId, new TableInfo(tableId, table.schemaName(), table.tableName(), ImmutableList.copyOf(columns), table.truncated(), table.dataFragments(), table.comment()));
}

@Override
Expand Down Expand Up @@ -538,7 +543,7 @@ private void updateRowsOnHosts(long tableId, Collection<Slice> fragments)
dataFragments.merge(memoryDataFragment.hostAddress(), memoryDataFragment, MemoryDataFragment::merge);
}

tables.put(tableId, new TableInfo(tableId, info.schemaName(), info.tableName(), info.columns(), dataFragments, info.comment()));
tables.put(tableId, new TableInfo(tableId, info.schemaName(), info.tableName(), info.columns(), info.truncated(), dataFragments, info.comment()));
}

public synchronized List<MemoryDataFragment> getDataFragments(long tableId)
Expand Down Expand Up @@ -599,7 +604,7 @@ public synchronized void setTableComment(ConnectorSession session, ConnectorTabl
MemoryTableHandle table = (MemoryTableHandle) tableHandle;
TableInfo info = tables.get(table.id());
checkArgument(info != null, "Table not found");
tables.put(table.id(), new TableInfo(table.id(), info.schemaName(), info.tableName(), info.columns(), info.dataFragments(), comment));
tables.put(table.id(), new TableInfo(table.id(), info.schemaName(), info.tableName(), info.columns(), info.truncated(), info.dataFragments(), comment));
}

@Override
Expand All @@ -617,6 +622,7 @@ public synchronized void setColumnComment(ConnectorSession session, ConnectorTab
info.columns().stream()
.map(tableColumn -> Objects.equals(tableColumn.handle(), columnHandle) ? new ColumnInfo(tableColumn.handle(), tableColumn.name(), tableColumn.getMetadata().getType(), tableColumn.nullable(), comment) : tableColumn)
.collect(toImmutableList()),
info.truncated(),
info.dataFragments(),
info.comment()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ImmutableList;
import com.google.inject.Inject;
import io.airlift.slice.Slice;
import io.trino.plugin.memory.MemoryInsertTableHandle.InsertMode;
import io.trino.spi.HostAddress;
import io.trino.spi.NodeManager;
import io.trino.spi.Page;
Expand Down Expand Up @@ -75,6 +76,10 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa
long tableId = memoryInsertTableHandle.table();
checkState(memoryInsertTableHandle.activeTableIds().contains(tableId));

if (memoryInsertTableHandle.mode() == InsertMode.OVERWRITE) {
pagesStore.purge(tableId);
}

pagesStore.cleanUp(memoryInsertTableHandle.activeTableIds());
pagesStore.initialize(tableId);
return new MemoryPageSink(pagesStore, currentHostAddress, tableId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ public synchronized boolean contains(Long tableId)
return tables.containsKey(tableId);
}

public synchronized void purge(long tableId)
{
tables.remove(tableId);
Comment thread
ebyhr marked this conversation as resolved.
Outdated
}

public synchronized void cleanUp(Set<Long> activeTableIds)
{
// We have to remember that there might be some race conditions when there are two tables created at once.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public record TableInfo(
String schemaName,
String tableName,
List<ColumnInfo> columns,
boolean truncated,
Map<HostAddress, MemoryDataFragment> dataFragments,
Optional<String> comment)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,19 @@ public void testRenameView()
assertUpdate("DROP SCHEMA test_different_schema");
}

@Test
void testInsertAfterTruncate()
{
try (TestTable table = new TestTable(getQueryRunner()::execute, "test_truncate", "AS SELECT 1 x")) {
assertUpdate("TRUNCATE TABLE " + table.getName());
assertQueryReturnsEmptyResult("SELECT * FROM " + table.getName());

assertUpdate("INSERT INTO " + table.getName() + " VALUES 2", 1);
assertThat(query("SELECT * FROM " + table.getName()))
.matches("VALUES 2");
}
}

@Override
protected String errorMessageForInsertIntoNotNullColumn(String columnName)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.airlift.units.DataSize;
import io.trino.plugin.memory.MemoryInsertTableHandle.InsertMode;
import io.trino.spi.HostAddress;
import io.trino.spi.Page;
import io.trino.spi.TrinoException;
Expand Down Expand Up @@ -165,7 +166,7 @@ private static ConnectorOutputTableHandle createMemoryOutputTableHandle(long tab

private static ConnectorInsertTableHandle createMemoryInsertTableHandle(long tableId, Long[] activeTableIds)
{
return new MemoryInsertTableHandle(tableId, ImmutableSet.copyOf(activeTableIds));
return new MemoryInsertTableHandle(tableId, InsertMode.APPEND, ImmutableSet.copyOf(activeTableIds));
}

private static Page createPage()
Expand Down