diff --git a/.gitattributes b/.gitattributes index 33eef5d0d0..efc6a0cd5d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -20,3 +20,5 @@ enginetest/testdata/loaddata_enclosed.dat binary enginetest/testdata/loaddata_single_quotes.dat binary enginetest/testdata/loaddata_nulls.dat binary enginetest/testdata/loaddata_escape.dat binary +enginetest/testdata/loaddata_extra_fields.dat binary +enginetest/testdata/loaddata_enclosed.dat binary diff --git a/enginetest/queries/load_queries.go b/enginetest/queries/load_queries.go index 0b3e7ce4f1..e5c3371000 100644 --- a/enginetest/queries/load_queries.go +++ b/enginetest/queries/load_queries.go @@ -24,6 +24,74 @@ import ( ) var LoadDataScripts = []ScriptTest{ + { + Name: "LOAD DATA with unterminated enclosed field", + SetUpScript: []string{ + "CREATE TABLE t_unterminated (val VARCHAR(255))", + "LOAD DATA INFILE './testdata/loaddata_unterminated.dat' INTO TABLE t_unterminated FIELDS ENCLOSED BY '\"'", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT * FROM t_unterminated", + Expected: []sql.Row{ + {"\"unterminated field"}, + }, + }, + }, + }, + { + Name: "LOAD DATA with extra fields, user variables, and missing fields", + SetUpScript: []string{ + "CREATE TABLE t_extra (id INT PRIMARY KEY, val VARCHAR(255))", + "LOAD DATA INFILE './testdata/loaddata_extra_fields.dat' INTO TABLE t_extra FIELDS TERMINATED BY ',' (id, val, @extra1, @extra2)", + "CREATE TABLE t_short (id INT PRIMARY KEY, val VARCHAR(255) NOT NULL DEFAULT 'default')", + "LOAD DATA INFILE './testdata/loaddata_extra_fields.dat' INTO TABLE t_short FIELDS TERMINATED BY ',' (id, val)", + "CREATE TABLE t_defaults (id INT PRIMARY KEY, val VARCHAR(255) DEFAULT 'default')", + "LOAD DATA INFILE './testdata/loaddata_extra_fields.dat' INTO TABLE t_defaults FIELDS TERMINATED BY ',' (id)", + "CREATE TABLE t_discard (id INT PRIMARY KEY, val VARCHAR(255))", + "LOAD DATA INFILE './testdata/loaddata_extra_fields.dat' INTO TABLE t_discard FIELDS TERMINATED BY ',' (id, @discard, val)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT * FROM t_extra ORDER BY id", + Expected: []sql.Row{ + {1, "val1"}, + {2, "val2"}, + {3, nil}, + }, + }, + { + Query: "SELECT * FROM t_short ORDER BY id", + Expected: []sql.Row{ + {1, "val1"}, + {2, "val2"}, + {3, ""}, + }, + }, + { + Query: "SELECT * FROM t_defaults ORDER BY id", + Expected: []sql.Row{ + {1, "default"}, + {2, "default"}, + {3, "default"}, + }, + }, + { + Query: "SELECT * FROM t_discard ORDER BY id", + Expected: []sql.Row{ + {1, "extra1"}, + {2, "extra3"}, + {3, nil}, + }, + }, + { + Query: "SELECT @extra1, @extra2, @discard", + Expected: []sql.Row{ + {nil, nil, nil}, + }, + }, + }, + }, { // https://github.com/dolthub/dolt/issues/9969 Name: "LOAD DATA with ENCLOSED BY and ESCAPED BY parsing", diff --git a/enginetest/testdata/loaddata_extra_fields.dat b/enginetest/testdata/loaddata_extra_fields.dat new file mode 100644 index 0000000000..5be28160ac --- /dev/null +++ b/enginetest/testdata/loaddata_extra_fields.dat @@ -0,0 +1,3 @@ +1,val1,extra1,extra2 +2,val2,extra3 +3 \ No newline at end of file diff --git a/enginetest/testdata/loaddata_unterminated.dat b/enginetest/testdata/loaddata_unterminated.dat new file mode 100644 index 0000000000..a81a63fb79 --- /dev/null +++ b/enginetest/testdata/loaddata_unterminated.dat @@ -0,0 +1 @@ +"unterminated field \ No newline at end of file diff --git a/sql/rowexec/ddl.go b/sql/rowexec/ddl.go index cba7fab76a..a5a0996cf8 100644 --- a/sql/rowexec/ddl.go +++ b/sql/rowexec/ddl.go @@ -108,8 +108,8 @@ func (b *BaseBuilder) buildLoadData(ctx *sql.Context, n *plan.LoadData, row sql. } fieldToColMap := make([]int, len(n.UserVars)) - for fieldIdx, colIdx := 0, 0; fieldIdx < len(n.UserVars) && colIdx < len(colNames); fieldIdx++ { - if n.UserVars[fieldIdx] != nil { + for fieldIdx, colIdx := 0, 0; fieldIdx < len(n.UserVars); fieldIdx++ { + if n.UserVars[fieldIdx] != nil || colIdx >= len(colNames) { fieldToColMap[fieldIdx] = -1 continue } diff --git a/sql/rowexec/ddl_iters.go b/sql/rowexec/ddl_iters.go index d9cc0bb6fe..aaac4a0c40 100644 --- a/sql/rowexec/ddl_iters.go +++ b/sql/rowexec/ddl_iters.go @@ -225,93 +225,128 @@ func (l *loadDataIter) parseFields(ctx *sql.Context, line string) (exprs []sql.E lastField := currentField.String() // If still in enclosure at EOF when enc==esc, prepend the opening enclosure that was stripped - if inEnclosure && !normalLineTerm { + if inEnclosure { lastField = string(l.fieldsEnclosedBy[0]) + lastField } fields = append(fields, lastField) - if inEnclosure && normalLineTerm { - return nil, fmt.Errorf("error: unterminated enclosed field") - } - - fieldRow := make(sql.Row, len(fields)) - for i, field := range fields { - fieldRow[i] = field + exprs, colListRow, rowFieldToColMap, err := l.inputPreprocessor(ctx, fields) + if err != nil { + return nil, err } - exprs = make([]sql.Expression, len(l.destSch)) - for fieldIdx, exprIdx := 0, 0; fieldIdx < len(fields) && fieldIdx < len(l.userVars); fieldIdx++ { - if l.userVars[fieldIdx] != nil { - setField := l.userVars[fieldIdx].(*expression.SetField) - userVar := setField.LeftChild.(*expression.UserVar) - err := setUserVar(ctx, userVar, setField.RightChild, fieldRow) + for exprIdx, expr := range exprs { + if expr != nil { + result, err := expr.Eval(ctx, colListRow) if err != nil { return nil, err } + exprs[exprIdx] = expression.NewLiteral(result, expr.Type()) continue } - // don't check for `exprIdx < len(exprs)` in for loop - // because we still need to assign trailing user variables - if exprIdx >= len(exprs) { + destColIdx := rowFieldToColMap[exprIdx] + if destColIdx == -1 { continue } - field := fields[fieldIdx] - switch field { - case "": - // Replace the empty string with defaults if exists, otherwise NULL - destCol := l.destSch[l.fieldToColMap[fieldIdx]] - if _, ok := destCol.Type.(sql.StringType); ok { - exprs[exprIdx] = expression.NewLiteral(field, types.LongText) - } else { - if destCol.Default != nil { - exprs[exprIdx] = destCol.Default - } else { - exprs[exprIdx] = expression.NewLiteral(nil, types.Null) + field := colListRow[exprIdx] + destCol := l.destSch[destColIdx] + + if field != nil { + switch field { + case "": + if _, ok := destCol.Type.(sql.StringType); ok { + exprs[exprIdx] = expression.NewLiteral(field, types.LongText) } + case "NULL": + exprs[exprIdx] = expression.NewLiteral(nil, types.Null) + default: + exprs[exprIdx] = expression.NewLiteral(field, types.LongText) } - case "NULL": - exprs[exprIdx] = expression.NewLiteral(nil, types.Null) - default: - exprs[exprIdx] = expression.NewLiteral(field, types.LongText) - } - exprIdx++ - } - - // Apply Set Expressions by replacing the corresponding field expression with the set expression - for fieldIdx, exprIdx := 0, 0; len(l.setExprs) > 0 && fieldIdx < len(l.fieldToColMap) && exprIdx < len(exprs); fieldIdx++ { - setIdx := l.fieldToColMap[fieldIdx] - if setIdx == -1 { continue } - setExpr := l.setExprs[setIdx] - if setExpr != nil { - res, err := setExpr.Eval(ctx, fieldRow) - if err != nil { - return nil, err - } - exprs[exprIdx] = expression.NewLiteral(res, setExpr.Type()) + + // If the field is still nil, the input line did not contain enough fields to satisfy the column list. For + // non-nullable columns, MySQL treats this as a data truncation and assigns the implicit "zero value" for the + // data type (e.g. an empty string or 0) instead of the explicit schema default. + if !destCol.Nullable && !destCol.AutoIncrement { + exprs[exprIdx] = expression.NewLiteral(destCol.Type.Zero(), destCol.Type) + } else { + exprs[exprIdx] = destCol.Default } - exprIdx++ } - // Due to how projections work, if no columns are provided (each row may have a variable number of values), the - // projection will not insert default values, so we must do it here. - if l.colCount == 0 { - for exprIdx, expr := range exprs { - if expr != nil { + return exprs, nil +} + +// inputPreprocessor takes in the |parsedFields| extracted from a [plan.LoadData.File] line to correctly place +// preprocessors (i.e. the SET clause allowing you to perform transformations on values before assigning their result to +// a column), and to reindex new field positions without user variables into a [sql.Row], and [sql.Expression] array. +// Per row results can differentiate, and we only care about column fields for expressions anyway, so we don't include +// user variables in the returned results of this function. If a user variable is included in the returned [sql.Row], it +// could mess with the projection of other fields, because it offsets anything that comes after it. +// +// For more information on preprocessors, see the documentation for "[Input Preprocessing]". +// +// [Input Preprocessing]: https://dev.mysql.com/doc/refman/9.5/en/load-data.html#load-data-input-preprocessing +func (l *loadDataIter) inputPreprocessor(ctx *sql.Context, parsedFields []string) (expressions []sql.Expression, colListRow sql.Row, rowFieldToColMap map[int]int, err error) { + colListRow = make(sql.Row, len(l.destSch)) + expressions = make([]sql.Expression, len(l.destSch)) + rowFieldToColMap = make(map[int]int) + // colListIdx must only increment on column fields or preprocessors. + colListIdx := 0 + for fieldIdx, destColIdx := range l.fieldToColMap { + if l.userVars[fieldIdx] != nil { + setField := l.userVars[fieldIdx].(*expression.SetField) + userVar := setField.LeftChild.(*expression.UserVar) + if fieldIdx >= len(parsedFields) { + err = ctx.SetUserVariable(ctx, userVar.Name, nil, types.Null) + if err != nil { + return nil, nil, nil, err + } continue } - col := l.destSch[exprIdx] - if !col.Nullable && col.Default == nil && !col.AutoIncrement { - return nil, sql.ErrInsertIntoNonNullableDefaultNullColumn.New(col.Name) + + field := parsedFields[fieldIdx] + fieldType := types.ApproximateTypeFromValue(field) + err = ctx.SetUserVariable(ctx, userVar.Name, field, fieldType) + if err != nil { + return nil, nil, nil, err } - exprs[exprIdx] = col.Default + continue } + + // We've filled all possible column-only fields for the destination [sql.Schema], but other user variables could + // exist past this length. Any other `continue` statements below apply the same thought process. + if colListIdx >= len(expressions) { + continue + } + + rowFieldToColMap[colListIdx] = destColIdx + + // The preprocessors are placed ahead of time to let callers know they can be evaluated and should not be + // overwritten. loadDataIter.setExprs uses the destination [sql.Schema] indices as its map. This deviates from + // loadDataIter.fieldToColMap which includes all column field indices first, and *then* preprocessor expression + // indices (incrementing from where columns fields left off). For that reason, we must use the destination + // column index to get the correct expression. + if l.setExprs != nil && destColIdx != -1 && l.setExprs[destColIdx] != nil { + expressions[colListIdx] = l.setExprs[destColIdx] + } + + if fieldIdx >= len(parsedFields) { + colListIdx++ + continue + } + + // We need to provide a [sql.Row] with all non-user variable claimed values to later evaluate expressions that + // rely on their values. + field := parsedFields[fieldIdx] + colListRow[colListIdx] = field + colListIdx++ } - return exprs, nil + return expressions, colListRow, rowFieldToColMap, nil } type modifyColumnIter struct {