diff --git a/go/sqltypes/bind_variables.go b/go/sqltypes/bind_variables.go index 1d79f033681..3c9bc6bbe9e 100644 --- a/go/sqltypes/bind_variables.go +++ b/go/sqltypes/bind_variables.go @@ -75,6 +75,14 @@ func Int32BindVariable(v int32) *querypb.BindVariable { return ValueBindVariable(NewInt32(v)) } +// BoolBindVariable converts an bool to a int32 bind var. +func BoolBindVariable(v bool) *querypb.BindVariable { + if v { + return Int32BindVariable(1) + } + return Int32BindVariable(0) +} + // Int64BindVariable converts an int64 to a bind var. func Int64BindVariable(v int64) *querypb.BindVariable { return ValueBindVariable(NewInt64(v)) diff --git a/go/vt/sqlparser/bind_var_needs.go b/go/vt/sqlparser/bind_var_needs.go new file mode 100644 index 00000000000..cc086cf4f51 --- /dev/null +++ b/go/vt/sqlparser/bind_var_needs.go @@ -0,0 +1,72 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqlparser + +// BindVarNeeds represents the bind vars that need to be provided as the result of expression rewriting. +type BindVarNeeds struct { + NeedFunctionResult, + NeedSystemVariable, + // NeedUserDefinedVariables keeps track of all user defined variables a query is using + NeedUserDefinedVariables []string +} + +//MergeWith adds bind vars needs coming from sub scopes +func (bvn *BindVarNeeds) MergeWith(other *BindVarNeeds) { + bvn.NeedFunctionResult = append(bvn.NeedFunctionResult, other.NeedFunctionResult...) + bvn.NeedSystemVariable = append(bvn.NeedSystemVariable, other.NeedSystemVariable...) + bvn.NeedUserDefinedVariables = append(bvn.NeedUserDefinedVariables, other.NeedUserDefinedVariables...) +} + +//AddFuncResult adds a function bindvar need +func (bvn *BindVarNeeds) AddFuncResult(name string) { + bvn.NeedFunctionResult = append(bvn.NeedFunctionResult, name) +} + +//AddSysVar adds a system variable bindvar need +func (bvn *BindVarNeeds) AddSysVar(name string) { + bvn.NeedSystemVariable = append(bvn.NeedSystemVariable, name) +} + +//AddUserDefVar adds a user defined variable bindvar need +func (bvn *BindVarNeeds) AddUserDefVar(name string) { + bvn.NeedUserDefinedVariables = append(bvn.NeedUserDefinedVariables, name) +} + +//NeedsFuncResult says if a function result needs to be provided +func (bvn *BindVarNeeds) NeedsFuncResult(name string) bool { + return contains(bvn.NeedFunctionResult, name) +} + +//NeedsSysVar says if a function result needs to be provided +func (bvn *BindVarNeeds) NeedsSysVar(name string) bool { + return contains(bvn.NeedSystemVariable, name) +} + +func (bvn *BindVarNeeds) HasRewrites() bool { + return len(bvn.NeedFunctionResult) > 0 || + len(bvn.NeedUserDefinedVariables) > 0 || + len(bvn.NeedSystemVariable) > 0 +} + +func contains(strings []string, name string) bool { + for _, s := range strings { + if name == s { + return true + } + } + return false +} diff --git a/go/vt/sqlparser/expression_rewriting.go b/go/vt/sqlparser/expression_rewriting.go index 1ea4bae0db4..9542d9a325b 100644 --- a/go/vt/sqlparser/expression_rewriting.go +++ b/go/vt/sqlparser/expression_rewriting.go @@ -19,7 +19,8 @@ package sqlparser import ( "strings" - "vitess.io/vitess/go/vt/log" + "vitess.io/vitess/go/vt/sysvars" + querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" @@ -33,15 +34,6 @@ func PrepareAST(in Statement, bindVars map[string]*querypb.BindVariable, prefix return RewriteAST(in) } -// BindVarNeeds represents the bind vars that need to be provided as the result of expression rewriting. -type BindVarNeeds struct { - NeedLastInsertID bool - NeedDatabase bool - NeedFoundRows bool - NeedRowCount bool - NeedUserDefinedVariables []string -} - // RewriteAST rewrites the whole AST, replacing function calls and adding column aliases to queries func RewriteAST(in Statement) (*RewriteASTResult, error) { er := newExpressionRewriter() @@ -56,21 +48,8 @@ func RewriteAST(in Statement) (*RewriteASTResult, error) { } r := &RewriteASTResult{ - AST: out, - } - for k := range er.bindVars { - switch k { - case LastInsertIDName: - r.NeedLastInsertID = true - case DBVarName: - r.NeedDatabase = true - case FoundRowsName: - r.NeedFoundRows = true - case RowCountName: - r.NeedRowCount = true - default: - r.NeedUserDefinedVariables = append(r.NeedUserDefinedVariables, k) - } + AST: out, + BindVarNeeds: er.bindVars, } return r, nil } @@ -96,18 +75,18 @@ func shouldRewriteDatabaseFunc(in Statement) bool { // RewriteASTResult contains the rewritten ast and meta information about it type RewriteASTResult struct { - BindVarNeeds + *BindVarNeeds AST Statement // The rewritten AST } type expressionRewriter struct { - bindVars map[string]struct{} + bindVars *BindVarNeeds shouldRewriteDatabaseFunc bool err error } func newExpressionRewriter() *expressionRewriter { - return &expressionRewriter{bindVars: make(map[string]struct{})} + return &expressionRewriter{bindVars: &BindVarNeeds{}} } const ( @@ -127,6 +106,18 @@ const ( UserDefinedVariableName = "__vtudv" ) +func (er *expressionRewriter) rewriteAliasedExpr(cursor *Cursor, node *AliasedExpr) (*BindVarNeeds, error) { + inner := newExpressionRewriter() + inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc + tmp := Rewrite(node.Expr, inner.goingDown, nil) + newExpr, ok := tmp.(Expr) + if !ok { + return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "failed to rewrite AST. function expected to return Expr returned a %s", String(tmp)) + } + node.Expr = newExpr + return inner.bindVars, nil +} + func (er *expressionRewriter) goingDown(cursor *Cursor) bool { switch node := cursor.Node().(type) { // select last_insert_id() -> select :__lastInsertId as `last_insert_id()` @@ -136,35 +127,50 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool { if ok && aliasedExpr.As.IsEmpty() { buf := NewTrackedBuffer(nil) aliasedExpr.Expr.Format(buf) - inner := newExpressionRewriter() - inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc - tmp := Rewrite(aliasedExpr.Expr, inner.goingDown, nil) - newExpr, ok := tmp.(Expr) - if !ok { - log.Errorf("failed to rewrite AST. function expected to return Expr returned a %s", String(tmp)) + innerBindVarNeeds, err := er.rewriteAliasedExpr(cursor, aliasedExpr) + if err != nil { + er.err = err return false } - aliasedExpr.Expr = newExpr - if inner.didAnythingChange() { + if innerBindVarNeeds.HasRewrites() { aliasedExpr.As = NewColIdent(buf.String()) } - for k := range inner.bindVars { - er.needBindVarFor(k) - } + er.bindVars.MergeWith(innerBindVarNeeds) } } case *FuncExpr: er.funcRewrite(cursor, node) case *ColName: - if node.Name.at == SingleAt { - udv := strings.ToLower(node.Name.CompliantName()) - cursor.Replace(bindVarExpression(UserDefinedVariableName + udv)) - er.needBindVarFor(udv) + switch node.Name.at { + case SingleAt: + er.udvRewrite(cursor, node) + case DoubleAt: + er.sysVarRewrite(cursor, node) } } return true } +func (er *expressionRewriter) sysVarRewrite(cursor *Cursor, node *ColName) { + lowered := node.Name.Lowered() + switch lowered { + case sysvars.Autocommit.Name, + sysvars.ClientFoundRows.Name, + sysvars.SkipQueryPlanCache.Name, + sysvars.SQLSelectLimit.Name, + sysvars.TransactionMode.Name, + sysvars.Workload.Name: + cursor.Replace(bindVarExpression("__vt" + lowered)) + er.bindVars.AddSysVar(lowered) + } +} + +func (er *expressionRewriter) udvRewrite(cursor *Cursor, node *ColName) { + udv := strings.ToLower(node.Name.CompliantName()) + cursor.Replace(bindVarExpression(UserDefinedVariableName + udv)) + er.bindVars.AddUserDefVar(udv) +} + func (er *expressionRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) { switch { // last_insert_id() -> :__lastInsertId @@ -173,7 +179,7 @@ func (er *expressionRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) { er.err = vterrors.New(vtrpc.Code_UNIMPLEMENTED, "Argument to LAST_INSERT_ID() not supported") } else { cursor.Replace(bindVarExpression(LastInsertIDName)) - er.needBindVarFor(LastInsertIDName) + er.bindVars.AddFuncResult(LastInsertIDName) } // database() -> :__vtdbname case er.shouldRewriteDatabaseFunc && @@ -183,7 +189,7 @@ func (er *expressionRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) { er.err = vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "Syntax error. %s() takes no arguments", node.Name.String()) } else { cursor.Replace(bindVarExpression(DBVarName)) - er.needBindVarFor(DBVarName) + er.bindVars.AddFuncResult(DBVarName) } // found_rows() -> :__vtfrows case node.Name.EqualString("found_rows"): @@ -191,7 +197,7 @@ func (er *expressionRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) { er.err = vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Arguments to FOUND_ROWS() not supported") } else { cursor.Replace(bindVarExpression(FoundRowsName)) - er.needBindVarFor(FoundRowsName) + er.bindVars.AddFuncResult(FoundRowsName) } // row_count() -> :__vtrcount case node.Name.EqualString("row_count"): @@ -199,22 +205,11 @@ func (er *expressionRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) { er.err = vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Arguments to ROW_COUNT() not supported") } else { cursor.Replace(bindVarExpression(RowCountName)) - er.needBindVarFor(RowCountName) + er.bindVars.AddFuncResult(RowCountName) } } } -// instead of creating new objects, we'll reuse this one -var token = struct{}{} - -func (er *expressionRewriter) needBindVarFor(name string) { - er.bindVars[name] = token -} - -func (er *expressionRewriter) didAnythingChange() bool { - return len(er.bindVars) > 0 -} - func bindVarExpression(name string) Expr { return NewArgument([]byte(":" + name)) } diff --git a/go/vt/sqlparser/expression_rewriting_test.go b/go/vt/sqlparser/expression_rewriting_test.go index 110c4883373..a627620026a 100644 --- a/go/vt/sqlparser/expression_rewriting_test.go +++ b/go/vt/sqlparser/expression_rewriting_test.go @@ -19,127 +19,145 @@ package sqlparser import ( "testing" + "github.com/stretchr/testify/assert" + "vitess.io/vitess/go/vt/sysvars" + "github.com/stretchr/testify/require" ) type myTestCase struct { - in, expected string - liid, db, foundRows, rowCount bool - udv int + in, expected string + liid, db, foundRows, rowCount bool + udv int + autocommit, clientFoundRows, skipQueryPlanCache bool + sqlSelectLimit, transactionMode, workload bool } func TestRewrites(in *testing.T) { - tests := []myTestCase{ - { - in: "SELECT 42", - expected: "SELECT 42", - // no bindvar needs - }, - { - in: "SELECT last_insert_id()", - expected: "SELECT :__lastInsertId as `last_insert_id()`", - liid: true, - }, - { - in: "SELECT database()", - expected: "SELECT :__vtdbname as `database()`", - db: true, - }, - { - in: "SELECT database() from test", - expected: "SELECT database() from test", - // no bindvar needs - }, - { - in: "SELECT last_insert_id() as test", - expected: "SELECT :__lastInsertId as test", - liid: true, - }, - { - in: "SELECT last_insert_id() + database()", - expected: "SELECT :__lastInsertId + :__vtdbname as `last_insert_id() + database()`", - db: true, liid: true, - }, - { - in: "select (select database()) from test", - expected: "select (select database() from dual) from test", - // no bindvar needs - }, - { - in: "select (select database() from dual) from test", - expected: "select (select database() from dual) from test", - // no bindvar needs - }, - { - in: "select (select database() from dual) from dual", - expected: "select (select :__vtdbname as `database()` from dual) as `(select database() from dual)` from dual", - db: true, - }, - { - in: "select id from user where database()", - expected: "select id from user where database()", - // no bindvar needs - }, - { - in: "select table_name from information_schema.tables where table_schema = database()", - expected: "select table_name from information_schema.tables where table_schema = database()", - // no bindvar needs - }, - { - in: "select schema()", - expected: "select :__vtdbname as `schema()`", - db: true, - }, - { - in: "select found_rows()", - expected: "select :__vtfrows as `found_rows()`", - foundRows: true, - }, - { - in: "select @`x y`", - expected: "select :__vtudvx_y as `@``x y``` from dual", - udv: 1, - }, - { - in: "select id from t where id = @x and val = @y", - expected: "select id from t where id = :__vtudvx and val = :__vtudvy", - db: false, udv: 2, - }, - { - in: "insert into t(id) values(@xyx)", - expected: "insert into t(id) values(:__vtudvxyx)", - db: false, udv: 1, - }, - { - in: "select row_count()", - expected: "select :__vtrcount as `row_count()`", - rowCount: true, - }, - { - in: "SELECT lower(database())", - expected: "SELECT lower(:__vtdbname) as `lower(database())`", - db: true, - }, - } + tests := []myTestCase{{ + in: "SELECT 42", + expected: "SELECT 42", + // no bindvar needs + }, { + in: "SELECT last_insert_id()", + expected: "SELECT :__lastInsertId as `last_insert_id()`", + liid: true, + }, { + in: "SELECT database()", + expected: "SELECT :__vtdbname as `database()`", + db: true, + }, { + in: "SELECT database() from test", + expected: "SELECT database() from test", + // no bindvar needs + }, { + in: "SELECT last_insert_id() as test", + expected: "SELECT :__lastInsertId as test", + liid: true, + }, { + in: "SELECT last_insert_id() + database()", + expected: "SELECT :__lastInsertId + :__vtdbname as `last_insert_id() + database()`", + db: true, liid: true, + }, { + in: "select (select database()) from test", + expected: "select (select database() from dual) from test", + // no bindvar needs + }, { + in: "select (select database() from dual) from test", + expected: "select (select database() from dual) from test", + // no bindvar needs + }, { + in: "select (select database() from dual) from dual", + expected: "select (select :__vtdbname as `database()` from dual) as `(select database() from dual)` from dual", + db: true, + }, { + in: "select id from user where database()", + expected: "select id from user where database()", + // no bindvar needs + }, { + in: "select table_name from information_schema.tables where table_schema = database()", + expected: "select table_name from information_schema.tables where table_schema = database()", + // no bindvar needs + }, { + in: "select schema()", + expected: "select :__vtdbname as `schema()`", + db: true, + }, { + in: "select found_rows()", + expected: "select :__vtfrows as `found_rows()`", + foundRows: true, + }, { + in: "select @`x y`", + expected: "select :__vtudvx_y as `@``x y``` from dual", + udv: 1, + }, { + in: "select id from t where id = @x and val = @y", + expected: "select id from t where id = :__vtudvx and val = :__vtudvy", + db: false, udv: 2, + }, { + in: "insert into t(id) values(@xyx)", + expected: "insert into t(id) values(:__vtudvxyx)", + db: false, udv: 1, + }, { + in: "select row_count()", + expected: "select :__vtrcount as `row_count()`", + rowCount: true, + }, { + in: "SELECT lower(database())", + expected: "SELECT lower(:__vtdbname) as `lower(database())`", + db: true, + }, { + in: "SELECT @@autocommit", + expected: "SELECT :__vtautocommit as `@@autocommit`", + autocommit: true, + }, { + in: "SELECT @@client_found_rows", + expected: "SELECT :__vtclient_found_rows as `@@client_found_rows`", + clientFoundRows: true, + }, { + in: "SELECT @@skip_query_plan_cache", + expected: "SELECT :__vtskip_query_plan_cache as `@@skip_query_plan_cache`", + skipQueryPlanCache: true, + }, { + in: "SELECT @@sql_select_limit", + expected: "SELECT :__vtsql_select_limit as `@@sql_select_limit`", + sqlSelectLimit: true, + }, { + in: "SELECT @@transaction_mode", + expected: "SELECT :__vttransaction_mode as `@@transaction_mode`", + transactionMode: true, + }, { + in: "SELECT @@workload", + expected: "SELECT :__vtworkload as `@@workload`", + workload: true, + }} for _, tc := range tests { in.Run(tc.in, func(t *testing.T) { + require := require.New(t) stmt, err := Parse(tc.in) - require.NoError(t, err) + require.NoError(err) result, err := RewriteAST(stmt) - require.NoError(t, err) + require.NoError(err) expected, err := Parse(tc.expected) - require.NoError(t, err, "test expectation does not parse [%s]", tc.expected) + require.NoError(err, "test expectation does not parse [%s]", tc.expected) s := String(expected) - require.Equal(t, s, String(result.AST)) - require.Equal(t, tc.liid, result.NeedLastInsertID, "should need last insert id") - require.Equal(t, tc.db, result.NeedDatabase, "should need database name") - require.Equal(t, tc.foundRows, result.NeedFoundRows, "should need found rows") - require.Equal(t, tc.rowCount, result.NeedRowCount, "should need row count") - require.Equal(t, tc.udv, len(result.NeedUserDefinedVariables), "should need row count") + assert := assert.New(t) + assert.Equal(s, String(result.AST)) + assert.Equal(tc.liid, result.NeedsFuncResult(LastInsertIDName), "should need last insert id") + assert.Equal(tc.db, result.NeedsFuncResult(DBVarName), "should need database name") + assert.Equal(tc.foundRows, result.NeedsFuncResult(FoundRowsName), "should need found rows") + assert.Equal(tc.rowCount, result.NeedsFuncResult(RowCountName), "should need row count") + assert.Equal(tc.udv, len(result.NeedUserDefinedVariables), "count of user defined variables") + assert.Equal(tc.autocommit, result.NeedsSysVar(sysvars.Autocommit.Name), "should need :__vtautocommit") + assert.Equal(tc.clientFoundRows, result.NeedsSysVar(sysvars.ClientFoundRows.Name), "should need :__vtclientFoundRows") + assert.Equal(tc.skipQueryPlanCache, result.NeedsSysVar(sysvars.SkipQueryPlanCache.Name), "should need :__vtskipQueryPlanCache") + assert.Equal(tc.sqlSelectLimit, result.NeedsSysVar(sysvars.SQLSelectLimit.Name), "should need :__vtsqlSelectLimit") + assert.Equal(tc.transactionMode, result.NeedsSysVar(sysvars.TransactionMode.Name), "should need :__vttransactionMode") + assert.Equal(tc.workload, result.NeedsSysVar(sysvars.Workload.Name), "should need :__vtworkload") }) } } diff --git a/go/vt/sysvars/sysvars.go b/go/vt/sysvars/sysvars.go new file mode 100644 index 00000000000..cbd8984d9c3 --- /dev/null +++ b/go/vt/sysvars/sysvars.go @@ -0,0 +1,222 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sysvars + +// This information lives here, because it's needed from the vtgate planbuilder, the vtgate engine, +// and the AST rewriter, that happens to live in sqlparser. + +// SystemVariable is a system variable that Vitess handles in queries such as: +// select @@sql_mode +// set skip_query_plan_cache = true +type SystemVariable struct { + // IsBoolean is used to signal necessary type coercion so that strings + // and numbers can be evaluated to a boolean value + IsBoolean bool + + // IdentifierAsString allows identifiers (a.k.a. ColName) from the AST to be handled as if they are strings. + // SET transaction_mode = two_pc => SET transaction_mode = 'two_pc' + IdentifierAsString bool + + // Default is the default value, if none is given + Default string + + Name string +} + +var ( + on = "1" + off = "0" + utf8 = "'utf8'" + + Autocommit = SystemVariable{Name: "autocommit", IsBoolean: true, Default: on} + ClientFoundRows = SystemVariable{Name: "client_found_rows", IsBoolean: true, Default: off} + SkipQueryPlanCache = SystemVariable{Name: "skip_query_plan_cache", IsBoolean: true, Default: off} + TxReadOnly = SystemVariable{Name: "tx_read_only", IsBoolean: true, Default: off} + TransactionReadOnly = SystemVariable{Name: "transaction_read_only", IsBoolean: true, Default: off} + SQLSelectLimit = SystemVariable{Name: "sql_select_limit", Default: off} + TransactionMode = SystemVariable{Name: "transaction_mode", IdentifierAsString: true} + Workload = SystemVariable{Name: "workload", IdentifierAsString: true} + Charset = SystemVariable{Name: "charset", Default: utf8, IdentifierAsString: true} + Names = SystemVariable{Name: "names", Default: utf8, IdentifierAsString: true} + + VitessAware = []SystemVariable{ + Autocommit, + ClientFoundRows, + SkipQueryPlanCache, + TxReadOnly, + TransactionReadOnly, + SQLSelectLimit, + TransactionMode, + Workload, + Charset, + Names, + } + + IgnoreThese = []SystemVariable{ + {Name: "big_tables", IsBoolean: true}, + {Name: "bulk_insert_buffer_size"}, + {Name: "debug"}, + {Name: "default_storage_engine"}, + {Name: "default_tmp_storage_engine"}, + {Name: "innodb_strict_mode", IsBoolean: true}, + {Name: "innodb_support_xa", IsBoolean: true}, + {Name: "innodb_table_locks", IsBoolean: true}, + {Name: "innodb_tmpdir"}, + {Name: "join_buffer_size"}, + {Name: "keep_files_on_create", IsBoolean: true}, + {Name: "lc_messages"}, + {Name: "long_query_time"}, + {Name: "low_priority_updates", IsBoolean: true}, + {Name: "max_delayed_threads"}, + {Name: "max_insert_delayed_threads"}, + {Name: "multi_range_count"}, + {Name: "net_buffer_length"}, + {Name: "new", IsBoolean: true}, + {Name: "query_cache_type"}, + {Name: "query_cache_wlock_invalidate", IsBoolean: true}, + {Name: "query_prealloc_size"}, + {Name: "sql_buffer_result", IsBoolean: true}, + {Name: "transaction_alloc_block_size"}, + {Name: "wait_timeout"}, + } + + NotSupported = []SystemVariable{ + {Name: "audit_log_read_buffer_size"}, + {Name: "auto_increment_increment"}, + {Name: "auto_increment_offset"}, + {Name: "binlog_direct_non_transactional_updates"}, + {Name: "binlog_row_image"}, + {Name: "binlog_rows_query_log_events"}, + {Name: "innodb_ft_enable_stopword"}, + {Name: "innodb_ft_user_stopword_table"}, + {Name: "max_points_in_geometry"}, + {Name: "max_sp_recursion_depth"}, + {Name: "myisam_repair_threads"}, + {Name: "myisam_sort_buffer_size"}, + {Name: "myisam_stats_method"}, + {Name: "ndb_allow_copying_alter_table"}, + {Name: "ndb_autoincrement_prefetch_sz"}, + {Name: "ndb_blob_read_batch_bytes"}, + {Name: "ndb_blob_write_batch_bytes"}, + {Name: "ndb_deferred_constraints"}, + {Name: "ndb_force_send"}, + {Name: "ndb_fully_replicated"}, + {Name: "ndb_index_stat_enable"}, + {Name: "ndb_index_stat_option"}, + {Name: "ndb_join_pushdown"}, + {Name: "ndb_log_bin"}, + {Name: "ndb_log_exclusive_reads"}, + {Name: "ndb_row_checksum"}, + {Name: "ndb_use_exact_count"}, + {Name: "ndb_use_transactions"}, + {Name: "ndbinfo_max_bytes"}, + {Name: "ndbinfo_max_rows"}, + {Name: "ndbinfo_show_hidden"}, + {Name: "ndbinfo_table_prefix"}, + {Name: "old_alter_table"}, + {Name: "preload_buffer_size"}, + {Name: "rbr_exec_mode"}, + {Name: "sql_log_off"}, + {Name: "thread_pool_high_priority_connection"}, + {Name: "thread_pool_prio_kickup_timer"}, + {Name: "transaction_write_set_extraction"}, + } + UseReservedConn = []SystemVariable{ + {Name: "default_week_format"}, + {Name: "end_markers_in_json", IsBoolean: true}, + {Name: "eq_range_index_dive_limit"}, + {Name: "explicit_defaults_for_timestamp"}, + {Name: "foreign_key_checks", IsBoolean: true}, + {Name: "group_concat_max_len"}, + {Name: "max_heap_table_size"}, + {Name: "max_seeks_for_key"}, + {Name: "max_tmp_tables"}, + {Name: "min_examined_row_limit"}, + {Name: "old_passwords"}, + {Name: "optimizer_prune_level"}, + {Name: "optimizer_search_depth"}, + {Name: "optimizer_switch"}, + {Name: "optimizer_trace"}, + {Name: "optimizer_trace_features"}, + {Name: "optimizer_trace_limit"}, + {Name: "optimizer_trace_max_mem_size"}, + {Name: "transaction_isolation"}, + {Name: "tx_isolation"}, + {Name: "optimizer_trace_offset"}, + {Name: "parser_max_mem_size"}, + {Name: "profiling", IsBoolean: true}, + {Name: "profiling_history_size"}, + {Name: "query_alloc_block_size"}, + {Name: "range_alloc_block_size"}, + {Name: "range_optimizer_max_mem_size"}, + {Name: "read_buffer_size"}, + {Name: "read_rnd_buffer_size"}, + {Name: "show_create_table_verbosity", IsBoolean: true}, + {Name: "show_old_temporals", IsBoolean: true}, + {Name: "sort_buffer_size"}, + {Name: "sql_big_selects", IsBoolean: true}, + {Name: "sql_mode"}, + {Name: "sql_notes", IsBoolean: true}, + {Name: "sql_quote_show_create", IsBoolean: true}, + {Name: "sql_safe_updates", IsBoolean: true}, + {Name: "sql_warnings", IsBoolean: true}, + {Name: "tmp_table_size"}, + {Name: "transaction_prealloc_size"}, + {Name: "unique_checks", IsBoolean: true}, + {Name: "updatable_views_with_limit", IsBoolean: true}, + } + CheckAndIgnore = []SystemVariable{ + // TODO: Most of these settings should be moved into SysSetOpAware, and change Vitess behaviour. + // Until then, SET statements against these settings are allowed + // as long as they have the same value as the underlying database + {Name: "binlog_format"}, + {Name: "block_encryption_mode"}, + {Name: "character_set_client"}, + {Name: "character_set_connection"}, + {Name: "character_set_database"}, + {Name: "character_set_filesystem"}, + {Name: "character_set_results"}, + {Name: "character_set_server"}, + {Name: "collation_connection"}, + {Name: "collation_database"}, + {Name: "collation_server"}, + {Name: "completion_type"}, + {Name: "div_precision_increment"}, + {Name: "innodb_lock_wait_timeout"}, + {Name: "interactive_timeout"}, + {Name: "lc_time_names"}, + {Name: "lock_wait_timeout"}, + {Name: "max_allowed_packet"}, + {Name: "max_error_count"}, + {Name: "max_execution_time"}, + {Name: "max_join_size"}, + {Name: "max_length_for_sort_data"}, + {Name: "max_sort_length"}, + {Name: "max_user_connections"}, + {Name: "net_read_timeout"}, + {Name: "net_retry_count"}, + {Name: "net_write_timeout"}, + {Name: "session_track_gtids"}, + {Name: "session_track_schema", IsBoolean: true}, + {Name: "session_track_state_change", IsBoolean: true}, + {Name: "session_track_system_variables"}, + {Name: "session_track_transaction_info"}, + {Name: "sql_auto_is_null", IsBoolean: true}, + {Name: "time_zone"}, + {Name: "version_tokens_session"}, + } +) diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index 14704210959..84353584a4c 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -100,15 +100,15 @@ func (t noopVCursor) SetAutocommit(bool) error { panic("implement me") } -func (t noopVCursor) SetClientFoundRows(bool) { +func (t noopVCursor) SetClientFoundRows(bool) error { panic("implement me") } -func (t noopVCursor) SetSkipQueryPlanCache(bool) { +func (t noopVCursor) SetSkipQueryPlanCache(bool) error { panic("implement me") } -func (t noopVCursor) SetSQLSelectLimit(int64) { +func (t noopVCursor) SetSQLSelectLimit(int64) error { panic("implement me") } @@ -405,15 +405,15 @@ func (f *loggingVCursor) SetAutocommit(bool) error { panic("implement me") } -func (f *loggingVCursor) SetClientFoundRows(bool) { +func (f *loggingVCursor) SetClientFoundRows(bool) error { panic("implement me") } -func (f *loggingVCursor) SetSkipQueryPlanCache(bool) { +func (f *loggingVCursor) SetSkipQueryPlanCache(bool) error { panic("implement me") } -func (f *loggingVCursor) SetSQLSelectLimit(int64) { +func (f *loggingVCursor) SetSQLSelectLimit(int64) error { panic("implement me") } diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index e0599812625..7cb39769e91 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -113,9 +113,9 @@ type ( ShardSession() []*srvtopo.ResolvedShard SetAutocommit(bool) error - SetClientFoundRows(bool) - SetSkipQueryPlanCache(bool) - SetSQLSelectLimit(int64) + SetClientFoundRows(bool) error + SetSkipQueryPlanCache(bool) error + SetSQLSelectLimit(int64) error SetTransactionMode(vtgatepb.TransactionMode) SetWorkload(querypb.ExecuteOptions_Workload) SetFoundRows(uint64) @@ -127,10 +127,10 @@ type ( // each node does its part by combining the results of the // sub-nodes. Plan struct { - Type sqlparser.StatementType // The type of query we have - Original string // Original is the original query. - Instructions Primitive // Instructions contains the instructions needed to fulfil the query. - sqlparser.BindVarNeeds // Stores BindVars needed to be provided as part of expression rewriting + Type sqlparser.StatementType // The type of query we have + Original string // Original is the original query. + Instructions Primitive // Instructions contains the instructions needed to fulfil the query. + BindVarNeeds *sqlparser.BindVarNeeds // Stores BindVars needed to be provided as part of expression rewriting mu sync.Mutex // Mutex to protect the fields below ExecCount uint64 // Count of times this plan was executed diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go index 0e0f9487adc..d886e2a9336 100644 --- a/go/vt/vtgate/engine/projection.go +++ b/go/vt/vtgate/engine/projection.go @@ -38,7 +38,10 @@ func (p *Projection) Execute(vcursor VCursor, bindVars map[string]*querypb.BindV } if wantfields { - p.addFields(result, bindVars) + err := p.addFields(result, bindVars) + if err != nil { + return nil, err + } } var rows [][]sqltypes.Value for _, row := range result.Rows { @@ -67,7 +70,10 @@ func (p *Projection) StreamExecute(vcursor VCursor, bindVars map[string]*querypb } if wantields { - p.addFields(result, bindVars) + err = p.addFields(result, bindVars) + if err != nil { + return err + } } var rows [][]sqltypes.Value for _, row := range result.Rows { @@ -90,18 +96,26 @@ func (p *Projection) GetFields(vcursor VCursor, bindVars map[string]*querypb.Bin if err != nil { return nil, err } - p.addFields(qr, bindVars) + err = p.addFields(qr, bindVars) + if err != nil { + return nil, err + } return qr, nil } -func (p *Projection) addFields(qr *sqltypes.Result, bindVars map[string]*querypb.BindVariable) { +func (p *Projection) addFields(qr *sqltypes.Result, bindVars map[string]*querypb.BindVariable) error { env := evalengine.ExpressionEnv{BindVars: bindVars} for i, col := range p.Cols { + q, err := p.Exprs[i].Type(env) + if err != nil { + return err + } qr.Fields = append(qr.Fields, &querypb.Field{ Name: col, - Type: p.Exprs[i].Type(env), + Type: q, }) } + return nil } func (p *Projection) Inputs() []Primitive { diff --git a/go/vt/vtgate/engine/set.go b/go/vt/vtgate/engine/set.go index 7c8148c75c1..a39d115c041 100644 --- a/go/vt/vtgate/engine/set.go +++ b/go/vt/vtgate/engine/set.go @@ -22,6 +22,8 @@ import ( "fmt" "strings" + "vitess.io/vitess/go/vt/sysvars" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" "vitess.io/vitess/go/vt/log" @@ -73,8 +75,8 @@ type ( Expr string } - // SysVarSet implements the SetOp interface and will write the changes variable into the session - SysVarSet struct { + // SysVarReservedConn implements the SetOp interface and will write the changes variable into the session + SysVarReservedConn struct { Name string Keyspace *vindexes.Keyspace TargetDestination key.Destination `json:",omitempty"` @@ -252,27 +254,27 @@ func (svci *SysVarCheckAndIgnore) Execute(vcursor VCursor, env evalengine.Expres return nil } -var _ SetOp = (*SysVarSet)(nil) +var _ SetOp = (*SysVarReservedConn)(nil) //MarshalJSON provides the type to SetOp for plan json -func (svs *SysVarSet) MarshalJSON() ([]byte, error) { +func (svs *SysVarReservedConn) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string - SysVarSet + SysVarReservedConn }{ - Type: "SysVarSet", - SysVarSet: *svs, + Type: "SysVarSet", + SysVarReservedConn: *svs, }) } //VariableName implements the SetOp interface method -func (svs *SysVarSet) VariableName() string { +func (svs *SysVarReservedConn) VariableName() string { return svs.Name } //Execute implements the SetOp interface method -func (svs *SysVarSet) Execute(vcursor VCursor, env evalengine.ExpressionEnv) error { +func (svs *SysVarReservedConn) Execute(vcursor VCursor, env evalengine.ExpressionEnv) error { // For those running on advanced vitess settings. if svs.TargetDestination != nil { rss, _, err := vcursor.ResolveDestinations(svs.Keyspace.Name, nil, []key.Destination{svs.TargetDestination}) @@ -306,7 +308,7 @@ func (svs *SysVarSet) Execute(vcursor VCursor, env evalengine.ExpressionEnv) err return vterrors.Aggregate(errs) } -func (svs *SysVarSet) execSetStatement(vcursor VCursor, rss []*srvtopo.ResolvedShard, env evalengine.ExpressionEnv) error { +func (svs *SysVarReservedConn) execSetStatement(vcursor VCursor, rss []*srvtopo.ResolvedShard, env evalengine.ExpressionEnv) error { queries := make([]*querypb.BoundQuery, len(rss)) for i := 0; i < len(rss); i++ { queries[i] = &querypb.BoundQuery{ @@ -318,7 +320,7 @@ func (svs *SysVarSet) execSetStatement(vcursor VCursor, rss []*srvtopo.ResolvedS return vterrors.Aggregate(errs) } -func (svs *SysVarSet) checkAndUpdateSysVar(vcursor VCursor, res evalengine.ExpressionEnv) (bool, error) { +func (svs *SysVarReservedConn) checkAndUpdateSysVar(vcursor VCursor, res evalengine.ExpressionEnv) (bool, error) { sysVarExprValidationQuery := fmt.Sprintf("select %s from dual where @@%s != %s", svs.Expr, svs.Name, svs.Expr) rss, _, err := vcursor.ResolveDestinations(svs.Keyspace.Name, nil, []key.Destination{key.DestinationKeyspaceID{0}}) if err != nil { @@ -342,21 +344,7 @@ func (svs *SysVarSet) checkAndUpdateSysVar(vcursor VCursor, res evalengine.Expre var _ SetOp = (*SysVarSetAware)(nil) -// System variables that needs special handling -const ( - Autocommit = "autocommit" - ClientFoundRows = "client_found_rows" - SkipQueryPlanCache = "skip_query_plan_cache" - TxReadOnly = "tx_read_only" - TransactionReadOnly = "transaction_read_only" - SQLSelectLimit = "sql_select_limit" - TransactionMode = "transaction_mode" - Workload = "workload" - Charset = "charset" - Names = "names" -) - -//MarshalJSON provides the type to SetOp for plan json +//MarshalJSON marshals all the json func (svss *SysVarSetAware) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string @@ -371,84 +359,105 @@ func (svss *SysVarSetAware) MarshalJSON() ([]byte, error) { //Execute implements the SetOp interface method func (svss *SysVarSetAware) Execute(vcursor VCursor, env evalengine.ExpressionEnv) error { + var err error switch svss.Name { - // These are all the boolean values we need to handle - case Autocommit, ClientFoundRows, SkipQueryPlanCache, TxReadOnly, TransactionReadOnly: - value, err := svss.Expr.Evaluate(env) - if err != nil { - return err - } - boolValue, err := value.ToBooleanStrict() + case sysvars.Autocommit.Name: + err = svss.setBoolSysVar(env, vcursor.Session().SetAutocommit) + case sysvars.ClientFoundRows.Name: + err = svss.setBoolSysVar(env, vcursor.Session().SetClientFoundRows) + case sysvars.SkipQueryPlanCache.Name: + err = svss.setBoolSysVar(env, vcursor.Session().SetSkipQueryPlanCache) + case sysvars.TxReadOnly.Name, + sysvars.TransactionReadOnly.Name: + // TODO (4127): This is a dangerous NOP. + noop := func(bool) error { return nil } + err = svss.setBoolSysVar(env, noop) + case sysvars.SQLSelectLimit.Name: + intValue, err := svss.evalAsInt64(env) if err != nil { - return vterrors.Wrapf(err, "System setting '%s' can't be set to this value", svss.Name) - } - switch svss.Name { - case Autocommit: - vcursor.Session().SetAutocommit(boolValue) - case ClientFoundRows: - vcursor.Session().SetClientFoundRows(boolValue) - case SkipQueryPlanCache: - vcursor.Session().SetSkipQueryPlanCache(boolValue) - case TxReadOnly, TransactionReadOnly: - // TODO (4127): This is a dangerous NOP. + return vterrors.Wrapf(err, "failed to evaluate value for %s", sysvars.SQLSelectLimit.Name) } - - case SQLSelectLimit: - value, err := svss.Expr.Evaluate(env) + vcursor.Session().SetSQLSelectLimit(intValue) + case sysvars.TransactionMode.Name: + str, err := svss.evalAsString(env) if err != nil { return err } - - v := value.Value() - if !v.IsIntegral() { - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value type for sql_select_limit: %T", value.Value().Type().String()) + out, ok := vtgatepb.TransactionMode_value[strings.ToUpper(str)] + if !ok { + return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid transaction_mode: %s", str) } - intValue, err := v.ToInt64() + vcursor.Session().SetTransactionMode(vtgatepb.TransactionMode(out)) + case sysvars.Workload.Name: + str, err := svss.evalAsString(env) if err != nil { return err } - vcursor.Session().SetSQLSelectLimit(intValue) - - // String settings - case TransactionMode, Workload, Charset, Names: - value, err := svss.Expr.Evaluate(env) + out, ok := querypb.ExecuteOptions_Workload_value[strings.ToUpper(str)] + if !ok { + return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid workload: %s", str) + } + vcursor.Session().SetWorkload(querypb.ExecuteOptions_Workload(out)) + case sysvars.Charset.Name, sysvars.Names.Name: + str, err := svss.evalAsString(env) if err != nil { return err } - v := value.Value() - if !v.IsText() && !v.IsBinary() { - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value type for %s: %s", svss.Name, value.Value().Type().String()) - } - - str := v.ToString() - switch svss.Name { - case TransactionMode: - out, ok := vtgatepb.TransactionMode_value[strings.ToUpper(str)] - if !ok { - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid transaction_mode: %s", str) - } - vcursor.Session().SetTransactionMode(vtgatepb.TransactionMode(out)) - case Workload: - out, ok := querypb.ExecuteOptions_Workload_value[strings.ToUpper(str)] - if !ok { - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid workload: %s", str) - } - vcursor.Session().SetWorkload(querypb.ExecuteOptions_Workload(out)) - case Charset, Names: - switch strings.ToLower(str) { - case "", "utf8", "utf8mb4", "latin1", "default": - // do nothing - break - default: - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value for charset/names: %v", str) - } + switch strings.ToLower(str) { + case "", "utf8", "utf8mb4", "latin1", "default": + // do nothing + break + default: + return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value for charset/names: %v", str) } default: - return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unsupported construct") + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unsupported construct %s", svss.Name) } - return nil + return err +} + +func (svss *SysVarSetAware) evalAsInt64(env evalengine.ExpressionEnv) (int64, error) { + value, err := svss.Expr.Evaluate(env) + if err != nil { + return 0, err + } + + v := value.Value() + if !v.IsIntegral() { + return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "expected int, unexpected value type: %T", value.Value().Type().String()) + } + intValue, err := v.ToInt64() + if err != nil { + return 0, err + } + return intValue, nil +} + +func (svss *SysVarSetAware) evalAsString(env evalengine.ExpressionEnv) (string, error) { + value, err := svss.Expr.Evaluate(env) + if err != nil { + return "", err + } + v := value.Value() + if !v.IsText() && !v.IsBinary() { + return "", vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value type for %s: %s", svss.Name, value.Value().Type().String()) + } + + return v.ToString(), nil +} + +func (svss *SysVarSetAware) setBoolSysVar(env evalengine.ExpressionEnv, setter func(bool) error) error { + value, err := svss.Expr.Evaluate(env) + if err != nil { + return err + } + boolValue, err := value.ToBooleanStrict() + if err != nil { + return vterrors.Wrapf(err, "System setting '%s' can't be set to this value", svss.Name) + } + return setter(boolValue) } //VariableName implements the SetOp interface method diff --git a/go/vt/vtgate/engine/set_test.go b/go/vt/vtgate/engine/set_test.go index fb23d7ddf0e..afd0b0bc6f9 100644 --- a/go/vt/vtgate/engine/set_test.go +++ b/go/vt/vtgate/engine/set_test.go @@ -31,7 +31,7 @@ import ( ) func TestSetSystemVariableAsString(t *testing.T) { - setOp := SysVarSet{ + setOp := SysVarReservedConn{ Name: "x", Keyspace: &vindexes.Keyspace{ Name: "ks", @@ -198,7 +198,7 @@ func TestSetTable(t *testing.T) { { testName: "sysvar set without destination", setOps: []SetOp{ - &SysVarSet{ + &SysVarReservedConn{ Name: "x", Keyspace: &vindexes.Keyspace{ Name: "ks", @@ -216,7 +216,7 @@ func TestSetTable(t *testing.T) { { testName: "sysvar set not modifying setting", setOps: []SetOp{ - &SysVarSet{ + &SysVarReservedConn{ Name: "x", Keyspace: &vindexes.Keyspace{ Name: "ks", @@ -233,7 +233,7 @@ func TestSetTable(t *testing.T) { { testName: "sysvar set modifying setting", setOps: []SetOp{ - &SysVarSet{ + &SysVarReservedConn{ Name: "x", Keyspace: &vindexes.Keyspace{ Name: "ks", @@ -282,7 +282,7 @@ func TestSetTable(t *testing.T) { func TestSysVarSetErr(t *testing.T) { setOps := []SetOp{ - &SysVarSet{ + &SysVarReservedConn{ Name: "x", Keyspace: &vindexes.Keyspace{ Name: "ks", diff --git a/go/vt/vtgate/evalengine/expressions.go b/go/vt/vtgate/evalengine/expressions.go index d561b17551f..153ecd09bae 100644 --- a/go/vt/vtgate/evalengine/expressions.go +++ b/go/vt/vtgate/evalengine/expressions.go @@ -45,7 +45,7 @@ type ( // Expr is the interface that all evaluating expressions must implement Expr interface { Evaluate(env ExpressionEnv) (EvalResult, error) - Type(env ExpressionEnv) querypb.Type + Type(env ExpressionEnv) (querypb.Type, error) String() string } @@ -202,27 +202,37 @@ func (s *Subtraction) Type(left querypb.Type) querypb.Type { } //Type implements the Expr interface -func (b *BinaryOp) Type(env ExpressionEnv) querypb.Type { - ltype := b.Left.Type(env) - rtype := b.Right.Type(env) +func (b *BinaryOp) Type(env ExpressionEnv) (querypb.Type, error) { + ltype, err := b.Left.Type(env) + if err != nil { + return 0, err + } + rtype, err := b.Right.Type(env) + if err != nil { + return 0, err + } typ := mergeNumericalTypes(ltype, rtype) - return b.Expr.Type(typ) + return b.Expr.Type(typ), nil } //Type implements the Expr interface -func (b *BindVariable) Type(env ExpressionEnv) querypb.Type { +func (b *BindVariable) Type(env ExpressionEnv) (querypb.Type, error) { e := env.BindVars - return e[b.Key].Type + v, found := e[b.Key] + if !found { + return querypb.Type_NULL_TYPE, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "query arguments missing for %s", b.Key) + } + return v.Type, nil } //Type implements the Expr interface -func (l *Literal) Type(ExpressionEnv) querypb.Type { - return l.Val.typ +func (l *Literal) Type(ExpressionEnv) (querypb.Type, error) { + return l.Val.typ, nil } //Type implements the Expr interface -func (c *Column) Type(ExpressionEnv) querypb.Type { - return sqltypes.Float64 +func (c *Column) Type(ExpressionEnv) (querypb.Type, error) { + return sqltypes.Float64, nil } //String implements the BinaryExpr interface @@ -287,6 +297,12 @@ func evaluateByType(val *querypb.BindVariable) (EvalResult, error) { ival = 0 } return EvalResult{typ: sqltypes.Int64, ival: ival}, nil + case sqltypes.Int32: + ival, err := strconv.ParseInt(string(val.Value), 10, 32) + if err != nil { + ival = 0 + } + return EvalResult{typ: sqltypes.Int32, ival: ival}, nil case sqltypes.Uint64: uval, err := strconv.ParseUint(string(val.Value), 10, 64) if err != nil { diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 6f8649f8209..62035ea41b0 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -29,6 +29,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/sysvars" + "golang.org/x/net/context" "vitess.io/vitess/go/trace" "vitess.io/vitess/go/vt/discovery" @@ -71,10 +73,11 @@ var ( ) const ( - utf8 = "utf8" - utf8mb4 = "utf8mb4" - both = "both" - charset = "charset" + utf8 = "utf8" + utf8mb4 = "utf8mb4" + both = "both" + charset = "charset" + bindVarPrefix = "__vt" ) func init() { @@ -241,13 +244,51 @@ func (e *Executor) legacyExecute(ctx context.Context, safeSession *SafeSession, } // addNeededBindVars adds bind vars that are needed by the plan -func (e *Executor) addNeededBindVars(bindVarNeeds sqlparser.BindVarNeeds, bindVars map[string]*querypb.BindVariable, session *SafeSession) error { - if bindVarNeeds.NeedDatabase { - bindVars[sqlparser.DBVarName] = sqltypes.StringBindVariable(session.TargetString) +func (e *Executor) addNeededBindVars(bindVarNeeds *sqlparser.BindVarNeeds, bindVars map[string]*querypb.BindVariable, session *SafeSession) error { + for _, funcName := range bindVarNeeds.NeedFunctionResult { + switch funcName { + case sqlparser.DBVarName: + bindVars[sqlparser.DBVarName] = sqltypes.StringBindVariable(session.TargetString) + case sqlparser.LastInsertIDName: + bindVars[sqlparser.LastInsertIDName] = sqltypes.Uint64BindVariable(session.GetLastInsertId()) + case sqlparser.FoundRowsName: + bindVars[sqlparser.FoundRowsName] = sqltypes.Uint64BindVariable(session.FoundRows) + case sqlparser.RowCountName: + bindVars[sqlparser.RowCountName] = sqltypes.Int64BindVariable(session.RowCount) + } } - if bindVarNeeds.NeedLastInsertID { - bindVars[sqlparser.LastInsertIDName] = sqltypes.Uint64BindVariable(session.GetLastInsertId()) + for _, funcName := range bindVarNeeds.NeedSystemVariable { + switch funcName { + case sysvars.Autocommit.Name: + bindVars[bindVarPrefix+sysvars.Autocommit.Name] = sqltypes.BoolBindVariable(session.Autocommit) + case sysvars.ClientFoundRows.Name: + var v bool + ifOptionsExist(session, func(options *querypb.ExecuteOptions) { + v = options.ClientFoundRows + }) + bindVars[bindVarPrefix+sysvars.ClientFoundRows.Name] = sqltypes.BoolBindVariable(v) + case sysvars.SkipQueryPlanCache.Name: + var v bool + ifOptionsExist(session, func(options *querypb.ExecuteOptions) { + v = options.ClientFoundRows + }) + bindVars[bindVarPrefix+sysvars.SkipQueryPlanCache.Name] = sqltypes.BoolBindVariable(v) + case sysvars.SQLSelectLimit.Name: + var v int64 + ifOptionsExist(session, func(options *querypb.ExecuteOptions) { + v = options.SqlSelectLimit + }) + bindVars[bindVarPrefix+sysvars.SQLSelectLimit.Name] = sqltypes.Int64BindVariable(v) + case sysvars.TransactionMode.Name: + bindVars[bindVarPrefix+sysvars.TransactionMode.Name] = sqltypes.StringBindVariable(session.TransactionMode.String()) + case sysvars.Workload.Name: + var v string + ifOptionsExist(session, func(options *querypb.ExecuteOptions) { + v = options.GetWorkload().String() + }) + bindVars[bindVarPrefix+sysvars.Workload.Name] = sqltypes.StringBindVariable(v) + } } udvMap := session.UserDefinedVariables @@ -262,15 +303,14 @@ func (e *Executor) addNeededBindVars(bindVarNeeds sqlparser.BindVarNeeds, bindVa bindVars[sqlparser.UserDefinedVariableName+udv] = val } - if bindVarNeeds.NeedFoundRows { - bindVars[sqlparser.FoundRowsName] = sqltypes.Uint64BindVariable(session.FoundRows) - } + return nil +} - if bindVarNeeds.NeedRowCount { - bindVars[sqlparser.RowCountName] = sqltypes.Int64BindVariable(session.RowCount) +func ifOptionsExist(session *SafeSession, f func(*querypb.ExecuteOptions)) { + options := session.GetOptions() + if options != nil { + f(options) } - - return nil } func (e *Executor) destinationExec(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, dest key.Destination, destKeyspace string, destTabletType topodatapb.TabletType, logStats *LogStats, ignoreMaxMemoryRows bool) (*sqltypes.Result, error) { @@ -1147,7 +1187,7 @@ func (e *Executor) getPlan(vcursor *vcursorImpl, sql string, comments sqlparser. } query := sql statement := stmt - bindVarNeeds := sqlparser.BindVarNeeds{} + bindVarNeeds := &sqlparser.BindVarNeeds{} if !sqlparser.IgnoreMaxPayloadSizeDirective(statement) && !isValidPayloadSize(query) { return nil, mysql.NewSQLError(mysql.ERNetPacketTooLarge, "", "query payload size above threshold") } diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index b3c9bcbb643..88a2f808c68 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -315,6 +315,39 @@ func TestSelectLastInsertId(t *testing.T) { utils.MustMatch(t, result, wantResult, "Mismatch") } +func TestSelectSystemVariables(t *testing.T) { + masterSession.LastInsertId = 52 + executor, _, _, _ := createLegacyExecutorEnv() + executor.normalize = true + logChan := QueryLogger.Subscribe("Test") + defer QueryLogger.Unsubscribe(logChan) + + sql := "select @@autocommit, @@client_found_rows, @@skip_query_plan_cache, @@sql_select_limit, @@transaction_mode, @@workload" + result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "@@autocommit", Type: sqltypes.Int32}, + {Name: "@@client_found_rows", Type: sqltypes.Int32}, + {Name: "@@skip_query_plan_cache", Type: sqltypes.Int32}, + {Name: "@@sql_select_limit", Type: sqltypes.Int64}, + {Name: "@@transaction_mode", Type: sqltypes.VarBinary}, + {Name: "@@workload", Type: sqltypes.VarBinary}, + }, + RowsAffected: 1, + Rows: [][]sqltypes.Value{{ + // the following are the uninitialised session values + sqltypes.NULL, + sqltypes.NULL, + sqltypes.NULL, + sqltypes.NewInt64(0), + sqltypes.NewVarBinary("UNSPECIFIED"), + sqltypes.NewVarBinary(""), + }}, + } + require.NoError(t, err) + utils.MustMatch(t, result, wantResult, "Mismatch") +} + func TestSelectUserDefindVariable(t *testing.T) { executor, _, _, _ := createLegacyExecutorEnv() executor.normalize = true diff --git a/go/vt/vtgate/executor_set_test.go b/go/vt/vtgate/executor_set_test.go index bde58c4af98..c9fb9f1859e 100644 --- a/go/vt/vtgate/executor_set_test.go +++ b/go/vt/vtgate/executor_set_test.go @@ -170,7 +170,7 @@ func TestExecutorSet(t *testing.T) { out: &vtgatepb.Session{Autocommit: true, Options: &querypb.ExecuteOptions{SqlSelectLimit: 0}}, }, { in: "set sql_select_limit = 'asdfasfd'", - err: "unexpected value type for sql_select_limit: string", + err: "failed to evaluate value for sql_select_limit: expected int, unexpected value type: string", }, { in: "set autocommit = 1+1", err: "System setting 'autocommit' can't be set to this value: 2 is not a boolean", @@ -355,7 +355,7 @@ func TestExecutorSetMetadata(t *testing.T) { assert.NoError(t, err) want := "1" - got := string(result.Rows[0][1].ToString()) + got := result.Rows[0][1].ToString() assert.Equalf(t, want, got, "want migrations %s, result %s", want, got) // Update metadata diff --git a/go/vt/vtgate/planbuilder/builder.go b/go/vt/vtgate/planbuilder/builder.go index c1f12e64dd6..ac240b0d265 100644 --- a/go/vt/vtgate/planbuilder/builder.go +++ b/go/vt/vtgate/planbuilder/builder.go @@ -281,7 +281,7 @@ func Build(query string, vschema ContextVSchema) (*engine.Plan, error) { var ErrPlanNotSupported = errors.New("plan building not supported") // BuildFromStmt builds a plan based on the AST provided. -func BuildFromStmt(query string, stmt sqlparser.Statement, vschema ContextVSchema, bindVarNeeds sqlparser.BindVarNeeds) (*engine.Plan, error) { +func BuildFromStmt(query string, stmt sqlparser.Statement, vschema ContextVSchema, bindVarNeeds *sqlparser.BindVarNeeds) (*engine.Plan, error) { instruction, err := createInstructionFor(query, stmt, vschema) if err != nil { return nil, err diff --git a/go/vt/vtgate/planbuilder/expression_converter.go b/go/vt/vtgate/planbuilder/expression_converter.go index 7a19c9e5dd8..751228e9faf 100644 --- a/go/vt/vtgate/planbuilder/expression_converter.go +++ b/go/vt/vtgate/planbuilder/expression_converter.go @@ -32,6 +32,10 @@ type expressionConverter struct { } func booleanValues(astExpr sqlparser.Expr) evalengine.Expr { + var ( + ON = evalengine.NewLiteralInt(1) + OFF = evalengine.NewLiteralInt(0) + ) switch node := astExpr.(type) { case *sqlparser.Literal: //set autocommit = 'on' diff --git a/go/vt/vtgate/planbuilder/set.go b/go/vt/vtgate/planbuilder/set.go index cf824afba67..b89b9a95bfc 100644 --- a/go/vt/vtgate/planbuilder/set.go +++ b/go/vt/vtgate/planbuilder/set.go @@ -108,9 +108,13 @@ func buildNotSupported(setting) planFunc { func buildSetOpIgnore(s setting) planFunc { return func(expr *sqlparser.SetExpr, vschema ContextVSchema, _ *expressionConverter) (engine.SetOp, error) { + value, err := extractValue(expr, s.boolean) + if err != nil { + return nil, err + } return &engine.SysVarIgnore{ Name: expr.Name.Lowered(), - Expr: extractValue(expr, s.boolean), + Expr: value, }, nil } } @@ -126,12 +130,16 @@ func planSysVarCheckIgnore(expr *sqlparser.SetExpr, schema ContextVSchema, boole if err != nil { return nil, err } + value, err := extractValue(expr, boolean) + if err != nil { + return nil, err + } return &engine.SysVarCheckAndIgnore{ Name: expr.Name.Lowered(), Keyspace: keyspace, TargetDestination: dest, - Expr: extractValue(expr, boolean), + Expr: value, }, nil } @@ -155,7 +163,7 @@ func expressionOkToDelegateToTablet(e sqlparser.Expr) bool { return valid } -func buildSetOpVarSet(s setting) planFunc { +func buildSetOpReservedConn(s setting) planFunc { return func(expr *sqlparser.SetExpr, vschema ContextVSchema, _ *expressionConverter) (engine.SetOp, error) { if !vschema.SysVarSetEnabled() { return planSysVarCheckIgnore(expr, vschema, s.boolean) @@ -164,16 +172,22 @@ func buildSetOpVarSet(s setting) planFunc { if err != nil { return nil, err } + value, err := extractValue(expr, s.boolean) + if err != nil { + return nil, err + } - return &engine.SysVarSet{ + return &engine.SysVarReservedConn{ Name: expr.Name.Lowered(), Keyspace: ks, TargetDestination: vschema.Destination(), - Expr: extractValue(expr, s.boolean), + Expr: value, }, nil } } +const defaultNotSupportedErrFmt = "DEFAULT not supported for @@%s" + func buildSetOpVitessAware(s setting) planFunc { return func(astExpr *sqlparser.SetExpr, vschema ContextVSchema, ec *expressionConverter) (engine.SetOp, error) { var err error @@ -182,7 +196,8 @@ func buildSetOpVitessAware(s setting) planFunc { _, isDefault := astExpr.Expr.(*sqlparser.Default) if isDefault { if s.defaultValue == nil { - return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "don't know default value for %s", astExpr.Name) + + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, defaultNotSupportedErrFmt, astExpr.Name) } runtimeExpr = s.defaultValue } else { @@ -212,15 +227,15 @@ func resolveDestination(vschema ContextVSchema) (*vindexes.Keyspace, key.Destina return keyspace, dest, nil } -func extractValue(expr *sqlparser.SetExpr, boolean bool) string { +func extractValue(expr *sqlparser.SetExpr, boolean bool) (string, error) { switch node := expr.Expr.(type) { case *sqlparser.Literal: if node.Type == sqlparser.StrVal && boolean { switch strings.ToLower(string(node.Val)) { case "on": - return "1" + return "1", nil case "off": - return "0" + return "0", nil } } case *sqlparser.ColName: @@ -229,15 +244,17 @@ func extractValue(expr *sqlparser.SetExpr, boolean bool) string { if node.Name.AtCount() == sqlparser.NoAt { switch node.Name.Lowered() { case "on": - return "1" + return "1", nil case "off": - return "0" + return "0", nil } - return fmt.Sprintf("'%s'", sqlparser.String(expr.Expr)) + return fmt.Sprintf("'%s'", sqlparser.String(expr.Expr)), nil } + case *sqlparser.Default: + return "", vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, defaultNotSupportedErrFmt, expr.Name) } - return sqlparser.String(expr.Expr) + return sqlparser.String(expr.Expr), nil } // whitelist of functions knows to be safe to pass through to mysql for evaluation diff --git a/go/vt/vtgate/planbuilder/system_settings.go b/go/vt/vtgate/planbuilder/system_settings.go deleted file mode 100644 index 9cc04698f22..00000000000 --- a/go/vt/vtgate/planbuilder/system_settings.go +++ /dev/null @@ -1,213 +0,0 @@ -/* -Copyright 2020 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package planbuilder - -import ( - "vitess.io/vitess/go/vt/vtgate/engine" - "vitess.io/vitess/go/vt/vtgate/evalengine" -) - -func init() { - forSettings(ignoreThese, buildSetOpIgnore) - forSettings(useReservedConn, buildSetOpVarSet) - forSettings(checkAndIgnore, buildSetOpCheckAndIgnore) - forSettings(notSupported, buildNotSupported) - forSettings(vitessAware, buildSetOpVitessAware) -} - -func forSettings(settings []setting, f func(s setting) planFunc) { - for _, setting := range settings { - if _, alreadyExists := sysVarPlanningFunc[setting.name]; alreadyExists { - panic("bug in set plan init - " + setting.name + " aleady configured") - } - sysVarPlanningFunc[setting.name] = f(setting) - } -} - -var ( - ON = evalengine.NewLiteralInt(1) - OFF = evalengine.NewLiteralInt(0) - - vitessAware = []setting{ - {name: engine.Autocommit, boolean: true, defaultValue: ON}, - {name: engine.ClientFoundRows, boolean: true, defaultValue: OFF}, - {name: engine.SkipQueryPlanCache, boolean: true, defaultValue: OFF}, - {name: engine.TransactionReadOnly, boolean: true, defaultValue: OFF}, - {name: engine.TxReadOnly, boolean: true, defaultValue: OFF}, - {name: engine.SQLSelectLimit, defaultValue: OFF}, - {name: engine.TransactionMode, identifierAsString: true, defaultValue: evalengine.NewLiteralString([]byte("MULTI"))}, - {name: engine.Workload, identifierAsString: true, defaultValue: evalengine.NewLiteralString([]byte("UNSPECIFIED"))}, - {name: engine.Charset, identifierAsString: true, defaultValue: evalengine.NewLiteralString([]byte("utf8"))}, - {name: engine.Names, identifierAsString: true, defaultValue: evalengine.NewLiteralString([]byte("utf8"))}, - } - - notSupported = []setting{ - {name: "audit_log_read_buffer_size"}, - {name: "auto_increment_increment"}, - {name: "auto_increment_offset"}, - {name: "binlog_direct_non_transactional_updates"}, - {name: "binlog_row_image"}, - {name: "binlog_rows_query_log_events"}, - {name: "innodb_ft_enable_stopword"}, - {name: "innodb_ft_user_stopword_table"}, - {name: "max_points_in_geometry"}, - {name: "max_sp_recursion_depth"}, - {name: "myisam_repair_threads"}, - {name: "myisam_sort_buffer_size"}, - {name: "myisam_stats_method"}, - {name: "ndb_allow_copying_alter_table"}, - {name: "ndb_autoincrement_prefetch_sz"}, - {name: "ndb_blob_read_batch_bytes"}, - {name: "ndb_blob_write_batch_bytes"}, - {name: "ndb_deferred_constraints"}, - {name: "ndb_force_send"}, - {name: "ndb_fully_replicated"}, - {name: "ndb_index_stat_enable"}, - {name: "ndb_index_stat_option"}, - {name: "ndb_join_pushdown"}, - {name: "ndb_log_bin"}, - {name: "ndb_log_exclusive_reads"}, - {name: "ndb_row_checksum"}, - {name: "ndb_use_exact_count"}, - {name: "ndb_use_transactions"}, - {name: "ndbinfo_max_bytes"}, - {name: "ndbinfo_max_rows"}, - {name: "ndbinfo_show_hidden"}, - {name: "ndbinfo_table_prefix"}, - {name: "old_alter_table"}, - {name: "preload_buffer_size"}, - {name: "rbr_exec_mode"}, - {name: "sql_log_off"}, - {name: "thread_pool_high_priority_connection"}, - {name: "thread_pool_prio_kickup_timer"}, - {name: "transaction_write_set_extraction"}, - } - - ignoreThese = []setting{ - {name: "big_tables", boolean: true}, - {name: "bulk_insert_buffer_size"}, - {name: "debug"}, - {name: "default_storage_engine"}, - {name: "default_tmp_storage_engine"}, - {name: "innodb_strict_mode", boolean: true}, - {name: "innodb_support_xa", boolean: true}, - {name: "innodb_table_locks", boolean: true}, - {name: "innodb_tmpdir"}, - {name: "join_buffer_size"}, - {name: "keep_files_on_create", boolean: true}, - {name: "lc_messages"}, - {name: "long_query_time"}, - {name: "low_priority_updates", boolean: true}, - {name: "max_delayed_threads"}, - {name: "max_insert_delayed_threads"}, - {name: "multi_range_count"}, - {name: "net_buffer_length"}, - {name: "new", boolean: true}, - {name: "query_cache_type"}, - {name: "query_cache_wlock_invalidate", boolean: true}, - {name: "query_prealloc_size"}, - {name: "sql_buffer_result", boolean: true}, - {name: "transaction_alloc_block_size"}, - {name: "wait_timeout"}, - } - - useReservedConn = []setting{ - {name: "default_week_format"}, - {name: "end_markers_in_json", boolean: true}, - {name: "eq_range_index_dive_limit"}, - {name: "explicit_defaults_for_timestamp"}, - {name: "foreign_key_checks", boolean: true}, - {name: "group_concat_max_len"}, - {name: "max_heap_table_size"}, - {name: "max_seeks_for_key"}, - {name: "max_tmp_tables"}, - {name: "min_examined_row_limit"}, - {name: "old_passwords"}, - {name: "optimizer_prune_level"}, - {name: "optimizer_search_depth"}, - {name: "optimizer_switch"}, - {name: "optimizer_trace"}, - {name: "optimizer_trace_features"}, - {name: "optimizer_trace_limit"}, - {name: "optimizer_trace_max_mem_size"}, - {name: "transaction_isolation"}, - {name: "tx_isolation"}, - {name: "optimizer_trace_offset"}, - {name: "parser_max_mem_size"}, - {name: "profiling", boolean: true}, - {name: "profiling_history_size"}, - {name: "query_alloc_block_size"}, - {name: "range_alloc_block_size"}, - {name: "range_optimizer_max_mem_size"}, - {name: "read_buffer_size"}, - {name: "read_rnd_buffer_size"}, - {name: "show_create_table_verbosity", boolean: true}, - {name: "show_old_temporals", boolean: true}, - {name: "sort_buffer_size"}, - {name: "sql_big_selects", boolean: true}, - {name: "sql_mode"}, - {name: "sql_notes", boolean: true}, - {name: "sql_quote_show_create", boolean: true}, - {name: "sql_safe_updates", boolean: true}, - {name: "sql_warnings", boolean: true}, - {name: "tmp_table_size"}, - {name: "transaction_prealloc_size"}, - {name: "unique_checks", boolean: true}, - {name: "updatable_views_with_limit", boolean: true}, - } - - // TODO: Most of these settings should be moved into SysSetOpAware, and change Vitess behaviour. - // Until then, SET statements against these settings are allowed - // as long as they have the same value as the underlying database - checkAndIgnore = []setting{ - {name: "binlog_format"}, - {name: "block_encryption_mode"}, - {name: "character_set_client"}, - {name: "character_set_connection"}, - {name: "character_set_database"}, - {name: "character_set_filesystem"}, - {name: "character_set_results"}, - {name: "character_set_server"}, - {name: "collation_connection"}, - {name: "collation_database"}, - {name: "collation_server"}, - {name: "completion_type"}, - {name: "div_precision_increment"}, - {name: "innodb_lock_wait_timeout"}, - {name: "interactive_timeout"}, - {name: "lc_time_names"}, - {name: "lock_wait_timeout"}, - {name: "max_allowed_packet"}, - {name: "max_error_count"}, - {name: "max_execution_time"}, - {name: "max_join_size"}, - {name: "max_length_for_sort_data"}, - {name: "max_sort_length"}, - {name: "max_user_connections"}, - {name: "net_read_timeout"}, - {name: "net_retry_count"}, - {name: "net_write_timeout"}, - {name: "session_track_gtids"}, - {name: "session_track_schema", boolean: true}, - {name: "session_track_state_change", boolean: true}, - {name: "session_track_system_variables"}, - {name: "session_track_transaction_info"}, - {name: "sql_auto_is_null", boolean: true}, - {name: "time_zone"}, - {name: "version_tokens_session"}, - } -) diff --git a/go/vt/vtgate/planbuilder/system_variables.go b/go/vt/vtgate/planbuilder/system_variables.go new file mode 100644 index 00000000000..ce02cb497ff --- /dev/null +++ b/go/vt/vtgate/planbuilder/system_variables.go @@ -0,0 +1,68 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package planbuilder + +import ( + "fmt" + + "vitess.io/vitess/go/vt/sysvars" + + "vitess.io/vitess/go/vt/sqlparser" + + "vitess.io/vitess/go/vt/vtgate/evalengine" +) + +func init() { + forSettings(sysvars.IgnoreThese, buildSetOpIgnore) + forSettings(sysvars.UseReservedConn, buildSetOpReservedConn) + forSettings(sysvars.CheckAndIgnore, buildSetOpCheckAndIgnore) + forSettings(sysvars.NotSupported, buildNotSupported) + forSettings(sysvars.VitessAware, buildSetOpVitessAware) +} + +func forSettings(systemVariables []sysvars.SystemVariable, f func(setting) planFunc) { + for _, sysvar := range systemVariables { + if _, alreadyExists := sysVarPlanningFunc[sysvar.Name]; alreadyExists { + panic("bug in set plan init - " + sysvar.Name + " already configured") + } + + s := setting{ + name: sysvar.Name, + boolean: sysvar.IsBoolean, + identifierAsString: sysvar.IdentifierAsString, + } + + if sysvar.Default != "" { + s.defaultValue = parseAndBuildDefaultValue(sysvar) + } + sysVarPlanningFunc[sysvar.Name] = f(s) + } +} + +func parseAndBuildDefaultValue(sysvar sysvars.SystemVariable) evalengine.Expr { + stmt, err := sqlparser.Parse(fmt.Sprintf("select %s", sysvar.Default)) + if err != nil { + panic(fmt.Sprintf("bug in set plan init - default value for %s not parsable: %s", sysvar.Name, sysvar.Default)) + } + sel := stmt.(*sqlparser.Select) + aliasedExpr := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + def, err := sqlparser.Convert(aliasedExpr.Expr) + if err != nil { + panic(fmt.Sprintf("bug in set plan init - default value for %s not able to convert to evalengine.Expr: %s", sysvar.Name, sysvar.Default)) + } + return def +} diff --git a/go/vt/vtgate/planbuilder/testdata/set_cases.txt b/go/vt/vtgate/planbuilder/testdata/set_cases.txt index 6f99d615d57..c07ea13103e 100644 --- a/go/vt/vtgate/planbuilder/testdata/set_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/set_cases.txt @@ -457,4 +457,25 @@ ] } } - \ No newline at end of file + +# set autocommit to default +"set @@autocommit = default" +{ + "QueryType": "SET", + "Original": "set @@autocommit = default", + "Instructions": { + "OperatorType": "Set", + "Ops": [ + { + "Type": "SysVarAware", + "Name": "autocommit", + "Expr": "INT64(1)" + } + ], + "Inputs": [ + { + "OperatorType": "SingleRow" + } + ] + } +} diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt index f2b9a155564..3f851250904 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt @@ -427,3 +427,11 @@ # union with SQL_CALC_FOUND_ROWS "(select sql_calc_found_rows id from user where id = 1 limit 1) union select id from user where id = 1" "SQL_CALC_FOUND_ROWS not supported with union" + +# set with DEFAULT - vitess aware +"set workload = default" +"DEFAULT not supported for @@workload" + +# set with DEFAULT - reserved connection +"set sql_mode = default" +"DEFAULT not supported for @@sql_mode" diff --git a/go/vt/vtgate/vcursor_impl.go b/go/vt/vtgate/vcursor_impl.go index 0acd4171887..56f8d132900 100644 --- a/go/vt/vtgate/vcursor_impl.go +++ b/go/vt/vtgate/vcursor_impl.go @@ -494,7 +494,7 @@ func (vc *vcursorImpl) TargetDestination(qualifier string) (key.Destination, *vi return vc.destination, keyspace.Keyspace, vc.tabletType, nil } -//SetAutocommit implementes the SessionActions interface +//SetAutocommit implements the SessionActions interface func (vc *vcursorImpl) SetAutocommit(autocommit bool) error { if autocommit && vc.safeSession.InTransaction() { if err := vc.executor.Commit(vc.ctx, vc.safeSession); err != nil { @@ -505,27 +505,30 @@ func (vc *vcursorImpl) SetAutocommit(autocommit bool) error { return nil } -//SetClientFoundRows implementes the SessionActions interface -func (vc *vcursorImpl) SetClientFoundRows(clientFoundRows bool) { +//SetClientFoundRows implements the SessionActions interface +func (vc *vcursorImpl) SetClientFoundRows(clientFoundRows bool) error { vc.safeSession.GetOrCreateOptions().ClientFoundRows = clientFoundRows + return nil } -//SetSkipQueryPlanCache implementes the SessionActions interface -func (vc *vcursorImpl) SetSkipQueryPlanCache(skipQueryPlanCache bool) { +//SetSkipQueryPlanCache implements the SessionActions interface +func (vc *vcursorImpl) SetSkipQueryPlanCache(skipQueryPlanCache bool) error { vc.safeSession.GetOrCreateOptions().SkipQueryPlanCache = skipQueryPlanCache + return nil } -//SetSkipQueryPlanCache implementes the SessionActions interface -func (vc *vcursorImpl) SetSQLSelectLimit(limit int64) { +//SetSkipQueryPlanCache implements the SessionActions interface +func (vc *vcursorImpl) SetSQLSelectLimit(limit int64) error { vc.safeSession.GetOrCreateOptions().SqlSelectLimit = limit + return nil } -//SetSkipQueryPlanCache implementes the SessionActions interface +//SetSkipQueryPlanCache implements the SessionActions interface func (vc *vcursorImpl) SetTransactionMode(mode vtgatepb.TransactionMode) { vc.safeSession.TransactionMode = mode } -//SetWorkload implementes the SessionActions interface +//SetWorkload implements the SessionActions interface func (vc *vcursorImpl) SetWorkload(workload querypb.ExecuteOptions_Workload) { vc.safeSession.GetOrCreateOptions().Workload = workload }