Skip to content

Commit 078b680

Browse files
David Strykerelectrum
authored andcommitted
Support enforcement of NOT NULL column declarations
This commit enforces NOT NULL column declarations on write in the Presto engine, so it applies to all connectors. The existing Postgres and Mysql tests named testInsertIntoNotNullColumn were changed to check for the new error message, and a new test with the same name was added to TestIcebergSmoke. One possible concern with this commit is that the error message issued by the Presto engine when writing a null to a NOT NULL column is a different message than the Connector might issue if no value was supplied for the NOT NULL column. I think this is ok, because the error messages supplied by the Connectors are completely specific to the Connector.
1 parent a464029 commit 078b680

File tree

13 files changed

+91
-12
lines changed

13 files changed

+91
-12
lines changed

presto-iceberg/src/main/java/io/prestosql/plugin/iceberg/IcebergMetadata.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ private List<ColumnMetadata> getColumnMetadatas(org.apache.iceberg.Table table)
543543
return ColumnMetadata.builder()
544544
.setName(column.name())
545545
.setType(toPrestoType(column.type(), typeManager))
546+
.setNullable(column.isOptional())
546547
.setComment(Optional.ofNullable(column.doc()))
547548
.build();
548549
})

presto-iceberg/src/test/java/io/prestosql/plugin/iceberg/TestIcebergSmoke.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,22 @@ public void testSchemaEvolution()
391391
testWithAllFileFormats(this::testSchemaEvolution);
392392
}
393393

394+
@Test
395+
public void testInsertIntoNotNullColumn()
396+
{
397+
assertUpdate("CREATE TABLE test_not_null_table (c1 INTEGER, c2 INTEGER NOT NULL)");
398+
assertUpdate("INSERT INTO test_not_null_table (c2) VALUES (2)", 1);
399+
assertQuery("SELECT * FROM test_not_null_table", "VALUES (NULL, 2)");
400+
assertQueryFails("INSERT INTO test_not_null_table (c1) VALUES (1)", "NULL value not allowed for NOT NULL column: c2");
401+
assertUpdate("DROP TABLE IF EXISTS test_not_null_table");
402+
403+
assertUpdate("CREATE TABLE test_commuted_not_null_table (a BIGINT, b BIGINT NOT NULL)");
404+
assertUpdate("INSERT INTO test_commuted_not_null_table (b) VALUES (2)", 1);
405+
assertQuery("SELECT * FROM test_commuted_not_null_table", "VALUES (NULL, 2)");
406+
assertQueryFails("INSERT INTO test_commuted_not_null_table (b, a) VALUES (NULL, 3)", "NULL value not allowed for NOT NULL column: b");
407+
assertUpdate("DROP TABLE IF EXISTS test_commuted_not_null_table");
408+
}
409+
394410
private void testSchemaEvolution(Session session, FileFormat fileFormat)
395411
{
396412
assertUpdate(session, "CREATE TABLE test_schema_evolution_drop_end (col0 INTEGER, col1 INTEGER, col2 INTEGER) WITH (format = '" + fileFormat + "')");

presto-main/src/main/java/io/prestosql/operator/TableWriterOperator.java

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import io.prestosql.operator.OperationTimer.OperationTiming;
2626
import io.prestosql.spi.Page;
2727
import io.prestosql.spi.PageBuilder;
28+
import io.prestosql.spi.PrestoException;
2829
import io.prestosql.spi.block.Block;
2930
import io.prestosql.spi.block.BlockBuilder;
3031
import io.prestosql.spi.block.RunLengthEncodedBlock;
@@ -48,6 +49,7 @@
4849
import static io.airlift.concurrent.MoreFutures.getFutureValue;
4950
import static io.airlift.concurrent.MoreFutures.toListenableFuture;
5051
import static io.prestosql.SystemSessionProperties.isStatisticsCpuTimerEnabled;
52+
import static io.prestosql.spi.StandardErrorCode.CONSTRAINT_VIOLATION;
5153
import static io.prestosql.spi.type.BigintType.BIGINT;
5254
import static io.prestosql.spi.type.VarbinaryType.VARBINARY;
5355
import static io.prestosql.sql.planner.plan.TableWriterNode.CreateTarget;
@@ -70,6 +72,7 @@ public static class TableWriterOperatorFactory
7072
private final PageSinkManager pageSinkManager;
7173
private final WriterTarget target;
7274
private final List<Integer> columnChannels;
75+
private final List<String> notNullChannelColumnNames;
7376
private final Session session;
7477
private final OperatorFactory statisticsAggregationOperatorFactory;
7578
private final List<Type> types;
@@ -81,13 +84,15 @@ public TableWriterOperatorFactory(
8184
PageSinkManager pageSinkManager,
8285
WriterTarget writerTarget,
8386
List<Integer> columnChannels,
87+
List<String> notNullChannelColumnNames,
8488
Session session,
8589
OperatorFactory statisticsAggregationOperatorFactory,
8690
List<Type> types)
8791
{
8892
this.operatorId = operatorId;
8993
this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");
9094
this.columnChannels = requireNonNull(columnChannels, "columnChannels is null");
95+
this.notNullChannelColumnNames = requireNonNull(notNullChannelColumnNames, "notNullChannelColumnNames is null");
9196
this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null");
9297
checkArgument(writerTarget instanceof CreateTarget || writerTarget instanceof InsertTarget, "writerTarget must be CreateTarget or InsertTarget");
9398
this.target = requireNonNull(writerTarget, "writerTarget is null");
@@ -103,7 +108,7 @@ public Operator createOperator(DriverContext driverContext)
103108
OperatorContext context = driverContext.addOperatorContext(operatorId, planNodeId, TableWriterOperator.class.getSimpleName());
104109
Operator statisticsAggregationOperator = statisticsAggregationOperatorFactory.createOperator(driverContext);
105110
boolean statisticsCpuTimerEnabled = !(statisticsAggregationOperator instanceof DevNullOperator) && isStatisticsCpuTimerEnabled(session);
106-
return new TableWriterOperator(context, createPageSink(), columnChannels, statisticsAggregationOperator, types, statisticsCpuTimerEnabled);
111+
return new TableWriterOperator(context, createPageSink(), columnChannels, notNullChannelColumnNames, statisticsAggregationOperator, types, statisticsCpuTimerEnabled);
107112
}
108113

109114
private ConnectorPageSink createPageSink()
@@ -126,7 +131,7 @@ public void noMoreOperators()
126131
@Override
127132
public OperatorFactory duplicate()
128133
{
129-
return new TableWriterOperatorFactory(operatorId, planNodeId, pageSinkManager, target, columnChannels, session, statisticsAggregationOperatorFactory, types);
134+
return new TableWriterOperatorFactory(operatorId, planNodeId, pageSinkManager, target, columnChannels, notNullChannelColumnNames, session, statisticsAggregationOperatorFactory, types);
130135
}
131136
}
132137

