diff --git a/enginetest/queries/load_queries.go b/enginetest/queries/load_queries.go index 1f29fadf20..de29acace3 100644 --- a/enginetest/queries/load_queries.go +++ b/enginetest/queries/load_queries.go @@ -43,8 +43,7 @@ var LoadDataScripts = []ScriptTest{ }, Assertions: []ScriptTestAssertion{ { - Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"'", - + Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"'", ExpectedErrStr: "Check constraint \"loadtable_chk_1\" violated", }, }, @@ -275,11 +274,19 @@ var LoadDataErrorScripts = []ScriptTest{ { Name: "Load data with unknown columns throws an error", SetUpScript: []string{ - "create table loadtable(pk int primary key)", + "create table loadtable(pk int primary key, i int)", }, Assertions: []ScriptTestAssertion{ { - Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"' (bad)", + Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"' (fake_col, pk, i)", + ExpectedErr: plan.ErrInsertIntoNonexistentColumn, + }, + { + Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"' (pk, fake_col, i)", + ExpectedErr: plan.ErrInsertIntoNonexistentColumn, + }, + { + Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"' (pk, i, fake_col)", ExpectedErr: plan.ErrInsertIntoNonexistentColumn, }, }, diff --git a/sql/analyzer/inserts.go b/sql/analyzer/inserts.go index a94c8557cc..39ac7d488a 100644 --- a/sql/analyzer/inserts.go +++ b/sql/analyzer/inserts.go @@ -45,22 +45,6 @@ func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sc return nil, transform.SameTree, err } - if insert.IsReplace { - var ok bool - _, ok = insertable.(sql.ReplaceableTable) - if !ok { - return nil, transform.SameTree, plan.ErrReplaceIntoNotSupported.New() - } - } - - if len(insert.OnDupExprs) > 0 { - var ok bool - _, ok = insertable.(sql.UpdatableTable) - if !ok { - return nil, transform.SameTree, plan.ErrOnDuplicateKeyUpdateNotSupported.New() - } - } - source := insert.Source // TriggerExecutor has already been analyzed if _, ok := insert.Source.(*plan.TriggerExecutor); !ok { @@ -93,16 +77,6 @@ func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sc for i, f := range dstSchema { columnNames[i] = f.Name } - } else { - err = validateColumns(table.Name(), columnNames, dstSchema, source) - if err != nil { - return nil, transform.SameTree, err - } - } - - err = validateValueCount(columnNames, source) - if err != nil { - return nil, transform.SameTree, err } // The schema of the destination node and the underlying table differ subtly in terms of defaults @@ -201,29 +175,6 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s return plan.NewProject(projExprs, insertSource), autoAutoIncrement, nil } -func validateColumns(tableName string, columnNames []string, dstSchema sql.Schema, source sql.Node) error { - dstColNames := make(map[string]*sql.Column) - for _, dstCol := range dstSchema { - dstColNames[strings.ToLower(dstCol.Name)] = dstCol - } - usedNames := make(map[string]struct{}) - for i, columnName := range columnNames { - dstCol, exists := dstColNames[columnName] - if !exists { - return plan.ErrInsertIntoNonexistentColumn.New(columnName) - } - if dstCol.Generated != nil && !validGeneratedColumnValue(i, source) { - return sql.ErrGeneratedColumnValue.New(dstCol.Name, tableName) - } - if _, exists := usedNames[columnName]; !exists { - usedNames[columnName] = struct{}{} - } else { - return plan.ErrInsertIntoDuplicateColumn.New(columnName) - } - } - return nil -} - // validGeneratedColumnValue returns true if the column is a generated column and the source node is not a values node. // Explicit default values (`DEFAULT`) are the only valid values to specify for a generated column func validGeneratedColumnValue(idx int, source sql.Node) bool { @@ -248,35 +199,6 @@ func validGeneratedColumnValue(idx int, source sql.Node) bool { } } -func validateValueCount(columnNames []string, values sql.Node) error { - if exchange, ok := values.(*plan.Exchange); ok { - values = exchange.Child - } - - switch node := values.(type) { - case *plan.Values: - for _, exprTuple := range node.ExpressionTuples { - if len(exprTuple) != len(columnNames) { - return sql.ErrInsertIntoMismatchValueCount.New() - } - } - case *plan.LoadData: - dataColLen := len(node.ColumnNames) - if dataColLen == 0 { - dataColLen = len(node.Schema()) - } - if len(columnNames) != dataColLen { - return sql.ErrInsertIntoMismatchValueCount.New() - } - default: - // Parser assures us that this will be some form of SelectStatement, so no need to type check it - if len(columnNames) != len(values.Schema()) { - return sql.ErrInsertIntoMismatchValueCount.New() - } - } - return nil -} - func assertCompatibleSchemas(projExprs []sql.Expression, schema sql.Schema) error { for _, expr := range projExprs { switch e := expr.(type) { diff --git a/sql/planbuilder/dml.go b/sql/planbuilder/dml.go index 69f5c0f795..9cae5f4475 100644 --- a/sql/planbuilder/dml.go +++ b/sql/planbuilder/dml.go @@ -107,6 +107,9 @@ func (b *Builder) buildInsert(inScope *scope, i *ast.Insert) (outScope *scope) { dest := destScope.node ins := plan.NewInsertInto(db, plan.NewInsertDestination(sch, dest), srcScope.node, isReplace, columns, onDupExprs, ignore) + + b.validateInsert(ins) + outScope = destScope outScope.node = ins if rt != nil { diff --git a/sql/planbuilder/dml_validate.go b/sql/planbuilder/dml_validate.go new file mode 100644 index 0000000000..087a29c705 --- /dev/null +++ b/sql/planbuilder/dml_validate.go @@ -0,0 +1,171 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package planbuilder + +import ( + "strings" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/plan" +) + +func (b *Builder) validateInsert(ins *plan.InsertInto) { + table := getResolvedTable(ins.Destination) + if table == nil { + return + } + + insertable, err := plan.GetInsertable(table) + if err != nil { + b.handleErr(err) + } + + if ins.IsReplace { + var ok bool + _, ok = insertable.(sql.ReplaceableTable) + if !ok { + err := plan.ErrReplaceIntoNotSupported.New() + b.handleErr(err) + } + } + + if len(ins.OnDupExprs) > 0 { + var ok bool + _, ok = insertable.(sql.UpdatableTable) + if !ok { + err := plan.ErrOnDuplicateKeyUpdateNotSupported.New() + b.handleErr(err) + } + } + + // normalize the column name + dstSchema := insertable.Schema() + columnNames := make([]string, len(ins.ColumnNames)) + for i, name := range ins.ColumnNames { + columnNames[i] = strings.ToLower(name) + } + + // If no columns are given and value tuples are not all empty, use the full schema + if len(columnNames) == 0 && existsNonZeroValueCount(ins.Source) { + columnNames = make([]string, len(dstSchema)) + for i, f := range dstSchema { + columnNames[i] = f.Name + } + } + + if len(ins.ColumnNames) > 0 { + err := validateColumns(table.Name(), columnNames, dstSchema, ins.Source) + if err != nil { + b.handleErr(err) + } + } + + err = validateValueCount(columnNames, ins.Source) + if err != nil { + b.handleErr(err) + } +} + +// Ensures that the number of elements in each Value tuple is empty +func existsNonZeroValueCount(values sql.Node) bool { + switch node := values.(type) { + case *plan.Values: + for _, exprTuple := range node.ExpressionTuples { + if len(exprTuple) != 0 { + return true + } + } + default: + return true + } + return false +} + +func validateColumns(tableName string, columnNames []string, dstSchema sql.Schema, source sql.Node) error { + dstColNames := make(map[string]*sql.Column) + for _, dstCol := range dstSchema { + dstColNames[strings.ToLower(dstCol.Name)] = dstCol + } + usedNames := make(map[string]struct{}) + for i, columnName := range columnNames { + dstCol, exists := dstColNames[columnName] + if !exists { + return plan.ErrInsertIntoNonexistentColumn.New(columnName) + } + if dstCol.Generated != nil && !validGeneratedColumnValue(i, source) { + return sql.ErrGeneratedColumnValue.New(dstCol.Name, tableName) + } + if _, exists := usedNames[columnName]; !exists { + usedNames[columnName] = struct{}{} + } else { + return plan.ErrInsertIntoDuplicateColumn.New(columnName) + } + } + return nil +} + +// validGeneratedColumnValue returns true if the column is a generated column and the source node is not a values node. +// Explicit default values (`DEFAULT`) are the only valid values to specify for a generated column +func validGeneratedColumnValue(idx int, source sql.Node) bool { + switch source := source.(type) { + case *plan.Values: + for _, tuple := range source.ExpressionTuples { + switch val := tuple[idx].(type) { + case *sql.ColumnDefaultValue: // should be wrapped, but just in case + return true + case *expression.Wrapper: + if _, ok := val.Unwrap().(*sql.ColumnDefaultValue); ok { + return true + } + return false + default: + return false + } + } + return false + default: + return false + } +} + +func validateValueCount(columnNames []string, values sql.Node) error { + if exchange, ok := values.(*plan.Exchange); ok { + values = exchange.Child + } + + switch node := values.(type) { + case *plan.Values: + for _, exprTuple := range node.ExpressionTuples { + if len(exprTuple) != len(columnNames) { + return sql.ErrInsertIntoMismatchValueCount.New() + } + } + case *plan.LoadData: + dataColLen := len(node.ColumnNames) + if dataColLen == 0 { + dataColLen = len(node.Schema()) + } + if len(columnNames) != dataColLen { + return sql.ErrInsertIntoMismatchValueCount.New() + } + default: + // Parser assures us that this will be some form of SelectStatement, so no need to type check it + if len(columnNames) != len(values.Schema()) { + return sql.ErrInsertIntoMismatchValueCount.New() + } + } + return nil +} diff --git a/sql/planbuilder/load.go b/sql/planbuilder/load.go index 2dffaccb03..411a2d0244 100644 --- a/sql/planbuilder/load.go +++ b/sql/planbuilder/load.go @@ -69,7 +69,7 @@ func (b *Builder) buildLoad(inScope *scope, d *ast.Load) (outScope *scope) { ld := plan.NewLoadData(bool(d.Local), d.Infile, sch, columnsToStrings(d.Columns), d.Fields, d.Lines, ignoreNumVal, d.IgnoreOrReplace) outScope = inScope.push() ins := plan.NewInsertInto(db, plan.NewInsertDestination(sch, dest), ld, ld.IsReplace, ld.ColumnNames, nil, ld.IsIgnore) - + b.validateInsert(ins) outScope.node = ins if rt != nil { checks := b.loadChecksFromTable(destScope, rt.Table)