Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions enginetest/queries/load_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ var LoadDataScripts = []ScriptTest{
},
Assertions: []ScriptTestAssertion{
{
Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"'",

Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"'",
ExpectedErrStr: "Check constraint \"loadtable_chk_1\" violated",
},
},
Expand Down Expand Up @@ -275,11 +274,19 @@ var LoadDataErrorScripts = []ScriptTest{
{
Name: "Load data with unknown columns throws an error",
SetUpScript: []string{
"create table loadtable(pk int primary key)",
"create table loadtable(pk int primary key, i int)",
},
Assertions: []ScriptTestAssertion{
{
Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"' (bad)",
Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"' (fake_col, pk, i)",
ExpectedErr: plan.ErrInsertIntoNonexistentColumn,
},
{
Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"' (pk, fake_col, i)",
ExpectedErr: plan.ErrInsertIntoNonexistentColumn,
},
{
Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"' (pk, i, fake_col)",
ExpectedErr: plan.ErrInsertIntoNonexistentColumn,
},
},
Expand Down
78 changes: 0 additions & 78 deletions sql/analyzer/inserts.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,6 @@ func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sc
return nil, transform.SameTree, err
}

if insert.IsReplace {
var ok bool
_, ok = insertable.(sql.ReplaceableTable)
if !ok {
return nil, transform.SameTree, plan.ErrReplaceIntoNotSupported.New()
}
}

if len(insert.OnDupExprs) > 0 {
var ok bool
_, ok = insertable.(sql.UpdatableTable)
if !ok {
return nil, transform.SameTree, plan.ErrOnDuplicateKeyUpdateNotSupported.New()
}
}

source := insert.Source
// TriggerExecutor has already been analyzed
if _, ok := insert.Source.(*plan.TriggerExecutor); !ok {
Expand Down Expand Up @@ -93,16 +77,6 @@ func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sc
for i, f := range dstSchema {
columnNames[i] = f.Name
}
} else {
err = validateColumns(table.Name(), columnNames, dstSchema, source)
if err != nil {
return nil, transform.SameTree, err
}
}

err = validateValueCount(columnNames, source)
if err != nil {
return nil, transform.SameTree, err
}

// The schema of the destination node and the underlying table differ subtly in terms of defaults
Expand Down Expand Up @@ -201,29 +175,6 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
return plan.NewProject(projExprs, insertSource), autoAutoIncrement, nil
}

func validateColumns(tableName string, columnNames []string, dstSchema sql.Schema, source sql.Node) error {
dstColNames := make(map[string]*sql.Column)
for _, dstCol := range dstSchema {
dstColNames[strings.ToLower(dstCol.Name)] = dstCol
}
usedNames := make(map[string]struct{})
for i, columnName := range columnNames {
dstCol, exists := dstColNames[columnName]
if !exists {
return plan.ErrInsertIntoNonexistentColumn.New(columnName)
}
if dstCol.Generated != nil && !validGeneratedColumnValue(i, source) {
return sql.ErrGeneratedColumnValue.New(dstCol.Name, tableName)
}
if _, exists := usedNames[columnName]; !exists {
usedNames[columnName] = struct{}{}
} else {
return plan.ErrInsertIntoDuplicateColumn.New(columnName)
}
}
return nil
}

// validGeneratedColumnValue returns true if the column is a generated column and the source node is not a values node.
// Explicit default values (`DEFAULT`) are the only valid values to specify for a generated column
func validGeneratedColumnValue(idx int, source sql.Node) bool {
Expand All @@ -248,35 +199,6 @@ func validGeneratedColumnValue(idx int, source sql.Node) bool {
}
}

func validateValueCount(columnNames []string, values sql.Node) error {
if exchange, ok := values.(*plan.Exchange); ok {
values = exchange.Child
}

switch node := values.(type) {
case *plan.Values:
for _, exprTuple := range node.ExpressionTuples {
if len(exprTuple) != len(columnNames) {
return sql.ErrInsertIntoMismatchValueCount.New()
}
}
case *plan.LoadData:
dataColLen := len(node.ColumnNames)
if dataColLen == 0 {
dataColLen = len(node.Schema())
}
if len(columnNames) != dataColLen {
return sql.ErrInsertIntoMismatchValueCount.New()
}
default:
// Parser assures us that this will be some form of SelectStatement, so no need to type check it
if len(columnNames) != len(values.Schema()) {
return sql.ErrInsertIntoMismatchValueCount.New()
}
}
return nil
}