@@ -139,6 +144,7 @@ private enum State
139144
private final LocalMemoryContext pageSinkMemoryContext;
140145
private final ConnectorPageSink pageSink;
141146
private final List<Integer> columnChannels;
147+
private final List<String> notNullChannelColumnNames;
142148
private final AtomicLong pageSinkPeakMemoryUsage = new AtomicLong();
143149
private final Operator statisticAggregationOperator;
144150
private final List<Type> types;
@@ -158,6 +164,7 @@ public TableWriterOperator(
158164
OperatorContext operatorContext,
159165
ConnectorPageSink pageSink,
160166
List<Integer> columnChannels,
167+
List<String> notNullChannelColumnNames,
161168
Operator statisticAggregationOperator,
162169
List<Type> types,
163170
boolean statisticsCpuTimerEnabled)
@@ -166,6 +173,8 @@ public TableWriterOperator(
166173
this.pageSinkMemoryContext = operatorContext.newLocalSystemMemoryContext(TableWriterOperator.class.getSimpleName());
167174
this.pageSink = requireNonNull(pageSink, "pageSink is null");
168175
this.columnChannels = requireNonNull(columnChannels, "columnChannels is null");
176+
this.notNullChannelColumnNames = requireNonNull(notNullChannelColumnNames, "notNullChannelColumnNames is null");
177+
checkArgument(columnChannels.size() == notNullChannelColumnNames.size(), "columnChannels and notNullColumnNames have different sizes");
169178
this.operatorContext.setInfoSupplier(this::getInfo);
170179
this.statisticAggregationOperator = requireNonNull(statisticAggregationOperator, "statisticAggregationOperator is null");
171180
this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));
@@ -227,7 +236,12 @@ public void addInput(Page page)
227236

