Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions enginetest/queries/create_table_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,36 @@ var CreateTableQueries = []WriteQueryTest{
SelectQuery: `select * from t1 order by i`,
ExpectedSelect: []sql.Row{{"newfirst row", 1}, {"newsecond row", 2}, {"newthird row", 3}},
},
{
WriteQuery: `CREATE TABLE t1 (pk varchar(10) primary key collate binary)`,
ExpectedWriteResult: []sql.Row{{types.NewOkResult(0)}},
SelectQuery: `SHOW CREATE TABLE t1`,
ExpectedSelect: []sql.Row{sql.Row{"t1", "CREATE TABLE `t1` (\n `pk` varbinary(10) NOT NULL,\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
WriteQuery: `CREATE TABLE t1 (pk varchar(10) primary key charset binary)`,
ExpectedWriteResult: []sql.Row{{types.NewOkResult(0)}},
SelectQuery: `SHOW CREATE TABLE t1`,
ExpectedSelect: []sql.Row{sql.Row{"t1", "CREATE TABLE `t1` (\n `pk` varbinary(10) NOT NULL,\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
WriteQuery: `CREATE TABLE t1 (pk varchar(10) primary key character set binary)`,
ExpectedWriteResult: []sql.Row{{types.NewOkResult(0)}},
SelectQuery: `SHOW CREATE TABLE t1`,
ExpectedSelect: []sql.Row{sql.Row{"t1", "CREATE TABLE `t1` (\n `pk` varbinary(10) NOT NULL,\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
WriteQuery: `CREATE TABLE t1 (pk varchar(10) primary key charset binary collate binary)`,
ExpectedWriteResult: []sql.Row{{types.NewOkResult(0)}},
SelectQuery: `SHOW CREATE TABLE t1`,
ExpectedSelect: []sql.Row{sql.Row{"t1", "CREATE TABLE `t1` (\n `pk` varbinary(10) NOT NULL,\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
WriteQuery: `CREATE TABLE t1 (pk varchar(10) primary key character set binary collate binary)`,
ExpectedWriteResult: []sql.Row{{types.NewOkResult(0)}},
SelectQuery: `SHOW CREATE TABLE t1`,
ExpectedSelect: []sql.Row{sql.Row{"t1", "CREATE TABLE `t1` (\n `pk` varbinary(10) NOT NULL,\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
}

var CreateTableScriptTests = []ScriptTest{
Expand Down Expand Up @@ -496,6 +526,38 @@ var CreateTableScriptTests = []ScriptTest{
},
},
},
{
Name: "valid character set and collation options",
SetUpScript: []string{
"create table parent (a int primary key)",
},
Assertions: []ScriptTestAssertion{
{
Query: `CREATE TABLE t1 (pk varbinary(10) primary key collate utf8mb4_0900_bin)`,
ExpectedErr: types.ErrBinaryCollation,
},
{
Query: `CREATE TABLE t1 (pk varbinary(10) primary key charset utf8mb4_0900_bin)`,
ExpectedErr: types.ErrCharacterSetOnInvalidType,
},
{
Query: `CREATE TABLE t1 (pk varbinary(10) primary key character set utf8mb4)`,
ExpectedErr: types.ErrCharacterSetOnInvalidType,
},
{
Query: `CREATE TABLE t1 (pk varbinary(10) primary key charset utf8mb4 collate utf8mb4_0900_bin)`,
ExpectedErr: types.ErrCharacterSetOnInvalidType,
},
{
Query: `CREATE TABLE t1 (pk varbinary(10) primary key character set utf8mb4 collate utf8mb4_0900_bin)`,
ExpectedErr: types.ErrCharacterSetOnInvalidType,
},
{
Query: `CREATE TABLE t1 (pk int primary key character set utf8mb4)`,
ExpectedErr: types.ErrCharacterSetOnInvalidType,
},
},
},
}

var CreateTableAutoIncrementTests = []ScriptTest{
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require (
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e
github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
github.com/dolthub/vitess v0.0.0-20231127171856-2466012fb61f
github.com/dolthub/vitess v0.0.0-20231202001124-09287d7cc674
github.com/go-kit/kit v0.10.0
github.com/go-sql-driver/mysql v1.7.1
github.com/gocraft/dbr/v2 v2.7.2
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ github.com/dolthub/vitess v0.0.0-20231109003730-c0fa018b5ef6 h1:/GOBV8ceNCMuyS9/
github.com/dolthub/vitess v0.0.0-20231109003730-c0fa018b5ef6/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw=
github.com/dolthub/vitess v0.0.0-20231127171856-2466012fb61f h1:I480LKHhb4usnF3dYhp6J4ORKMrncNKaWYZvIZwlK+U=
github.com/dolthub/vitess v0.0.0-20231127171856-2466012fb61f/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw=
github.com/dolthub/vitess v0.0.0-20231202001124-09287d7cc674 h1:OYEf4PpMUG7rj51/l/Da9Iv9tAMnBvnEi+dA+tsoKGA=
github.com/dolthub/vitess v0.0.0-20231202001124-09287d7cc674/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
Expand Down
3 changes: 2 additions & 1 deletion sql/planbuilder/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -1037,9 +1037,10 @@ func (b *Builder) tableSpecToSchema(inScope, outScope *scope, db sql.Database, t
generated := make([]ast.Expr, len(tableSpec.Columns))
var schema sql.Schema
for i, cd := range tableSpec.Columns {
sqlType := cd.Type.SQLType()
// Use the table's collation if no character or collation was specified for the table
if len(cd.Type.Charset) == 0 && len(cd.Type.Collate) == 0 {
if tableCollation != sql.Collation_Unspecified {
if tableCollation != sql.Collation_Unspecified && !types.IsBinary(sqlType) {
cd.Type.Collate = tableCollation.Name()
}
}
Expand Down
98 changes: 59 additions & 39 deletions sql/types/conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import (
"time"

"github.com/dolthub/vitess/go/sqltypes"
"github.com/dolthub/vitess/go/vt/proto/query"
"github.com/dolthub/vitess/go/vt/sqlparser"
"github.com/shopspring/decimal"
"gopkg.in/src-d/go-errors.v1"

"github.com/dolthub/go-mysql-server/sql"
)
Expand Down Expand Up @@ -116,8 +118,47 @@ func ApproximateTypeFromValue(val interface{}) sql.Type {
}
}

// IsBinary returns whether the type represents binary data.
func IsBinary(sqlType query.Type) bool {
switch sqlType {
case sqltypes.Binary,
sqltypes.VarBinary,
sqltypes.Blob,
sqltypes.TypeJSON,
sqltypes.Geometry:
return true
}
return false
}

func allowsCharSet(sqlType query.Type) bool {
switch sqlType {
case sqltypes.VarChar,
sqltypes.Char,
sqltypes.Text,
sqltypes.Enum,
sqltypes.Set:
return true
}
return false
}

var ErrCharacterSetOnInvalidType = errors.NewKind("Only character columns, enums, and sets can have a CHARACTER SET option")

// ColumnTypeToType gets the column type using the column definition.
func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) {

sqlType := ct.SQLType()

if !allowsCharSet(sqlType) && len(ct.Charset) != 0 {
return nil, ErrCharacterSetOnInvalidType.New()
}

collate := ct.Collate
if IsBinary(sqlType) && collate == "" {
collate = sql.Collation_binary.Name()
}

switch strings.ToLower(ct.Type) {
case "boolean", "bool":
return Boolean, nil
Expand Down Expand Up @@ -206,29 +247,14 @@ func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) {
}
}
return CreateBitType(uint8(length))
case "tinyblob":
return TinyBlob, nil
case "blob":
if ct.Length == nil {
return Blob, nil
}
length, err := strconv.ParseInt(string(ct.Length.Val), 10, 64)
if err != nil {
return nil, err
}
return CreateBinary(sqltypes.Blob, length)
case "mediumblob":
return MediumBlob, nil
case "longblob":
return LongBlob, nil
case "tinytext":
collation, err := sql.ParseCollation(&ct.Charset, &ct.Collate, ct.BinaryCollate)
case "tinytext", "tinyblob":
collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate)
if err != nil {
return nil, err
}
return CreateString(sqltypes.Text, TinyTextBlobMax/collation.CharacterSet().MaxLength(), collation)
case "text":
collation, err := sql.ParseCollation(&ct.Charset, &ct.Collate, ct.BinaryCollate)
case "text", "blob":
collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate)
if err != nil {
return nil, err
}
Expand All @@ -240,20 +266,20 @@ func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) {
return nil, err
}
return CreateString(sqltypes.Text, length, collation)
case "mediumtext", "long", "long varchar":
collation, err := sql.ParseCollation(&ct.Charset, &ct.Collate, ct.BinaryCollate)
case "mediumtext", "mediumblob", "long", "long varchar":
collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate)
if err != nil {
return nil, err
}
return CreateString(sqltypes.Text, MediumTextBlobMax/collation.CharacterSet().MaxLength(), collation)
case "longtext":
collation, err := sql.ParseCollation(&ct.Charset, &ct.Collate, ct.BinaryCollate)
case "longtext", "longblob":
collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate)
if err != nil {
return nil, err
}
return CreateString(sqltypes.Text, LongTextBlobMax/collation.CharacterSet().MaxLength(), collation)
case "char", "character":
collation, err := sql.ParseCollation(&ct.Charset, &ct.Collate, ct.BinaryCollate)
case "char", "character", "binary":
collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate)
if err != nil {
return nil, err
}
Expand All @@ -277,7 +303,7 @@ func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) {
}
return CreateString(sqltypes.Char, length, sql.Collation_utf8mb3_general_ci)
case "varchar", "char varying", "character varying":
collation, err := sql.ParseCollation(&ct.Charset, &ct.Collate, ct.BinaryCollate)
collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -305,17 +331,11 @@ func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) {
return nil, err
}
return CreateString(sqltypes.VarChar, length, sql.Collation_utf8mb3_general_ci)
case "binary":
length := int64(1)
if ct.Length != nil {
var err error
length, err = strconv.ParseInt(string(ct.Length.Val), 10, 64)
if err != nil {
return nil, err
}
}
return CreateString(sqltypes.Binary, length, sql.Collation_binary)
case "varbinary":
collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate)
if err != nil {
return nil, err
}
if ct.Length == nil {
return nil, fmt.Errorf("VARBINARY requires a length")
}
Expand All @@ -327,7 +347,7 @@ func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) {
if length > varcharVarbinaryMax {
return nil, ErrLengthTooLarge.New(length, varcharVarbinaryMax)
}
return CreateString(sqltypes.VarBinary, length, sql.Collation_binary)
return CreateString(sqltypes.VarBinary, length, collation)
case "year":
return Year, nil
case "date":
Expand Down Expand Up @@ -379,7 +399,7 @@ func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) {

return CreateDatetimeType(sqltypes.Datetime, int(precision))
case "enum":
collation, err := sql.ParseCollation(&ct.Charset, &ct.Collate, ct.BinaryCollate)
collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate)
if err != nil {
return nil, err
}
Expand All @@ -388,7 +408,7 @@ func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) {
}
return CreateEnumType(ct.EnumValues, collation)
case "set":
collation, err := sql.ParseCollation(&ct.Charset, &ct.Collate, ct.BinaryCollate)
collation, err := sql.ParseCollation(&ct.Charset, &collate, ct.BinaryCollate)
if err != nil {
return nil, err
}
Expand Down