diff --git a/go/test/endtoend/vtgate/setstatement/sysvar_test.go b/go/test/endtoend/vtgate/setstatement/sysvar_test.go index afc0f7e7e88..51b17640885 100644 --- a/go/test/endtoend/vtgate/setstatement/sysvar_test.go +++ b/go/test/endtoend/vtgate/setstatement/sysvar_test.go @@ -62,6 +62,14 @@ func TestSetSysVar(t *testing.T) { name: "default_storage_engine", expr: "INNODB", expected: `[[VARCHAR("InnoDB")]]`, + }, { + name: "character_set_client", + expr: "utf8", + expected: `[[VARCHAR("utf8")]]`, + }, { + name: "character_set_client", // ignored so will keep the actual value + expr: "@charvar", + expected: `[[VARCHAR("utf8")]]`, }, { name: "sql_mode", expr: "''", diff --git a/go/vt/vtgate/executor_set_test.go b/go/vt/vtgate/executor_set_test.go index b77c12f6b37..ed4f72547f7 100644 --- a/go/vt/vtgate/executor_set_test.go +++ b/go/vt/vtgate/executor_set_test.go @@ -179,24 +179,6 @@ func TestExecutorSet(t *testing.T) { }, { in: "set autocommit = 1+1", err: "invalid syntax: 1 + 1", - }, { - in: "set character_set_results=null", - out: &vtgatepb.Session{Autocommit: true}, - }, { - in: "set character_set_results='binary'", - out: &vtgatepb.Session{Autocommit: true}, - }, { - in: "set character_set_results='utf8'", - out: &vtgatepb.Session{Autocommit: true}, - }, { - in: "set character_set_results='utf8mb4'", - out: &vtgatepb.Session{Autocommit: true}, - }, { - in: "set character_set_results='latin1'", - out: &vtgatepb.Session{Autocommit: true}, - }, { - in: "set character_set_results='abcd'", - err: "disallowed value for character_set_results: abcd", }, { in: "set foo = 1", err: "unsupported construct: set foo = 1", @@ -278,6 +260,13 @@ func TestExecutorSetOp(t *testing.T) { sqltypes.MakeTestResult(sqltypes.MakeTestFields("unique_checks", "int64"), "0"), sqltypes.MakeTestResult(sqltypes.MakeTestFields("net_write_timeout", "int64"), "600"), sqltypes.MakeTestResult(sqltypes.MakeTestFields("net_read_timeout", "int64"), "300"), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("character_set_client", "varchar"), "utf8"), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("character_set_results", "varchar")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("character_set_results", "varchar")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("character_set_results", "varchar")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("character_set_results", "varchar")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("character_set_results", "varchar")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("character_set_results", "varchar")), }) testcases := []struct { @@ -307,6 +296,20 @@ func TestExecutorSetOp(t *testing.T) { in: "set net_write_timeout = 600", }, { in: "set net_read_timeout = 600", + }, { + in: "set character_set_client = utf8", + }, { + in: "set character_set_results=null", + }, { + in: "set character_set_results='binary'", + }, { + in: "set character_set_results='utf8'", + }, { + in: "set character_set_results=utf8mb4", + }, { + in: "set character_set_results='latin1'", + }, { + in: "set character_set_results='abcd'", }} for _, tcase := range testcases { t.Run(tcase.in, func(t *testing.T) { diff --git a/go/vt/vtgate/planbuilder/set.go b/go/vt/vtgate/planbuilder/set.go index 8800df00784..fe8f75cb665 100644 --- a/go/vt/vtgate/planbuilder/set.go +++ b/go/vt/vtgate/planbuilder/set.go @@ -160,6 +160,7 @@ var checkAndIgnore = []string{ "character_set_connection", "character_set_database", "character_set_filesystem", + "character_set_results", "character_set_server", "collation_connection", "collation_database", @@ -310,12 +311,9 @@ func buildNotSupported(e *sqlparser.SetExpr, _ ContextVSchema) (engine.SetOp, er } func buildSetOpIgnore(expr *sqlparser.SetExpr, _ ContextVSchema) (engine.SetOp, error) { - buf := sqlparser.NewTrackedBuffer(nil) - buf.Myprintf("%v", expr.Expr) - return &engine.SysVarIgnore{ Name: expr.Name.Lowered(), - Expr: buf.String(), + Expr: extractValue(expr), }, nil } @@ -329,7 +327,7 @@ func buildSetOpCheckAndIgnore(expr *sqlparser.SetExpr, vschema ContextVSchema) ( Name: expr.Name.Lowered(), Keyspace: keyspace, TargetDestination: dest, - Expr: sqlparser.String(expr.Expr), + Expr: extractValue(expr), }, nil } @@ -360,7 +358,7 @@ func buildSetOpVarSet(expr *sqlparser.SetExpr, vschema ContextVSchema) (engine.S Name: expr.Name.Lowered(), Keyspace: ks, TargetDestination: vschema.Destination(), - Expr: sqlparser.String(expr.Expr), + Expr: extractValue(expr), }, nil } @@ -377,6 +375,17 @@ func resolveDestination(vschema ContextVSchema) (*vindexes.Keyspace, key.Destina return keyspace, dest, nil } +func extractValue(expr *sqlparser.SetExpr) string { + value := sqlparser.String(expr.Expr) + switch colname := expr.Expr.(type) { + case *sqlparser.ColName: + if colname.Name.AtCount() == sqlparser.NoAt { + value = fmt.Sprintf("'%s'", value) + } + } + return value +} + // whitelist of functions knows to be safe to pass through to mysql for evaluation // this list tries to not include functions that might return different results on different tablets var validFuncs = map[string]interface{}{