228237
Block[] blocks = new Block[columnChannels.size()];
229238
for (int outputChannel = 0; outputChannel < columnChannels.size(); outputChannel++) {
230-
blocks[outputChannel] = page.getBlock(columnChannels.get(outputChannel));
239+
Block block = page.getBlock(columnChannels.get(outputChannel));
240+
String columnName = notNullChannelColumnNames.get(outputChannel);
241+
if (columnName != null) {
242+
verifyBlockHasNoNulls(block, columnName);
243+
}
244+
blocks[outputChannel] = block;
231245
}
232246

233247
OperationTimer timer = new OperationTimer(statisticsCpuTimerEnabled);
@@ -243,6 +257,18 @@ public void addInput(Page page)
243257
updateWrittenBytes();
244258
}
245259

260+
private void verifyBlockHasNoNulls(Block block, String columnName)
261+
{
262+
if (!block.mayHaveNull()) {
263+
return;
264+
}
265+
for (int position = 0; position < block.getPositionCount(); position++) {
266+
if (block.isNull(position)) {
267+
throw new PrestoException(CONSTRAINT_VIOLATION, "NULL value not allowed for NOT NULL column: " + columnName);
268+
}
269+
}
270+
}
271+
246272
@Override
247273
public Page getOutput()
248274
{

presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2335,12 +2335,17 @@ public PhysicalOperation visitTableWriter(TableWriterNode node, LocalExecutionPl
23352335
.map(source::symbolToChannel)
23362336
.collect(toImmutableList());
23372337

2338+
List<String> notNullChannelColumnNames = node.getColumns().stream()
2339+
.map(symbol -> node.getNotNullColumnSymbols().contains(symbol) ? node.getColumnNames().get(source.symbolToChannel(symbol)) : null)
2340+
.collect(Collectors.toList());
2341+
23382342
OperatorFactory operatorFactory = new TableWriterOperatorFactory(
23392343
context.getNextOperatorId(),
23402344
node.getId(),
23412345
pageSinkManager,
23422346
node.getTarget(),
23432347
inputChannels,
2348+
notNullChannelColumnNames,
23442349
session,
23452350
statisticsAggregation,
23462351
getSymbolTypes(node.getOutputSymbols(), context.getTypes()));

presto-main/src/main/java/io/prestosql/sql/planner/LogicalPlanner.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,13 @@
9292
import java.util.Map.Entry;
9393
import java.util.Objects;
9494
import java.util.Optional;
95+
import java.util.Set;
9596

9697
import static com.google.common.base.Preconditions.checkState;
9798
import static com.google.common.base.Verify.verify;
9899
import static com.google.common.collect.ImmutableList.toImmutableList;
99100
import static com.google.common.collect.ImmutableMap.toImmutableMap;
101+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
100102
import static com.google.common.collect.Streams.zip;
101103
import static io.prestosql.SystemSessionProperties.isCollectPlanStatisticsForAllQueries;
102104
import static io.prestosql.SystemSessionProperties.isUsePreferredWritePartitioning;
@@ -334,6 +336,7 @@ private RelationPlan createTableCreationPlan(Analysis analysis, Query query)
334336
plan,
335337
new CreateReference(destination.getCatalogName(), tableMetadata, newTableLayout),
336338
columnNames,
339+
tableMetadata.getColumns(),
337340
newTableLayout,
338341
statisticsMetadata);
339342
}
@@ -407,6 +410,7 @@ private RelationPlan createInsertPlan(Analysis analysis, Insert insertStatement)
407410
.map(columns::get)
408411
.collect(toImmutableList())),
409412
insertedTableColumnNames,
413+
insertedColumns,
410414
insert.getNewTableLayout(),
411415
statisticsMetadata);
412416
}
@@ -416,6 +420,7 @@ private RelationPlan createTableWriterPlan(
416420
RelationPlan plan,
417421
WriterTarget target,
418422
List<String> columnNames,
423+
List<ColumnMetadata> columnMetadataList,
419424
Optional<NewTableLayout> writeTableLayout,
420425
TableStatisticsMetadata statisticsMetadata)
421426
{
@@ -448,11 +453,17 @@ else if (isUsePreferredWritePartitioning(session)) {
448453
}
449454
}
450455

