diff --git a/enginetest/queries/variable_queries.go b/enginetest/queries/variable_queries.go index c7a10b4d0e..697777052d 100644 --- a/enginetest/queries/variable_queries.go +++ b/enginetest/queries/variable_queries.go @@ -18,6 +18,7 @@ import ( "math" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" ) var VariableQueries = []ScriptTest{ @@ -60,6 +61,51 @@ var VariableQueries = []ScriptTest{ {uint64(2)}, }, }, + { + Name: "variable scope is included in returned column name when explicitly provided", + Assertions: []ScriptTestAssertion{ + { + Query: "select @@max_allowed_packet;", + Expected: []sql.Row{{1073741824}}, + ExpectedColumns: sql.Schema{ + { + Name: "@@max_allowed_packet", + Type: types.Uint64, + }, + }, + }, + { + Query: "select @@session.max_allowed_packet;", + Expected: []sql.Row{{1073741824}}, + ExpectedColumns: sql.Schema{ + { + Name: "@@session.max_allowed_packet", + Type: types.Uint64, + }, + }, + }, + { + Query: "select @@global.max_allowed_packet;", + Expected: []sql.Row{{1073741824}}, + ExpectedColumns: sql.Schema{ + { + Name: "@@global.max_allowed_packet", + Type: types.Uint64, + }, + }, + }, + { + Query: "select @@GLoBAL.max_allowed_packet;", + Expected: []sql.Row{{1073741824}}, + ExpectedColumns: sql.Schema{ + { + Name: "@@GLoBAL.max_allowed_packet", + Type: types.Uint64, + }, + }, + }, + }, + }, { Name: "@@server_id", Assertions: []ScriptTestAssertion{ diff --git a/go.mod b/go.mod index 39466ba566..23c9c75b9c 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20240117220136-123ca09b8929 + github.com/dolthub/vitess v0.0.0-20240117224045-c9088efc7f8c github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index 430edd0a0a..9a9f524340 100644 --- a/go.sum +++ b/go.sum @@ -60,12 +60,12 @@ github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9X github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= github.com/dolthub/vitess v0.0.0-20240110233415-e46007d964c0 h1:P8wb4dR5krirPa0swEJbEObc/I7GaAM/01nOnuQrl0c= github.com/dolthub/vitess v0.0.0-20240110233415-e46007d964c0/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= -github.com/dolthub/vitess v0.0.0-20240117061527-f9260279b3d3 h1:nEwq2/8gTI2jm/4APIMTrWNDDRCn8AWJjrCbH+d7CJc= -github.com/dolthub/vitess v0.0.0-20240117061527-f9260279b3d3/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= -github.com/dolthub/vitess v0.0.0-20240117195812-420942cccb48 h1:Bdsy71WXx4yvK71IFwIqQ2duL5a/y15EuKEhVN51bSE= -github.com/dolthub/vitess v0.0.0-20240117195812-420942cccb48/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= -github.com/dolthub/vitess v0.0.0-20240117220136-123ca09b8929 h1:6SExRtdwbcNPl7q09SXxtnwk+pVdhrsd0ap1DVfphEg= -github.com/dolthub/vitess v0.0.0-20240117220136-123ca09b8929/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= +github.com/dolthub/vitess v0.0.0-20240117190712-1b506b108f54 h1:AjTW9LRaRq12Pu1tt0YG+vjfDduA59O7DhtyiNgY0Yw= +github.com/dolthub/vitess v0.0.0-20240117190712-1b506b108f54/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= +github.com/dolthub/vitess v0.0.0-20240117191503-43038f0e7332 h1:lkY6/i/jFn70eQWfZaxMhjlMtcTZiLErCBAoKcySzCQ= +github.com/dolthub/vitess v0.0.0-20240117191503-43038f0e7332/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= +github.com/dolthub/vitess v0.0.0-20240117224045-c9088efc7f8c h1:smzyKT85EbK/RZnu+KFku63enuWx5IZ9rehPdngoBrk= +github.com/dolthub/vitess v0.0.0-20240117224045-c9088efc7f8c/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= diff --git a/server/handler.go b/server/handler.go index dd96f7599b..cab3df3ef2 100644 --- a/server/handler.go +++ b/server/handler.go @@ -718,6 +718,20 @@ func schemaToFields(ctx *sql.Context, s sql.Schema) []*querypb.Field { charset = uint32(charSetResults) } + var flags querypb.MySqlFlag + if !c.Nullable { + flags = flags | querypb.MySqlFlag_NOT_NULL_FLAG + } + if c.AutoIncrement { + flags = flags | querypb.MySqlFlag_AUTO_INCREMENT_FLAG + } + if c.PrimaryKey { + flags = flags | querypb.MySqlFlag_PRI_KEY_FLAG + } + if types.IsUnsigned(c.Type) { + flags = flags | querypb.MySqlFlag_UNSIGNED_FLAG + } + fields[i] = &querypb.Field{ Name: c.Name, OrgName: c.Name, @@ -727,6 +741,7 @@ func schemaToFields(ctx *sql.Context, s sql.Schema) []*querypb.Field { Type: c.Type.Type(), Charset: charset, ColumnLength: c.Type.MaxTextResponseByteLength(ctx), + Flags: uint32(flags), } if types.IsDecimal(c.Type) { diff --git a/server/handler_test.go b/server/handler_test.go index 498c7eaa1c..a09ad44258 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -257,7 +257,7 @@ func TestHandlerComPrepare(t *testing.T) { name: "select statement returns non-nil schema", statement: "select c1 from test where c1 > ?", expected: []*query.Field{ - {Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11}, + {Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, }, }, { @@ -324,7 +324,7 @@ func TestHandlerComPrepareExecute(t *testing.T) { }, }, schema: []*query.Field{ - {Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11}, + {Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, }, expected: []sql.Row{ {0}, {1}, {2}, {3}, {4}, @@ -402,7 +402,7 @@ func TestHandlerComPrepareExecuteWithPreparedDisabled(t *testing.T) { }, }, schema: []*query.Field{ - {Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11}, + {Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, }, expected: []sql.Row{ {0}, {1}, {2}, {3}, {4}, @@ -663,57 +663,57 @@ func TestSchemaToFields(t *testing.T) { expected := []*query.Field{ // Blob, Text, and JSON Types - {Name: "tinyblob", OrgName: "tinyblob", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BLOB, Charset: mysql.CharacterSetBinary, ColumnLength: 255}, - {Name: "blob", OrgName: "blob", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BLOB, Charset: mysql.CharacterSetBinary, ColumnLength: 65_535}, - {Name: "mediumblob", OrgName: "mediumblob", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BLOB, Charset: mysql.CharacterSetBinary, ColumnLength: 16_777_215}, - {Name: "longblob", OrgName: "longblob", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BLOB, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295}, - {Name: "tinytext", OrgName: "tinytext", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TEXT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 1020}, - {Name: "text", OrgName: "text", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TEXT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 262_140}, - {Name: "mediumtext", OrgName: "mediumtext", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TEXT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 67_108_860}, - {Name: "longtext", OrgName: "longtext", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TEXT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 4_294_967_295}, - {Name: "json", OrgName: "json", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_JSON, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295}, + {Name: "tinyblob", OrgName: "tinyblob", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BLOB, Charset: mysql.CharacterSetBinary, ColumnLength: 255, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "blob", OrgName: "blob", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BLOB, Charset: mysql.CharacterSetBinary, ColumnLength: 65_535, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "mediumblob", OrgName: "mediumblob", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BLOB, Charset: mysql.CharacterSetBinary, ColumnLength: 16_777_215, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "longblob", OrgName: "longblob", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BLOB, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "tinytext", OrgName: "tinytext", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TEXT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 1020, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "text", OrgName: "text", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TEXT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 262_140, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "mediumtext", OrgName: "mediumtext", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TEXT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 67_108_860, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "longtext", OrgName: "longtext", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TEXT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "json", OrgName: "json", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_JSON, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, // Geometry Types - {Name: "geometry", OrgName: "geometry", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_GEOMETRY, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295}, - {Name: "point", OrgName: "point", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_GEOMETRY, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295}, - {Name: "polygon", OrgName: "polygon", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_GEOMETRY, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295}, - {Name: "linestring", OrgName: "linestring", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_GEOMETRY, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295}, + {Name: "geometry", OrgName: "geometry", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_GEOMETRY, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "point", OrgName: "point", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_GEOMETRY, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "polygon", OrgName: "polygon", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_GEOMETRY, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "linestring", OrgName: "linestring", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_GEOMETRY, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, // Integer Types - {Name: "uint8", OrgName: "uint8", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT8, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 3}, - {Name: "int8", OrgName: "int8", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT8, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 4}, - {Name: "uint16", OrgName: "uint16", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT16, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 5}, - {Name: "int16", OrgName: "int16", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT16, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6}, - {Name: "uint24", OrgName: "uint24", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT24, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 8}, - {Name: "int24", OrgName: "int24", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT24, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 9}, - {Name: "uint32", OrgName: "uint32", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 10}, - {Name: "int32", OrgName: "int32", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11}, - {Name: "uint64", OrgName: "uint64", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT64, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 20}, - {Name: "int64", OrgName: "int64", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT64, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 20}, + {Name: "uint8", OrgName: "uint8", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT8, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 3, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG | query.MySqlFlag_UNSIGNED_FLAG)}, + {Name: "int8", OrgName: "int8", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT8, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 4, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "uint16", OrgName: "uint16", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT16, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 5, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG | query.MySqlFlag_UNSIGNED_FLAG)}, + {Name: "int16", OrgName: "int16", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT16, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "uint24", OrgName: "uint24", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT24, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 8, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG | query.MySqlFlag_UNSIGNED_FLAG)}, + {Name: "int24", OrgName: "int24", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT24, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 9, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "uint32", OrgName: "uint32", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 10, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG | query.MySqlFlag_UNSIGNED_FLAG)}, + {Name: "int32", OrgName: "int32", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "uint64", OrgName: "uint64", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT64, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 20, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG | query.MySqlFlag_UNSIGNED_FLAG)}, + {Name: "int64", OrgName: "int64", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT64, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 20, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, // Floating Point and Decimal Types - {Name: "float32", OrgName: "float32", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_FLOAT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 12}, - {Name: "float64", OrgName: "float64", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_FLOAT64, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 22}, - {Name: "decimal10_0", OrgName: "decimal10_0", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_DECIMAL, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Decimals: 0}, - {Name: "decimal60_30", OrgName: "decimal60_30", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_DECIMAL, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 62, Decimals: 30}, + {Name: "float32", OrgName: "float32", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_FLOAT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 12, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "float64", OrgName: "float64", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_FLOAT64, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 22, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "decimal10_0", OrgName: "decimal10_0", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_DECIMAL, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Decimals: 0, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "decimal60_30", OrgName: "decimal60_30", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_DECIMAL, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 62, Decimals: 30, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, // Char, Binary, and Bit Types - {Name: "varchar50", OrgName: "varchar50", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_VARCHAR, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 50 * 4}, - {Name: "varbinary12345", OrgName: "varbinary12345", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_VARBINARY, Charset: mysql.CharacterSetBinary, ColumnLength: 12345}, - {Name: "binary123", OrgName: "binary123", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BINARY, Charset: mysql.CharacterSetBinary, ColumnLength: 123}, - {Name: "char123", OrgName: "char123", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_CHAR, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 123 * 4}, - {Name: "bit12", OrgName: "bit12", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BIT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 12}, + {Name: "varchar50", OrgName: "varchar50", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_VARCHAR, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 50 * 4, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "varbinary12345", OrgName: "varbinary12345", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_VARBINARY, Charset: mysql.CharacterSetBinary, ColumnLength: 12345, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "binary123", OrgName: "binary123", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BINARY, Charset: mysql.CharacterSetBinary, ColumnLength: 123, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "char123", OrgName: "char123", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_CHAR, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 123 * 4, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "bit12", OrgName: "bit12", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BIT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 12, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, // Dates - {Name: "datetime", OrgName: "datetime", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_DATETIME, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 26}, - {Name: "timestamp", OrgName: "timestamp", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TIMESTAMP, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 26}, - {Name: "date", OrgName: "date", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_DATE, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 10}, - {Name: "time", OrgName: "time", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TIME, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 17}, - {Name: "year", OrgName: "year", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_YEAR, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 4}, + {Name: "datetime", OrgName: "datetime", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_DATETIME, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 26, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "timestamp", OrgName: "timestamp", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TIMESTAMP, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 26, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "date", OrgName: "date", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_DATE, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 10, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "time", OrgName: "time", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TIME, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 17, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "year", OrgName: "year", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_YEAR, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 4, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, // Set and Enum Types - {Name: "set", OrgName: "set", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_SET, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 72}, - {Name: "enum", OrgName: "enum", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_ENUM, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 20}, + {Name: "set", OrgName: "set", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_SET, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 72, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "enum", OrgName: "enum", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_ENUM, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 20, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, } require.Equal(len(schema), len(expected)) diff --git a/sql/expression/variables.go b/sql/expression/variables.go index a22fe310f9..d925062428 100644 --- a/sql/expression/variables.go +++ b/sql/expression/variables.go @@ -25,17 +25,22 @@ import ( // SystemVar is an expression that returns the value of a system variable. It's also used as the expression on the left // hand side of a SET statement for a system variable. type SystemVar struct { - Name string - Collation sql.CollationID - Scope sql.SystemVariableScope + Name string + Collation sql.CollationID + Scope sql.SystemVariableScope + SpecifiedScope string } var _ sql.Expression = (*SystemVar)(nil) var _ sql.CollationCoercible = (*SystemVar)(nil) -// NewSystemVar creates a new SystemVar expression. -func NewSystemVar(name string, scope sql.SystemVariableScope) *SystemVar { - return &SystemVar{name, sql.CollationID(0), scope} +// NewSystemVar creates a new SystemVar expression for the system variable named |name| with the specified |scope|. +// The |specifiedScope| parameter indicates the exact scope that was specified in the original reference to this +// system variable, and is used to ensure we output a column name in a result set that exactly matches how the +// system variable was originally referenced. If the |specifiedScope| parameter is empty, then the scope was not +// originally specified and any scope has been inferred. +func NewSystemVar(name string, scope sql.SystemVariableScope, specifiedScope string) *SystemVar { + return &SystemVar{name, sql.CollationID(0), scope, specifiedScope} } // Children implements the sql.Expression interface. @@ -100,13 +105,11 @@ func (v *SystemVar) Resolved() bool { return true } // String implements the sql.Expression interface. func (v *SystemVar) String() string { - switch v.Scope { - case sql.SystemVariableScope_Session: - return fmt.Sprintf("@@SESSION.%s", v.Name) - case sql.SystemVariableScope_Global: - return fmt.Sprintf("@@GLOBAL.%s", v.Name) - default: // should never happen - return fmt.Sprintf("@@UNKNOWN(%v).%s", v.Scope, v.Name) + // If the scope wasn't explicitly provided, then don't include it in the string representation + if v.SpecifiedScope == "" { + return fmt.Sprintf("@@%s", v.Name) + } else { + return fmt.Sprintf("@@%s.%s", v.SpecifiedScope, v.Name) } } diff --git a/sql/planbuilder/set.go b/sql/planbuilder/set.go index b1108ed657..e2abb532da 100644 --- a/sql/planbuilder/set.go +++ b/sql/planbuilder/set.go @@ -98,27 +98,27 @@ func (b *Builder) setExprsToExpressions(inScope *scope, e ast.SetVarExprs) []sql } switch strings.ToLower(expr.String()) { case "'isolation level repeatable read'": - varToSet := expression.NewSystemVar("transaction_isolation", scope) + varToSet := expression.NewSystemVar("transaction_isolation", scope, string(scope)) res[i] = expression.NewSetField(varToSet, expression.NewLiteral("REPEATABLE-READ", types.LongText)) continue case "'isolation level read committed'": - varToSet := expression.NewSystemVar("transaction_isolation", scope) + varToSet := expression.NewSystemVar("transaction_isolation", scope, string(scope)) res[i] = expression.NewSetField(varToSet, expression.NewLiteral("READ-COMMITTED", types.LongText)) continue case "'isolation level read uncommitted'": - varToSet := expression.NewSystemVar("transaction_isolation", scope) + varToSet := expression.NewSystemVar("transaction_isolation", scope, string(scope)) res[i] = expression.NewSetField(varToSet, expression.NewLiteral("READ-UNCOMMITTED", types.LongText)) continue case "'isolation level serializable'": - varToSet := expression.NewSystemVar("transaction_isolation", scope) + varToSet := expression.NewSystemVar("transaction_isolation", scope, string(scope)) res[i] = expression.NewSetField(varToSet, expression.NewLiteral("SERIALIZABLE", types.LongText)) continue case "'read write'": - varToSet := expression.NewSystemVar("transaction_read_only", scope) + varToSet := expression.NewSystemVar("transaction_read_only", scope, string(scope)) res[i] = expression.NewSetField(varToSet, expression.NewLiteral(false, types.Boolean)) continue case "'read only'": - varToSet := expression.NewSystemVar("transaction_read_only", scope) + varToSet := expression.NewSystemVar("transaction_read_only", scope, string(scope)) res[i] = expression.NewSetField(varToSet, expression.NewLiteral(true, types.Boolean)) continue } @@ -168,10 +168,12 @@ func (b *Builder) buildSysVar(colName *ast.ColName, scopeHint ast.SetScope) (sql var varName string var scope ast.SetScope var err error - if table != "" { - varName, scope, err = ast.VarScope(table, col) + var specifiedScope string + + if table == "" { + varName, scope, specifiedScope, err = ast.VarScope(col) } else { - varName, scope, err = ast.VarScope(col) + varName, scope, specifiedScope, err = ast.VarScope(table, col) } if err != nil { b.handleErr(err) @@ -187,11 +189,11 @@ func (b *Builder) buildSysVar(colName *ast.ColName, scopeHint ast.SetScope) (sql if !ok { return nil, scope, false } - return expression.NewSystemVar(varName, sql.SystemVariableScope_Global), scope, true + return expression.NewSystemVar(varName, sql.SystemVariableScope_Global, specifiedScope), scope, true case ast.SetScope_None, ast.SetScope_Session: switch strings.ToLower(varName) { case "character_set_database", "collation_database": - sysVar := expression.NewSystemVar(varName, sql.SystemVariableScope_Session) + sysVar := expression.NewSystemVar(varName, sql.SystemVariableScope_Session, specifiedScope) sysVar.Collation = sql.Collation_Default if db, err := b.cat.Database(b.ctx, b.ctx.GetCurrentDatabase()); err == nil { sysVar.Collation = plan.GetDatabaseCollation(b.ctx, db) @@ -202,7 +204,7 @@ func (b *Builder) buildSysVar(colName *ast.ColName, scopeHint ast.SetScope) (sql if err != nil { return nil, scope, false } - return expression.NewSystemVar(varName, sql.SystemVariableScope_Session), scope, true + return expression.NewSystemVar(varName, sql.SystemVariableScope_Session, specifiedScope), scope, true } case ast.SetScope_User: t, _, err := b.ctx.GetUserVariable(b.ctx, varName) @@ -214,9 +216,9 @@ func (b *Builder) buildSysVar(colName *ast.ColName, scopeHint ast.SetScope) (sql } return expression.NewUserVar(varName), scope, true case ast.SetScope_Persist: - return expression.NewSystemVar(varName, sql.SystemVariableScope_Persist), scope, true + return expression.NewSystemVar(varName, sql.SystemVariableScope_Persist, specifiedScope), scope, true case ast.SetScope_PersistOnly: - return expression.NewSystemVar(varName, sql.SystemVariableScope_PersistOnly), scope, true + return expression.NewSystemVar(varName, sql.SystemVariableScope_PersistOnly, specifiedScope), scope, true default: // shouldn't happen err := fmt.Errorf("unknown set scope %v", scope) b.handleErr(err) @@ -325,9 +327,9 @@ func (b *Builder) simplifySetExpr(name *ast.ColName, varScope ast.SetScope, val table := name.Qualifier.String() col := name.Name.Lowered() if table != "" { - varName, _, err = ast.VarScope(table, col) + varName, _, _, err = ast.VarScope(table, col) } else { - varName, _, err = ast.VarScope(col) + varName, _, _, err = ast.VarScope(col) } if err != nil { b.handleErr(err) diff --git a/sql/rowexec/set_test.go b/sql/rowexec/set_test.go index 33c5e34b92..1581a9246f 100644 --- a/sql/rowexec/set_test.go +++ b/sql/rowexec/set_test.go @@ -93,7 +93,7 @@ func TestPersistedSessionSetIterator(t *testing.T) { sqlCtx, globals := newPersistedSqlContext() s := plan.NewSet( []sql.Expression{ - expression.NewSetField(expression.NewSystemVar(test.name, test.scope), expression.NewLiteral(int64(test.value), types.Int64)), + expression.NewSetField(expression.NewSystemVar(test.name, test.scope, string(test.scope)), expression.NewLiteral(int64(test.value), types.Int64)), }, ) diff --git a/sql/types/system_bool.go b/sql/types/system_bool.go index 4b3c576467..e5b8c797ac 100644 --- a/sql/types/system_bool.go +++ b/sql/types/system_bool.go @@ -140,9 +140,8 @@ func (t SystemBoolType) Equals(otherType sql.Type) bool { } // MaxTextResponseByteLength implements the Type interface -func (t SystemBoolType) MaxTextResponseByteLength(_ *sql.Context) uint32 { - // system types are not sent directly across the wire - return 0 +func (t SystemBoolType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { + return t.UnderlyingType().MaxTextResponseByteLength(ctx) } // Promote implements the Type interface. diff --git a/sql/types/system_double.go b/sql/types/system_double.go index 6d7b1b8aea..325c5d42e9 100644 --- a/sql/types/system_double.go +++ b/sql/types/system_double.go @@ -125,9 +125,8 @@ func (t systemDoubleType) Equals(otherType sql.Type) bool { } // MaxTextResponseByteLength implements the Type interface -func (t systemDoubleType) MaxTextResponseByteLength(_ *sql.Context) uint32 { - // system types are not sent directly across the wire - return 0 +func (t systemDoubleType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { + return t.UnderlyingType().MaxTextResponseByteLength(ctx) } // Promote implements the Type interface. diff --git a/sql/types/system_enum.go b/sql/types/system_enum.go index e9fce2cc59..36d45bbf38 100644 --- a/sql/types/system_enum.go +++ b/sql/types/system_enum.go @@ -145,9 +145,8 @@ func (t systemEnumType) Equals(otherType sql.Type) bool { } // MaxTextResponseByteLength implements the Type interface -func (t systemEnumType) MaxTextResponseByteLength(_ *sql.Context) uint32 { - // system types are not sent directly across the wire - return 0 +func (t systemEnumType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { + return t.UnderlyingType().MaxTextResponseByteLength(ctx) } // Promote implements the Type interface. diff --git a/sql/types/system_int.go b/sql/types/system_int.go index 0782f76a2b..fb9ae4f33b 100644 --- a/sql/types/system_int.go +++ b/sql/types/system_int.go @@ -140,9 +140,8 @@ func (t systemIntType) Equals(otherType sql.Type) bool { } // MaxTextResponseByteLength implements the Type interface -func (t systemIntType) MaxTextResponseByteLength(_ *sql.Context) uint32 { - // system types are not sent directly across the wire - return 0 +func (t systemIntType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { + return t.UnderlyingType().MaxTextResponseByteLength(ctx) } // Promote implements the Type interface. diff --git a/sql/types/system_set.go b/sql/types/system_set.go index fc7f10228c..f6a3b24c68 100644 --- a/sql/types/system_set.go +++ b/sql/types/system_set.go @@ -127,6 +127,11 @@ func (t systemSetType) Equals(otherType sql.Type) bool { return false } +// MaxTextResponseByteLength implements the Type interface +func (t systemSetType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { + return t.UnderlyingType().MaxTextResponseByteLength(ctx) +} + // Promote implements the Type interface. func (t systemSetType) Promote() sql.Type { return t diff --git a/sql/types/system_string.go b/sql/types/system_string.go index 077a5d6bf2..640df144d1 100644 --- a/sql/types/system_string.go +++ b/sql/types/system_string.go @@ -90,9 +90,8 @@ func (t systemStringType) Equals(otherType sql.Type) bool { } // MaxTextResponseByteLength implements the Type interface -func (t systemStringType) MaxTextResponseByteLength(_ *sql.Context) uint32 { - // system types are not sent directly across the wire - return 0 +func (t systemStringType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { + return t.UnderlyingType().MaxTextResponseByteLength(ctx) } // Promote implements the Type interface. diff --git a/sql/types/system_uint.go b/sql/types/system_uint.go index 5ffa53429c..81380da4c7 100644 --- a/sql/types/system_uint.go +++ b/sql/types/system_uint.go @@ -129,9 +129,8 @@ func (t systemUintType) Equals(otherType sql.Type) bool { } // MaxTextResponseByteLength implements the Type interface -func (t systemUintType) MaxTextResponseByteLength(_ *sql.Context) uint32 { - // system types are not sent directly across the wire - return 0 +func (t systemUintType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { + return t.UnderlyingType().MaxTextResponseByteLength(ctx) } // Promote implements the Type interface. diff --git a/sql/types/typecheck.go b/sql/types/typecheck.go index 0bac63fa8d..298d5d4d2d 100644 --- a/sql/types/typecheck.go +++ b/sql/types/typecheck.go @@ -186,5 +186,9 @@ func IsTuple(t sql.Type) bool { // IsUnsigned checks if t is an unsigned type. func IsUnsigned(t sql.Type) bool { + if svt, ok := t.(sql.SystemVariableType); ok { + t = svt.UnderlyingType() + } + return t == Uint8 || t == Uint16 || t == Uint24 || t == Uint32 || t == Uint64 } diff --git a/sql/variables/system_variables.go b/sql/variables/system_variables.go index a2897449fa..0ee993a282 100644 --- a/sql/variables/system_variables.go +++ b/sql/variables/system_variables.go @@ -1372,7 +1372,7 @@ var systemVars = map[string]sql.SystemVariable{ Scope: sql.SystemVariableScope_Both, Dynamic: true, SetVarHintApplies: false, - Type: types.NewSystemIntType("max_allowed_packet", 1024, 1073741824, false), + Type: types.NewSystemUintType("max_allowed_packet", 1024, 1073741824), Default: int64(1073741824), }, "max_connect_errors": {