func assertCompatibleSchemas(projExprs []sql.Expression, schema sql.Schema) error {
for _, expr := range projExprs {
switch e := expr.(type) {
Expand Down
3 changes: 3 additions & 0 deletions sql/planbuilder/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ func (b *Builder) buildInsert(inScope *scope, i *ast.Insert) (outScope *scope) {
dest := destScope.node

ins := plan.NewInsertInto(db, plan.NewInsertDestination(sch, dest), srcScope.node, isReplace, columns, onDupExprs, ignore)

b.validateInsert(ins)

outScope = destScope
outScope.node = ins
if rt != nil {
Expand Down
171 changes: 171 additions & 0 deletions sql/planbuilder/dml_validate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package planbuilder

import (
"strings"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/plan"
)

func (b *Builder) validateInsert(ins *plan.InsertInto) {
table := getResolvedTable(ins.Destination)
if table == nil {
return
}

insertable, err := plan.GetInsertable(table)
if err != nil {
b.handleErr(err)
}

if ins.IsReplace {
var ok bool
_, ok = insertable.(sql.ReplaceableTable)
if !ok {
err := plan.ErrReplaceIntoNotSupported.New()
b.handleErr(err)
}
}

if len(ins.OnDupExprs) > 0 {
var ok bool
_, ok = insertable.(sql.UpdatableTable)
if !ok {
err := plan.ErrOnDuplicateKeyUpdateNotSupported.New()
b.handleErr(err)
}
}

// normalize the column name
dstSchema := insertable.Schema()
columnNames := make([]string, len(ins.ColumnNames))
for i, name := range ins.ColumnNames {
columnNames[i] = strings.ToLower(name)
}

// If no columns are given and value tuples are not all empty, use the full schema
if len(columnNames) == 0 && existsNonZeroValueCount(ins.Source) {
columnNames = make([]string, len(dstSchema))
for i, f := range dstSchema {
columnNames[i] = f.Name
}
}

if len(ins.ColumnNames) > 0 {
err := validateColumns(table.Name(), columnNames, dstSchema, ins.Source)
if err != nil {
b.handleErr(err)
}
}

err = validateValueCount(columnNames, ins.Source)
if err != nil {
b.handleErr(err)
}
}

// Ensures that the number of elements in each Value tuple is empty
func existsNonZeroValueCount(values sql.Node) bool {
switch node := values.(type) {
case *plan.Values:
for _, exprTuple := range node.ExpressionTuples {
if len(exprTuple) != 0 {
return true
}
}
default:
return true
}
return false
}

func validateColumns(tableName string, columnNames []string, dstSchema sql.Schema, source sql.Node) error {
dstColNames := make(map[string]*sql.Column)
for _, dstCol := range dstSchema {
dstColNames[strings.ToLower(dstCol.Name)] = dstCol
}
usedNames := make(map[string]struct{})
for i, columnName := range columnNames {
dstCol, exists := dstColNames[columnName]
if !exists {
return plan.ErrInsertIntoNonexistentColumn.New(columnName)
}
if dstCol.Generated != nil && !validGeneratedColumnValue(i, source) {
return sql.ErrGeneratedColumnValue.New(dstCol.Name, tableName)
}
if _, exists := usedNames[columnName]; !exists {
usedNames[columnName] = struct{}{}
} else {
return plan.ErrInsertIntoDuplicateColumn.New(columnName)
}
}
return nil
}

// validGeneratedColumnValue returns true if the column is a generated column and the source node is not a values node.
// Explicit default values (`DEFAULT`) are the only valid values to specify for a generated column
func validGeneratedColumnValue(idx int, source sql.Node) bool {
switch source := source.(type) {
case *plan.Values:
for _, tuple := range source.ExpressionTuples {
switch val := tuple[idx].(type) {
case *sql.ColumnDefaultValue: // should be wrapped, but just in case
return true
case *expression.Wrapper:
if _, ok := val.Unwrap().(*sql.ColumnDefaultValue); ok {
return true
}
return false
default:
return false
}
}
return false
default:
return false
}
}

func validateValueCount(columnNames []string, values sql.Node) error {
if exchange, ok := values.(*plan.Exchange); ok {
values = exchange.Child
}

switch node := values.(type) {
case *plan.Values:
for _, exprTuple := range node.ExpressionTuples {
if len(exprTuple) != len(columnNames) {
return sql.ErrInsertIntoMismatchValueCount.New()
}
}
case *plan.LoadData:
dataColLen := len(node.ColumnNames)
if dataColLen == 0 {
dataColLen = len(node.Schema())
}
if len(columnNames) != dataColLen {
return sql.ErrInsertIntoMismatchValueCount.New()
}
default:
// Parser assures us that this will be some form of SelectStatement, so no need to type check it
if len(columnNames) != len(values.Schema()) {
return sql.ErrInsertIntoMismatchValueCount.New()
}
}
return nil
}
2 changes: 1 addition & 1 deletion sql/planbuilder/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (b *Builder) buildLoad(inScope *scope, d *ast.Load) (outScope *scope) {
ld := plan.NewLoadData(bool(d.Local), d.Infile, sch, columnsToStrings(d.Columns), d.Fields, d.Lines, ignoreNumVal, d.IgnoreOrReplace)
outScope = inScope.push()
ins := plan.NewInsertInto(db, plan.NewInsertDestination(sch, dest), ld, ld.IsReplace, ld.ColumnNames, nil, ld.IsIgnore)

b.validateInsert(ins)
outScope.node = ins
if rt != nil {
checks := b.loadChecksFromTable(destScope, rt.Table)
Expand Down