451-
if (!statisticsMetadata.isEmpty()) {
452-
verify(columnNames.size() == symbols.size(), "columnNames.size() != symbols.size(): %s and %s", columnNames, symbols);
453-
Map<String, Symbol> columnToSymbolMap = zip(columnNames.stream(), symbols.stream(), SimpleImmutableEntry::new)
454-
.collect(toImmutableMap(Entry::getKey, Entry::getValue));
456+
verify(columnNames.size() == symbols.size(), "columnNames.size() != symbols.size(): %s and %s", columnNames, symbols);
457+
Map<String, Symbol> columnToSymbolMap = zip(columnNames.stream(), symbols.stream(), SimpleImmutableEntry::new)
458+
.collect(toImmutableMap(Entry::getKey, Entry::getValue));
459+
460+
Set<Symbol> notNullColumnSymbols = columnMetadataList.stream()
461+
.filter(column -> !column.isNullable())
462+
.map(ColumnMetadata::getName)
463+
.map(columnToSymbolMap::get)
464+
.collect(toImmutableSet());
455465

466+
if (!statisticsMetadata.isEmpty()) {
456467
TableStatisticAggregation result = statisticsAggregationPlanner.createStatisticsAggregation(statisticsMetadata, columnToSymbolMap);
457468

458469
StatisticAggregations.Parts aggregations = result.getAggregations().createPartialAggregations(symbolAllocator, metadata);
@@ -473,6 +484,7 @@ else if (isUsePreferredWritePartitioning(session)) {
473484
symbolAllocator.newSymbol("fragment", VARBINARY),
474485
symbols,
475486
columnNames,
487+
notNullColumnSymbols,
476488
partitioningScheme,
477489
Optional.of(partialAggregation),
478490
Optional.of(result.getDescriptor().map(aggregations.getMappings()::get))),
@@ -494,6 +506,7 @@ else if (isUsePreferredWritePartitioning(session)) {
494506
symbolAllocator.newSymbol("fragment", VARBINARY),
495507
symbols,
496508
columnNames,
509+
notNullColumnSymbols,
497510
partitioningScheme,
498511
Optional.empty(),
499512
Optional.empty()),

presto-main/src/main/java/io/prestosql/sql/planner/optimizations/BeginTableWrite.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ public PlanNode visitTableWriter(TableWriterNode node, RewriteContext<Context> c
9898
node.getFragmentSymbol(),
9999
node.getColumns(),
100100
node.getColumnNames(),
101+
node.getNotNullColumnSymbols(),
101102
node.getPartitioningScheme(),
102103
node.getStatisticsAggregation(),
103104
node.getStatisticsAggregationDescriptor());

presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,7 @@ public PlanNode visitTableWriter(TableWriterNode node, RewriteContext<Set<Symbol
681681
node.getFragmentSymbol(),
682682
node.getColumns(),
683683
node.getColumnNames(),
684+
node.getNotNullColumnSymbols(),
684685
node.getPartitioningScheme(),
685686
node.getStatisticsAggregation(),
686687
node.getStatisticsAggregationDescriptor());

presto-main/src/main/java/io/prestosql/sql/planner/optimizations/SymbolMapper.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ public TableWriterNode map(TableWriterNode node, PlanNode source, PlanNodeId new
162162
map(node.getFragmentSymbol()),
163163
columns,
164164
node.getColumnNames(),
165+
node.getNotNullColumnSymbols(),
165166
node.getPartitioningScheme().map(partitioningScheme -> canonicalize(partitioningScheme, source)),
166167
node.getStatisticsAggregation().map(this::map),
167168
node.getStatisticsAggregationDescriptor().map(this::map));

presto-main/src/main/java/io/prestosql/sql/planner/plan/TableWriterNode.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import com.fasterxml.jackson.annotation.JsonSubTypes;
1919
import com.fasterxml.jackson.annotation.JsonTypeInfo;
2020
import com.google.common.collect.ImmutableList;
21+
import com.google.common.collect.ImmutableSet;
2122
import com.google.common.collect.Iterables;
2223
import io.prestosql.metadata.InsertTableHandle;
2324
import io.prestosql.metadata.NewTableLayout;
@@ -33,6 +34,7 @@
3334

3435
import java.util.List;
3536
import java.util.Optional;
37+
import java.util.Set;
3638

3739
import static com.google.common.base.Preconditions.checkArgument;
3840
import static java.util.Objects.requireNonNull;
@@ -47,6 +49,7 @@ public class TableWriterNode
4749
private final Symbol fragmentSymbol;
4850
private final List<Symbol> columns;
4951
private final List<String> columnNames;
52+
private final Set<Symbol> notNullColumnSymbols;
5053
private final Optional<PartitioningScheme> partitioningScheme;
5154
private final Optional<StatisticAggregations> statisticsAggregation;
5255
private final Optional<StatisticAggregationsDescriptor<Symbol>> statisticsAggregationDescriptor;
@@ -61,6 +64,7 @@ public TableWriterNode(
6164
@JsonProperty("fragmentSymbol") Symbol fragmentSymbol,
6265
@JsonProperty("columns") List<Symbol> columns,
6366
@JsonProperty("columnNames") List<String> columnNames,
67+
@JsonProperty("notNullColumnSymbols") Set<Symbol> notNullColumnSymbols,
6468
@JsonProperty("partitioningScheme") Optional<PartitioningScheme> partitioningScheme,
6569
@JsonProperty("statisticsAggregation") Optional<StatisticAggregations> statisticsAggregation,
6670
@JsonProperty("statisticsAggregationDescriptor") Optional<StatisticAggregationsDescriptor<Symbol>> statisticsAggregationDescriptor)
@@ -77,6 +81,7 @@ public TableWriterNode(
7781
this.fragmentSymbol = requireNonNull(fragmentSymbol, "fragmentSymbol is null");
7882
this.columns = ImmutableList.copyOf(columns);
7983
this.columnNames = ImmutableList.copyOf(columnNames);
84+
this.notNullColumnSymbols = ImmutableSet.copyOf(requireNonNull(notNullColumnSymbols, "notNullColumns is null"));
8085
this.partitioningScheme = requireNonNull(partitioningScheme, "partitioningScheme is null");
8186
this.statisticsAggregation = requireNonNull(statisticsAggregation, "statisticsAggregation is null");
8287
this.statisticsAggregationDescriptor = requireNonNull(statisticsAggregationDescriptor, "statisticsAggregationDescriptor is null");
@@ -128,6 +133,12 @@ public List<String> getColumnNames()
128133
return columnNames;
129134
}
130135

136+
@JsonProperty
137+
public Set<Symbol> getNotNullColumnSymbols()
138+
{
139+
return notNullColumnSymbols;
140+
}
141+
131142
@JsonProperty
132143
public Optional<PartitioningScheme> getPartitioningScheme()
133144
{
@@ -167,7 +178,7 @@ public <R, C> R accept(PlanVisitor<R, C> visitor, C context)
167178
@Override
168179
public PlanNode replaceChildren(List<PlanNode> newChildren)
169180
{
170-
return new TableWriterNode(getId(), Iterables.getOnlyElement(newChildren), target, rowCountSymbol, fragmentSymbol, columns, columnNames, partitioningScheme, statisticsAggregation, statisticsAggregationDescriptor);
181+
return new TableWriterNode(getId(), Iterables.getOnlyElement(newChildren), target, rowCountSymbol, fragmentSymbol, columns, columnNames, notNullColumnSymbols, partitioningScheme, statisticsAggregation, statisticsAggregationDescriptor);
171182
}
172183

173184
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "@type")

presto-main/src/test/java/io/prestosql/operator/TestTableWriterOperator.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,8 @@ private Operator createTableWriterOperator(
292292
Session session,
293293
DriverContext driverContext)
294294
{
295+
List<String> notNullColumnNames = new ArrayList<>(1);
296+
notNullColumnNames.add(null);
295297
TableWriterOperatorFactory factory = new TableWriterOperatorFactory(
296298
0,
297299
new PlanNodeId("test"),
@@ -302,6 +304,7 @@ private Operator createTableWriterOperator(
302304
new ConnectorOutputTableHandle() {}),
303305
new SchemaTableName("testSchema", "testTable")),
304306
ImmutableList.of(0),
307+
notNullColumnNames,
305308
session,
306309
statisticsAggregation,
307310
outputTypes);

0 commit comments

Comments
 (0)