diff --git a/enginetest/queries/create_table_queries.go b/enginetest/queries/create_table_queries.go index 69079bc78f..2ae413bc23 100644 --- a/enginetest/queries/create_table_queries.go +++ b/enginetest/queries/create_table_queries.go @@ -228,6 +228,36 @@ var CreateTableQueries = []WriteQueryTest{ SelectQuery: `select * from t1 order by i`, ExpectedSelect: []sql.Row{{"newfirst row", 1}, {"newsecond row", 2}, {"newthird row", 3}}, }, + { + WriteQuery: `CREATE TABLE t1 (pk varchar(10) primary key collate binary)`, + ExpectedWriteResult: []sql.Row{{types.NewOkResult(0)}}, + SelectQuery: `SHOW CREATE TABLE t1`, + ExpectedSelect: []sql.Row{sql.Row{"t1", "CREATE TABLE `t1` (\n `pk` varbinary(10) NOT NULL,\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + WriteQuery: `CREATE TABLE t1 (pk varchar(10) primary key charset binary)`, + ExpectedWriteResult: []sql.Row{{types.NewOkResult(0)}}, + SelectQuery: `SHOW CREATE TABLE t1`, + ExpectedSelect: []sql.Row{sql.Row{"t1", "CREATE TABLE `t1` (\n `pk` varbinary(10) NOT NULL,\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + WriteQuery: `CREATE TABLE t1 (pk varchar(10) primary key character set binary)`, + ExpectedWriteResult: []sql.Row{{types.NewOkResult(0)}}, + SelectQuery: `SHOW CREATE TABLE t1`, + ExpectedSelect: []sql.Row{sql.Row{"t1", "CREATE TABLE `t1` (\n `pk` varbinary(10) NOT NULL,\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + WriteQuery: `CREATE TABLE t1 (pk varchar(10) primary key charset binary collate binary)`, + ExpectedWriteResult: []sql.Row{{types.NewOkResult(0)}}, + SelectQuery: `SHOW CREATE TABLE t1`, + ExpectedSelect: []sql.Row{sql.Row{"t1", "CREATE TABLE `t1` (\n `pk` varbinary(10) NOT NULL,\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + WriteQuery: `CREATE TABLE t1 (pk varchar(10) primary key character set binary collate binary)`, + ExpectedWriteResult: []sql.Row{{types.NewOkResult(0)}}, + SelectQuery: `SHOW CREATE TABLE t1`, + ExpectedSelect: []sql.Row{sql.Row{"t1", "CREATE TABLE `t1` (\n `pk` varbinary(10) NOT NULL,\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, } var CreateTableScriptTests = []ScriptTest{ @@ -496,6 +526,38 @@ var CreateTableScriptTests = []ScriptTest{ }, }, }, + { + Name: "valid character set and collation options", + SetUpScript: []string{ + "create table parent (a int primary key)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: `CREATE TABLE t1 (pk varbinary(10) primary key collate utf8mb4_0900_bin)`, + ExpectedErr: types.ErrBinaryCollation, + }, + { + Query: `CREATE TABLE t1 (pk varbinary(10) primary key charset utf8mb4_0900_bin)`, + ExpectedErr: types.ErrCharacterSetOnInvalidType, + }, + { + Query: `CREATE TABLE t1 (pk varbinary(10) primary key character set utf8mb4)`, + ExpectedErr: types.ErrCharacterSetOnInvalidType, + }, + { + Query: `CREATE TABLE t1 (pk varbinary(10) primary key charset utf8mb4 collate utf8mb4_0900_bin)`, + ExpectedErr: types.ErrCharacterSetOnInvalidType, + }, + { + Query: `CREATE TABLE t1 (pk varbinary(10) primary key character set utf8mb4 collate utf8mb4_0900_bin)`, + ExpectedErr: types.ErrCharacterSetOnInvalidType, + }, + { + Query: `CREATE TABLE t1 (pk int primary key character set utf8mb4)`, + ExpectedErr: types.ErrCharacterSetOnInvalidType, + }, + }, + }, } var CreateTableAutoIncrementTests = []ScriptTest{ diff --git a/go.mod b/go.mod index 80cb8f2b8e..5261609a0e 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20231127171856-2466012fb61f + github.com/dolthub/vitess v0.0.0-20231202001124-09287d7cc674 github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.1 github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index b210a131d5..1706f7eaaf 100644 --- a/go.sum +++ b/go.sum @@ -70,6 +70,8 @@ github.com/dolthub/vitess v0.0.0-20231109003730-c0fa018b5ef6 h1:/GOBV8ceNCMuyS9/ github.com/dolthub/vitess v0.0.0-20231109003730-c0fa018b5ef6/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= github.com/dolthub/vitess v0.0.0-20231127171856-2466012fb61f h1:I480LKHhb4usnF3dYhp6J4ORKMrncNKaWYZvIZwlK+U= github.com/dolthub/vitess v0.0.0-20231127171856-2466012fb61f/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= +github.com/dolthub/vitess v0.0.0-20231202001124-09287d7cc674 h1:OYEf4PpMUG7rj51/l/Da9Iv9tAMnBvnEi+dA+tsoKGA= +github.com/dolthub/vitess v0.0.0-20231202001124-09287d7cc674/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= diff --git a/sql/planbuilder/ddl.go b/sql/planbuilder/ddl.go index 5042acfa10..0b9f373be5 100644 --- a/sql/planbuilder/ddl.go +++ b/sql/planbuilder/ddl.go @@ -1037,9 +1037,10 @@ func (b *Builder) tableSpecToSchema(inScope, outScope *scope, db sql.Database, t generated := make([]ast.Expr, len(tableSpec.Columns)) var schema sql.Schema for i, cd := range tableSpec.Columns { + sqlType := cd.Type.SQLType() // Use the table's collation if no character or collation was specified for the table if len(cd.Type.Charset) == 0 && len(cd.Type.Collate) == 0 { - if tableCollation != sql.Collation_Unspecified { + if tableCollation != sql.Collation_Unspecified && !types.IsBinary(sqlType) { cd.Type.Collate = tableCollation.Name() } } diff --git a/sql/types/conversion.go b/sql/types/conversion.go index 4bf5f54598..7cc19f3d80 100644 --- a/sql/types/conversion.go +++ b/sql/types/conversion.go @@ -21,8 +21,10 @@ import ( "time" "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/vt/proto/query" "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/shopspring/decimal" + "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" ) @@ -116,8 +118,47 @@ func ApproximateTypeFromValue(val interface{}) sql.Type { } } +// IsBinary returns whether the type represents binary data. +func IsBinary(sqlType query.Type) bool { + switch sqlType { + case sqltypes.Binary, + sqltypes.VarBinary, + sqltypes.Blob, + sqltypes.TypeJSON, + sqltypes.Geometry: + return true + } + return false +} + +func allowsCharSet(sqlType query.Type) bool { + switch sqlType { + case sqltypes.VarChar, + sqltypes.Char, + sqltypes.Text, + sqltypes.Enum, + sqltypes.Set: + return true + } + return false +} + +var ErrCharacterSetOnInvalidType = errors.NewKind("Only character columns, enums, and sets can have a CHARACTER SET option") + // ColumnTypeToType gets the column type using the column definition. func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) { + + sqlType := ct.SQLType() + + if !allowsCharSet(sqlType) && len(ct.Charset) != 0 { + return nil, ErrCharacterSetOnInvalidType.New() + } + + collate := ct.Collate + if IsBinary(sqlType) && collate == "" { + collate = sql.Collation_binary.Name() + } + switch strings.ToLower(ct.Type) { case "boolean", "bool": return Boolean, nil @@ -206,29 +247,14 @@ func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) { } } return CreateBitType(uint8(length)) - case "tinyblob": - return TinyBlob, nil - case "blob": - if ct.Length == nil { - return Blob, nil - } - length, err := strconv.ParseInt(string(ct.Length.Val), 10, 64) - if err != nil { - return nil, err - } - return CreateBinary(sqltypes.Blob, length) - case "mediumblob": - return MediumBlob, nil - case "longblob": - return LongBlob, nil - case "tinytext": - collation, err := sql.ParseCollation(&ct.Charset, &ct.Collate, ct.BinaryCollate) + case "tinytext", "tinyblob": + collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) if err != nil { return nil, err } return CreateString(sqltypes.Text, TinyTextBlobMax/collation.CharacterSet().MaxLength(), collation) - case "text": - collation, err := sql.ParseCollation(&ct.Charset, &ct.Collate, ct.BinaryCollate) + case "text", "blob": + collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) if err != nil { return nil, err } @@ -240,20 +266,20 @@ func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) { return nil, err } return CreateString(sqltypes.Text, length, collation) - case "mediumtext", "long", "long varchar": - collation, err := sql.ParseCollation(&ct.Charset, &ct.Collate, ct.BinaryCollate) + case "mediumtext", "mediumblob", "long", "long varchar": + collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) if err != nil { return nil, err } return CreateString(sqltypes.Text, MediumTextBlobMax/collation.CharacterSet().MaxLength(), collation) - case "longtext": - collation, err := sql.ParseCollation(&ct.Charset, &ct.Collate, ct.BinaryCollate) + case "longtext", "longblob": + collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) if err != nil { return nil, err } return CreateString(sqltypes.Text, LongTextBlobMax/collation.CharacterSet().MaxLength(), collation) - case "char", "character": - collation, err := sql.ParseCollation(&ct.Charset, &ct.Collate, ct.BinaryCollate) + case "char", "character", "binary": + collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) if err != nil { return nil, err } @@ -277,7 +303,7 @@ func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) { } return CreateString(sqltypes.Char, length, sql.Collation_utf8mb3_general_ci) case "varchar", "char varying", "character varying": - collation, err := sql.ParseCollation(&ct.Charset, &ct.Collate, ct.BinaryCollate) + collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) if err != nil { return nil, err } @@ -305,17 +331,11 @@ func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) { return nil, err } return CreateString(sqltypes.VarChar, length, sql.Collation_utf8mb3_general_ci) - case "binary": - length := int64(1) - if ct.Length != nil { - var err error - length, err = strconv.ParseInt(string(ct.Length.Val), 10, 64) - if err != nil { - return nil, err - } - } - return CreateString(sqltypes.Binary, length, sql.Collation_binary) case "varbinary": + collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) + if err != nil { + return nil, err + } if ct.Length == nil { return nil, fmt.Errorf("VARBINARY requires a length") } @@ -327,7 +347,7 @@ func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) { if length > varcharVarbinaryMax { return nil, ErrLengthTooLarge.New(length, varcharVarbinaryMax) } - return CreateString(sqltypes.VarBinary, length, sql.Collation_binary) + return CreateString(sqltypes.VarBinary, length, collation) case "year": return Year, nil case "date": @@ -379,7 +399,7 @@ func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) { return CreateDatetimeType(sqltypes.Datetime, int(precision)) case "enum": - collation, err := sql.ParseCollation(&ct.Charset, &ct.Collate, ct.BinaryCollate) + collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) if err != nil { return nil, err } @@ -388,7 +408,7 @@ func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) { } return CreateEnumType(ct.EnumValues, collation) case "set": - collation, err := sql.ParseCollation(&ct.Charset, &ct.Collate, ct.BinaryCollate) + collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate) if err != nil { return nil, err }