diff --git a/go.mod b/go.mod index a7203b0b62..630c85f845 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00 - github.com/dolthub/go-mysql-server v0.19.1-0.20250305230031-14a57e076a0a + github.com/dolthub/go-mysql-server v0.19.1-0.20250306014046-f73a318f7731 github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a github.com/fatih/color v1.13.0 diff --git a/go.sum b/go.sum index 1464050ea3..3e784c444d 100644 --- a/go.sum +++ b/go.sum @@ -224,8 +224,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00 h1:rh2ij2yTYKJWlX+c8XRg4H5OzqPewbU1lPK8pcfVmx8= github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= -github.com/dolthub/go-mysql-server v0.19.1-0.20250305230031-14a57e076a0a h1:lemFIUt0NCKIeX7vnU2yKF8UIgc0DT8zIoEUn7oy+60= -github.com/dolthub/go-mysql-server v0.19.1-0.20250305230031-14a57e076a0a/go.mod h1:yr+Vv47/YLOKMgiEY+QxHTlbIVpTuiVtkEZ5l+xruY4= +github.com/dolthub/go-mysql-server v0.19.1-0.20250306014046-f73a318f7731 h1:flDUXUqKRo4u5gdoQZBeO3jESUnCNkv01GDmiZbgAA4= +github.com/dolthub/go-mysql-server v0.19.1-0.20250306014046-f73a318f7731/go.mod h1:yr+Vv47/YLOKMgiEY+QxHTlbIVpTuiVtkEZ5l+xruY4= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= diff --git a/server/analyzer/init.go b/server/analyzer/init.go index 0282e5c011..2e1d605c8d 100644 --- a/server/analyzer/init.go +++ b/server/analyzer/init.go @@ -37,6 +37,7 @@ const ( ruleId_OptimizeFunctions // optimizeFunctions ruleId_ValidateColumnDefaults // validateColumnDefaults ruleId_ValidateCreateTable // validateCreateTable + rulesId_ResolveAlterColumn // resolveAlterColumn ) // Init adds additional rules to the analyzer to handle Doltgres-specific functionality. @@ -56,11 +57,18 @@ func Init() { analyzer.OnceBeforeDefault = append([]analyzer.Rule{{Id: ruleId_AddImplicitPrefixLengths, Apply: AddImplicitPrefixLengths}}, analyzer.OnceBeforeDefault...) + // We remove several validation rules and substitute our own analyzer.OnceBeforeDefault = insertAnalyzerRules(analyzer.OnceBeforeDefault, analyzer.ValidateCreateTableId, true, analyzer.Rule{Id: ruleId_ValidateCreateTable, Apply: validateCreateTable}) + analyzer.OnceBeforeDefault = insertAnalyzerRules(analyzer.OnceBeforeDefault, analyzer.ResolveAlterColumnId, true, + analyzer.Rule{Id: rulesId_ResolveAlterColumn, Apply: resolveAlterColumn}) - // We remove the original column default and create table validation rules, as we have our own implementations - analyzer.OnceBeforeDefault = removeAnalyzerRules(analyzer.OnceBeforeDefault, analyzer.ValidateColumnDefaultsId, analyzer.ValidateCreateTableId) + analyzer.OnceBeforeDefault = removeAnalyzerRules( + analyzer.OnceBeforeDefault, + analyzer.ValidateColumnDefaultsId, + analyzer.ValidateCreateTableId, + analyzer.ResolveAlterColumnId, + ) // Remove all other validation rules that do not apply to Postgres analyzer.DefaultValidationRules = removeAnalyzerRules(analyzer.DefaultValidationRules, analyzer.ValidateOperandsId) @@ -82,9 +90,11 @@ func Init() { // insertAnalyzerRules inserts the given rule(s) before or after the given analyzer.RuleId, returning an updated slice. func insertAnalyzerRules(rules []analyzer.Rule, id analyzer.RuleId, before bool, additionalRules ...analyzer.Rule) []analyzer.Rule { + inserted := false newRules := make([]analyzer.Rule, len(rules)+len(additionalRules)) for i, rule := range rules { if rule.Id == id { + inserted = true if before { copy(newRules, rules[:i]) copy(newRules[i:], additionalRules) @@ -97,6 +107,11 @@ func insertAnalyzerRules(rules []analyzer.Rule, id analyzer.RuleId, before bool, break } } + + if !inserted { + panic("no rules were inserted") + } + return newRules } @@ -106,11 +121,20 @@ func removeAnalyzerRules(rules []analyzer.Rule, remove ...analyzer.RuleId) []ana for _, removal := range remove { ids[removal] = struct{}{} } + + removedIds := 0 var newRules []analyzer.Rule for _, rule := range rules { if _, ok := ids[rule.Id]; !ok { newRules = append(newRules, rule) + } else { + removedIds++ } } + + if removedIds < len(remove) { + panic("one or more rules were not removed, this is a bug") + } + return newRules } diff --git a/server/analyzer/validate_create_table.go b/server/analyzer/validate_create_table.go index eadca47402..370b6bf5f0 100755 --- a/server/analyzer/validate_create_table.go +++ b/server/analyzer/validate_create_table.go @@ -49,14 +49,16 @@ func validateCreateTable(ctx *sql.Context, a *analyzer.Analyzer, n sql.Node, sco // validateIdentifiers validates the names of all schema elements for validity // TODO: we use 64 character as the max length for an identifier, postgres uses 63 func validateIdentifiers(ct *plan.CreateTable) error { - if len(ct.Name()) > sql.MaxIdentifierLength { - return sql.ErrInvalidIdentifier.New(ct.Name()) + err := analyzer.ValidateIdentifier(ct.Name()) + if err != nil { + return err } colNames := make(map[string]bool) for _, col := range ct.PkSchema().Schema { - if len(col.Name) > sql.MaxIdentifierLength { - return sql.ErrInvalidIdentifier.New(col.Name) + err = analyzer.ValidateIdentifier(col.Name) + if err != nil { + return err } lower := strings.ToLower(col.Name) if colNames[lower] { @@ -66,20 +68,23 @@ func validateIdentifiers(ct *plan.CreateTable) error { } for _, chDef := range ct.Checks() { - if len(chDef.Name) > sql.MaxIdentifierLength { - return sql.ErrInvalidIdentifier.New(chDef.Name) + err = analyzer.ValidateIdentifier(chDef.Name) + if err != nil { + return err } } for _, idxDef := range ct.Indexes() { - if len(idxDef.Name) > sql.MaxIdentifierLength { - return sql.ErrInvalidIdentifier.New(idxDef.Name) + err = analyzer.ValidateIdentifier(idxDef.Name) + if err != nil { + return err } } for _, fkDef := range ct.ForeignKeys() { - if len(fkDef.Name) > sql.MaxIdentifierLength { - return sql.ErrInvalidIdentifier.New(fkDef.Name) + err = analyzer.ValidateIdentifier(fkDef.Name) + if err != nil { + return err } } @@ -128,3 +133,316 @@ func validateIndex(ctx *sql.Context, colMap map[string]*sql.Column, idxDef *sql. return nil } + +// resolveAlterColumn is a validation rule that validates the schema changes in an ALTER TABLE statement and updates +// the nodes with necessary intermediate / update schema information +func resolveAlterColumn(ctx *sql.Context, a *analyzer.Analyzer, n sql.Node, scope *plan.Scope, sel analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { + if !analyzer.FlagIsSet(qFlags, sql.QFlagAlterTable) { + return n, transform.SameTree, nil + } + + var sch sql.Schema + var indexes []string + var validator sql.SchemaValidator + keyedColumns := make(map[string]bool) + var err error + transform.Inspect(n, func(n sql.Node) bool { + if st, ok := n.(sql.SchemaTarget); ok { + sch = st.TargetSchema() + } + switch n := n.(type) { + case *plan.ModifyColumn: + if rt, ok := n.Table.(*plan.ResolvedTable); ok { + if sv, ok := rt.UnwrappedDatabase().(sql.SchemaValidator); ok { + validator = sv + } + } + keyedColumns, err = analyzer.GetTableIndexColumns(ctx, n.Table) + return false + case *plan.RenameColumn: + if rt, ok := n.Table.(*plan.ResolvedTable); ok { + if sv, ok := rt.UnwrappedDatabase().(sql.SchemaValidator); ok { + validator = sv + } + } + return false + case *plan.AddColumn: + if rt, ok := n.Table.(*plan.ResolvedTable); ok { + if sv, ok := rt.UnwrappedDatabase().(sql.SchemaValidator); ok { + validator = sv + } + } + keyedColumns, err = analyzer.GetTableIndexColumns(ctx, n.Table) + return false + case *plan.DropColumn: + if rt, ok := n.Table.(*plan.ResolvedTable); ok { + if sv, ok := rt.UnwrappedDatabase().(sql.SchemaValidator); ok { + validator = sv + } + } + return false + case *plan.AlterIndex: + if rt, ok := n.Table.(*plan.ResolvedTable); ok { + if sv, ok := rt.UnwrappedDatabase().(sql.SchemaValidator); ok { + validator = sv + } + } + indexes, err = analyzer.GetTableIndexNames(ctx, a, n.Table) + default: + } + return true + }) + + if err != nil { + return nil, transform.SameTree, err + } + + // Skip this validation if we didn't find one or more of the above node types + if len(sch) == 0 { + return n, transform.SameTree, nil + } + + sch = sch.Copy() // Make a copy of the original schema to deal with any references to the original table. + initialSch := sch + + // Need a TransformUp here because multiple of these statement types can be nested under a Block node. + // It doesn't look it, but this is actually an iterative loop over all the independent clauses in an ALTER statement + n, same, err := transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { + switch nn := n.(type) { + case *plan.ModifyColumn: + n, err := nn.WithTargetSchema(sch.Copy()) + if err != nil { + return nil, transform.SameTree, err + } + + sch, err = analyzer.ValidateModifyColumn(ctx, initialSch, sch, n.(*plan.ModifyColumn), keyedColumns) + if err != nil { + return nil, transform.SameTree, err + } + return n, transform.NewTree, nil + case *plan.RenameColumn: + n, err := nn.WithTargetSchema(sch.Copy()) + if err != nil { + return nil, transform.SameTree, err + } + sch, err = analyzer.ValidateRenameColumn(initialSch, sch, n.(*plan.RenameColumn)) + if err != nil { + return nil, transform.SameTree, err + } + return n, transform.NewTree, nil + case *plan.AddColumn: + n, err := nn.WithTargetSchema(sch.Copy()) + if err != nil { + return nil, transform.SameTree, err + } + + sch, err = analyzer.ValidateAddColumn(sch, n.(*plan.AddColumn)) + if err != nil { + return nil, transform.SameTree, err + } + + return n, transform.NewTree, nil + case *plan.DropColumn: + n, err := nn.WithTargetSchema(sch.Copy()) + if err != nil { + return nil, transform.SameTree, err + } + sch, err = analyzer.ValidateDropColumn(initialSch, sch, n.(*plan.DropColumn)) + if err != nil { + return nil, transform.SameTree, err + } + delete(keyedColumns, nn.Column) + + return n, transform.NewTree, nil + case *plan.AlterIndex: + n, err := nn.WithTargetSchema(sch.Copy()) + if err != nil { + return nil, transform.SameTree, err + } + indexes, err = validateAlterIndex(ctx, initialSch, sch, n.(*plan.AlterIndex), indexes) + if err != nil { + return nil, transform.SameTree, err + } + + keyedColumns = analyzer.UpdateKeyedColumns(keyedColumns, nn) + return n, transform.NewTree, nil + case *plan.AlterPK: + n, err := nn.WithTargetSchema(sch.Copy()) + if err != nil { + return nil, transform.SameTree, err + } + sch, err = validatePrimaryKey(ctx, initialSch, sch, n.(*plan.AlterPK)) + if err != nil { + return nil, transform.SameTree, err + } + return n, transform.NewTree, nil + case *plan.AlterDefaultSet: + n, err := nn.WithTargetSchema(sch.Copy()) + if err != nil { + return nil, transform.SameTree, err + } + sch, err = analyzer.ValidateAlterDefault(initialSch, sch, n.(*plan.AlterDefaultSet)) + if err != nil { + return nil, transform.SameTree, err + } + return n, transform.NewTree, nil + case *plan.AlterDefaultDrop: + n, err := nn.WithTargetSchema(sch.Copy()) + if err != nil { + return nil, transform.SameTree, err + } + sch, err = analyzer.ValidateDropDefault(initialSch, sch, n.(*plan.AlterDefaultDrop)) + if err != nil { + return nil, transform.SameTree, err + } + return n, transform.NewTree, nil + } + return n, transform.SameTree, nil + }) + + if err != nil { + return nil, transform.SameTree, err + } + + if validator != nil { + if err := validator.ValidateSchema(sch); err != nil { + return nil, transform.SameTree, err + } + } + + return n, same, nil +} + +// Returns the underlying table name for the node given +func getTableName(node sql.Node) string { + var tableName string + transform.Inspect(node, func(node sql.Node) bool { + switch node := node.(type) { + case *plan.TableAlias: + tableName = node.Name() + return false + case *plan.ResolvedTable: + tableName = node.Name() + return false + case *plan.UnresolvedTable: + tableName = node.Name() + return false + case *plan.IndexedTableAccess: + tableName = node.Name() + return false + } + return true + }) + + return tableName +} + +// validatePrimaryKey validates a primary key add or drop operation. +func validatePrimaryKey(ctx *sql.Context, initialSch, sch sql.Schema, ai *plan.AlterPK) (sql.Schema, error) { + tableName := getTableName(ai.Table) + switch ai.Action { + case plan.PrimaryKeyAction_Create: + if analyzer.HasPrimaryKeys(sch) { + return nil, sql.ErrMultiplePrimaryKeysDefined.New() + } + + colMap := schToColMap(sch) + idxDef := &sql.IndexDef{ + Name: "PRIMARY", + Columns: ai.Columns, + Constraint: sql.IndexConstraint_Primary, + } + + err := validateIndex(ctx, colMap, idxDef) + if err != nil { + return nil, err + } + + for _, idxCol := range ai.Columns { + schCol := colMap[strings.ToLower(idxCol.Name)] + if schCol.Virtual { + return nil, sql.ErrVirtualColumnPrimaryKey.New() + } + } + + // Set the primary keys + for _, col := range ai.Columns { + sch[sch.IndexOf(col.Name, tableName)].PrimaryKey = true + } + + return sch, nil + case plan.PrimaryKeyAction_Drop: + if !analyzer.HasPrimaryKeys(sch) { + return nil, sql.ErrCantDropFieldOrKey.New("PRIMARY") + } + + for _, col := range sch { + if col.PrimaryKey { + col.PrimaryKey = false + } + } + + return sch, nil + default: + return sch, nil + } +} + +// validateAlterIndex validates the specified column can have an index added, dropped, or renamed. Returns an updated +// list of index name given the add, drop, or rename operations. +func validateAlterIndex(ctx *sql.Context, initialSch, sch sql.Schema, ai *plan.AlterIndex, indexes []string) ([]string, error) { + switch ai.Action { + case plan.IndexAction_Create: + err := analyzer.ValidateIdentifier(ai.IndexName) + if err != nil { + return nil, err + } + colMap := schToColMap(sch) + + // TODO: plan.AlterIndex should just have a sql.IndexDef + indexDef := &sql.IndexDef{ + Name: ai.IndexName, + Columns: ai.Columns, + Constraint: ai.Constraint, + Storage: ai.Using, + Comment: ai.Comment, + } + + err = validateIndex(ctx, colMap, indexDef) + if err != nil { + return nil, err + } + return append(indexes, ai.IndexName), nil + case plan.IndexAction_Drop: + savedIdx := -1 + for i, idx := range indexes { + if strings.EqualFold(idx, ai.IndexName) { + savedIdx = i + break + } + } + if savedIdx == -1 { + return nil, sql.ErrCantDropFieldOrKey.New(ai.IndexName) + } + // Remove the index from the list + return append(indexes[:savedIdx], indexes[savedIdx+1:]...), nil + case plan.IndexAction_Rename: + err := analyzer.ValidateIdentifier(ai.IndexName) + if err != nil { + return nil, err + } + savedIdx := -1 + for i, idx := range indexes { + if strings.EqualFold(idx, ai.PreviousIndexName) { + savedIdx = i + } + } + if savedIdx == -1 { + return nil, sql.ErrCantDropFieldOrKey.New(ai.IndexName) + } + // Simulate the rename by deleting the old name and adding the new one. + return append(append(indexes[:savedIdx], indexes[savedIdx+1:]...), ai.IndexName), nil + } + + return indexes, nil +} diff --git a/testing/go/alter_table_test.go b/testing/go/alter_table_test.go index 96deab1737..7c0e0a460d 100644 --- a/testing/go/alter_table_test.go +++ b/testing/go/alter_table_test.go @@ -227,6 +227,31 @@ func TestAlterTable(t *testing.T) { }, }, }, + { + Name: "Add Primary Key on text column", + SetUpScript: []string{ + "CREATE TABLE test1 (a text, b INT);", + "insert into test1 values ('a', 1), ('b', 2);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "ALTER TABLE test1 ADD PRIMARY KEY (a);", + Expected: []sql.Row{}, + }, + { + // Test the pk by inserting a duplicate value + Query: "INSERT into test1 values ('a', 3);", + ExpectedErr: "duplicate primary key", + }, + { + Query: "select * from test1;", + Expected: []sql.Row{ + {"a", 1}, + {"b", 2}, + }, + }, + }, + }, { Name: "Add primary key with generated column", SetUpScript: []string{