diff --git a/go.mod b/go.mod index 630c85f845..f4d415ae45 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00 - github.com/dolthub/go-mysql-server v0.19.1-0.20250306014046-f73a318f7731 + github.com/dolthub/go-mysql-server v0.19.1-0.20250307161823-e8ce0df0d8f2 github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a github.com/fatih/color v1.13.0 diff --git a/go.sum b/go.sum index 3e784c444d..80e35540b7 100644 --- a/go.sum +++ b/go.sum @@ -224,8 +224,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00 h1:rh2ij2yTYKJWlX+c8XRg4H5OzqPewbU1lPK8pcfVmx8= github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= -github.com/dolthub/go-mysql-server v0.19.1-0.20250306014046-f73a318f7731 h1:flDUXUqKRo4u5gdoQZBeO3jESUnCNkv01GDmiZbgAA4= -github.com/dolthub/go-mysql-server v0.19.1-0.20250306014046-f73a318f7731/go.mod h1:yr+Vv47/YLOKMgiEY+QxHTlbIVpTuiVtkEZ5l+xruY4= +github.com/dolthub/go-mysql-server v0.19.1-0.20250307161823-e8ce0df0d8f2 h1:5AAJTJWaiYO1ut8TkEKbeBTp3x0UIOd1yV9nC1s3vjg= +github.com/dolthub/go-mysql-server v0.19.1-0.20250307161823-e8ce0df0d8f2/go.mod h1:yr+Vv47/YLOKMgiEY+QxHTlbIVpTuiVtkEZ5l+xruY4= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= diff --git a/server/analyzer/init.go b/server/analyzer/init.go index 2e1d605c8d..27a65c4428 100644 --- a/server/analyzer/init.go +++ b/server/analyzer/init.go @@ -15,7 +15,12 @@ package analyzer import ( + "strings" + + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/analyzer" + "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/vitess/go/sqltypes" ) // IDs are basically arbitrary, we just need to ensure that they do not conflict with existing IDs @@ -86,6 +91,49 @@ func Init() { analyzer.Rule{Id: ruleId_AddDomainConstraintsToCasts, Apply: AddDomainConstraintsToCasts}, analyzer.Rule{Id: ruleId_ReplaceNode, Apply: ReplaceNode}, analyzer.Rule{Id: ruleId_InsertContextRootFinalizer, Apply: InsertContextRootFinalizer}) + + initEngine() +} + +func initEngine() { + plan.ValidateForeignKeyDefinition = validateForeignKeyDefinition +} + +// validateForeignKeyDefinition validates that the given foreign key definition is valid for creation +func validateForeignKeyDefinition(ctx *sql.Context, fkDef sql.ForeignKeyConstraint, cols map[string]*sql.Column, parentCols map[string]*sql.Column) error { + // TODO: this check is too permissive, we should be doing some type checks here + for i := range fkDef.Columns { + col := cols[strings.ToLower(fkDef.Columns[i])] + parentCol := parentCols[strings.ToLower(fkDef.ParentColumns[i])] + if !foreignKeyComparableTypes(ctx, col.Type, parentCol.Type) { + return sql.ErrForeignKeyColumnTypeMismatch.New(fkDef.Columns[i], fkDef.ParentColumns[i]) + } + } + return nil +} + +// foreignKeyComparableTypes returns whether the two given types are able to be used as parent/child columns in a +// foreign key. +func foreignKeyComparableTypes(ctx *sql.Context, type1 sql.Type, type2 sql.Type) bool { + if !type1.Equals(type2) { + // There seems to be a special case where CHAR/VARCHAR/BINARY/VARBINARY can have unequal lengths. + // Have not tested every type nor combination, but this seems specific to those 4 types. + if type1.Type() == type2.Type() { + switch type1.Type() { + case sqltypes.Char, sqltypes.VarChar, sqltypes.Binary, sqltypes.VarBinary: + type1String := type1.(sql.StringType) + type2String := type2.(sql.StringType) + if type1String.Collation().CharacterSet() != type2String.Collation().CharacterSet() { + return false + } + default: + return false + } + } else { + return false + } + } + return true } // insertAnalyzerRules inserts the given rule(s) before or after the given analyzer.RuleId, returning an updated slice. diff --git a/testing/go/foreign_keys_test.go b/testing/go/foreign_keys_test.go index aee7ca707a..7b659ca21b 100755 --- a/testing/go/foreign_keys_test.go +++ b/testing/go/foreign_keys_test.go @@ -44,6 +44,26 @@ func TestForeignKeys(t *testing.T) { }, }, }, + { + Name: "text foreign key", + SetUpScript: []string{ + `CREATE TABLE parent (a text PRIMARY KEY, b int)`, + `CREATE TABLE child (a INT PRIMARY KEY, b text, FOREIGN KEY (b) REFERENCES parent(a))`, + `INSERT INTO parent VALUES ('a', 1)`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "INSERT INTO child VALUES (1, 'a')", + }, + { + Query: "INSERT INTO child VALUES (2, 'a')", + }, + { + Query: "INSERT INTO child VALUES (3, 'b')", + ExpectedErr: "Foreign key violation", + }, + }, + }, { Name: "foreign key with dolt_add, dolt_commit", SetUpScript: []string{