Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions enginetest/queries/load_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ var LoadDataScripts = []ScriptTest{
},
Assertions: []ScriptTestAssertion{
{
Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"'",

Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"'",
ExpectedErrStr: "Check constraint \"loadtable_chk_1\" violated",
},
},
Expand Down Expand Up @@ -275,11 +274,19 @@ var LoadDataErrorScripts = []ScriptTest{
{
Name: "Load data with unknown columns throws an error",
SetUpScript: []string{
"create table loadtable(pk int primary key)",
"create table loadtable(pk int primary key, i int)",
},
Assertions: []ScriptTestAssertion{
{
Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"' (bad)",
Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"' (fake_col, pk, i)",
ExpectedErr: plan.ErrInsertIntoNonexistentColumn,
},
{
Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"' (pk, fake_col, i)",
ExpectedErr: plan.ErrInsertIntoNonexistentColumn,
},
{
Query: "LOAD DATA INFILE './testdata/test1.txt' INTO TABLE loadtable FIELDS ENCLOSED BY '\"' (pk, i, fake_col)",
ExpectedErr: plan.ErrInsertIntoNonexistentColumn,
},
},
Expand Down
26 changes: 0 additions & 26 deletions sql/analyzer/inserts.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,6 @@ func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sc
return nil, transform.SameTree, err
}

if insert.IsReplace {
var ok bool
_, ok = insertable.(sql.ReplaceableTable)
if !ok {
return nil, transform.SameTree, plan.ErrReplaceIntoNotSupported.New()
}
}

if len(insert.OnDupExprs) > 0 {
var ok bool
_, ok = insertable.(sql.UpdatableTable)
if !ok {
return nil, transform.SameTree, plan.ErrOnDuplicateKeyUpdateNotSupported.New()
}
}

source := insert.Source
// TriggerExecutor has already been analyzed
if _, ok := insert.Source.(*plan.TriggerExecutor); !ok {
Expand Down Expand Up @@ -93,16 +77,6 @@ func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sc
for i, f := range dstSchema {
columnNames[i] = f.Name
}
} else {
err = validateColumns(table.Name(), columnNames, dstSchema, source)
if err != nil {
return nil, transform.SameTree, err
}
}

err = validateValueCount(columnNames, source)
if err != nil {
return nil, transform.SameTree, err
}

// The schema of the destination node and the underlying table differ subtly in terms of defaults
Expand Down
1 change: 1 addition & 0 deletions sql/analyzer/rule_ids.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ const (
validateUnionSchemasMatchId // validateUnionSchemasMatch
validateAggregationsId // validateAggregations
validateDeleteFromId // validateDeleteFrom
validateInsertsId // validateInserts

// after all
cacheSubqueryResultsId // cacheSubqueryResults
Expand Down
17 changes: 9 additions & 8 deletions sql/analyzer/ruleid_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ var DefaultValidationRules = []Rule{
{validateSubqueryColumnsId, validateSubqueryColumns},
{validateUnionSchemasMatchId, validateUnionSchemasMatch},
{validateAggregationsId, validateAggregations},
{validateInsertsId, validateInserts},
}

var OnceAfterAll = []Rule{
Expand Down
69 changes: 69 additions & 0 deletions sql/analyzer/validation_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,75 @@ func validateAggregations(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan
return n, transform.SameTree, validationErr
}

// validateInserts validates InsertInto nodes.
func validateInserts(_ *sql.Context, _ *Analyzer, n sql.Node, _ *plan.Scope, _ RuleSelector) (sql.Node, transform.TreeIdentity, error) {
if _, ok := n.(*plan.TriggerExecutor); ok {
return n, transform.SameTree, nil
} else if _, ok := n.(*plan.CreateProcedure); ok {
return n, transform.SameTree, nil
}

return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
// Skip any nodes that aren't an InsertInto
ii, ok := n.(*plan.InsertInto)
if !ok {
return n, transform.SameTree, nil
}

table := getResolvedTable(ii.Destination)

insertable, err := plan.GetInsertable(table)
if err != nil {
return n, transform.SameTree, err
}

if ii.IsReplace {
var ok bool
_, ok = insertable.(sql.ReplaceableTable)
if !ok {
return n, transform.SameTree, plan.ErrReplaceIntoNotSupported.New()
}
}

if len(ii.OnDupExprs) > 0 {
var ok bool
_, ok = insertable.(sql.UpdatableTable)
if !ok {
return n, transform.SameTree, plan.ErrOnDuplicateKeyUpdateNotSupported.New()
}
}

// normalize the column name
dstSchema := insertable.Schema()
columnNames := make([]string, len(ii.ColumnNames))
for i, name := range ii.ColumnNames {
columnNames[i] = strings.ToLower(name)
}

// If no columns are given and value tuples are not all empty, use the full schema
if len(columnNames) == 0 && existsNonZeroValueCount(ii.Source) {
columnNames = make([]string, len(dstSchema))
for i, f := range dstSchema {
columnNames[i] = f.Name
}
}

if len(ii.ColumnNames) > 0 {
err := validateColumns(table.Name(), columnNames, dstSchema, ii.Source)
if err != nil {
return n, transform.SameTree, err
}
}

err = validateValueCount(columnNames, ii.Source)
if err != nil {
return nil, transform.SameTree, err
}

return n, transform.SameTree, nil
})
}

// checkForAggregationFunctions returns an ErrAggregationUnsupported error if any aggregation
// functions are found in the specified expressions.
func checkForAggregationFunctions(exprs []sql.Expression) error {
Expand Down