Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 86 additions & 72 deletions go/vt/proto/vtgate/vtgate.pb.go

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions go/vt/vtgate/engine/fake_vcursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (
var testMaxMemoryRows = 100

var _ VCursor = (*noopVCursor)(nil)
var _ SessionActions = (*noopVCursor)(nil)

// noopVCursor is used to build other vcursors.
type noopVCursor struct {
Expand All @@ -51,6 +52,10 @@ func (t noopVCursor) SetUDV(key string, value interface{}) error {
panic("implement me")
}

func (t noopVCursor) SetSysVar(name string, expr string) {
//panic("implement me")
}

func (t noopVCursor) ExecuteVSchema(keyspace string, vschemaDDL *sqlparser.DDL) error {
panic("implement me")
}
Expand Down Expand Up @@ -105,6 +110,7 @@ func (t noopVCursor) ResolveDestinations(keyspace string, ids []*querypb.Value,
}

var _ VCursor = (*loggingVCursor)(nil)
var _ SessionActions = (*loggingVCursor)(nil)

// loggingVCursor logs requests and allows you to verify
// that the correct requests were made.
Expand Down Expand Up @@ -134,6 +140,10 @@ func (f *loggingVCursor) SetUDV(key string, value interface{}) error {
return nil
}

func (f *loggingVCursor) SetSysVar(name string, expr string) {
f.log = append(f.log, fmt.Sprintf("SysVar set with (%s,%v)", name, expr))
}

func (f *loggingVCursor) ExecuteVSchema(keyspace string, vschemaDDL *sqlparser.DDL) error {
panic("implement me")
}
Expand Down
3 changes: 1 addition & 2 deletions go/vt/vtgate/engine/ordered_aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,6 @@ func TestOrderedAggregateKeysFail(t *testing.T) {
}

func TestOrderedAggregateMergeFail(t *testing.T) {
t.Skip("this looks like an invalid test")
fields := sqltypes.MakeTestFields(
"col|count(*)",
"varbinary|decimal",
Expand All @@ -582,7 +581,7 @@ func TestOrderedAggregateMergeFail(t *testing.T) {
results: []*sqltypes.Result{sqltypes.MakeTestResult(
fields,
"a|1",
"a|b",
"a|0",
)},
}

Expand Down
3 changes: 3 additions & 0 deletions go/vt/vtgate/engine/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ type (
RecordWarning(warning *querypb.QueryWarning)

SetTarget(target string) error

SetUDV(key string, value interface{}) error

SetSysVar(name string, expr string)
}

// Plan represents the execution strategy for a given query.
Expand Down
46 changes: 46 additions & 0 deletions go/vt/vtgate/engine/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ type (
TargetDestination key.Destination
Expr string
}

// SysVarSet implements the SetOp interface and will write the changes variable into the session
SysVarSet struct {
Name string
Keyspace *vindexes.Keyspace
TargetDestination key.Destination
Expr string
}
)

var _ Primitive = (*Set)(nil)
Expand Down Expand Up @@ -235,3 +243,41 @@ func (svci *SysVarCheckAndIgnore) Execute(vcursor VCursor, env evalengine.Expres
vcursor.Session().RecordWarning(warning)
return nil
}

var _ SetOp = (*SysVarSet)(nil)

//MarshalJSON provides the type to SetOp for plan json
func (svs *SysVarSet) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
SysVarSet
}{
Type: "SysVarSet",
SysVarSet: *svs,
})

}

//VariableName implements the SetOp interface method
func (svs *SysVarSet) VariableName() string {
return svs.Name
}

//Execute implements the SetOp interface method
func (svs *SysVarSet) Execute(vcursor VCursor, res evalengine.ExpressionEnv) error {
rss, _, err := vcursor.ResolveDestinations(svs.Keyspace.Name, nil, []key.Destination{svs.TargetDestination})
if err != nil {
return vterrors.Wrap(err, "SysVarSet")
}

if len(rss) != 1 {
return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Unexpected error, DestinationKeyspaceID mapping to multiple shards: %v", svs.TargetDestination)
}
sysVarExprValidationQuery := fmt.Sprintf("select %s from dual where false", svs.Expr)
_, err = execShard(vcursor, sysVarExprValidationQuery, res.BindVars, rss[0], false /* rollbackOnError */, false /* canAutocommit */)
if err != nil {
return err
}
vcursor.Session().SetSysVar(svs.Name, svs.Expr)
return nil
}
52 changes: 52 additions & 0 deletions go/vt/vtgate/engine/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package engine

import (
"fmt"
"strconv"
"testing"

Expand Down Expand Up @@ -180,6 +181,25 @@ func TestSetTable(t *testing.T) {
"1",
)},
},
{
testName: "sysvar set",
setOps: []SetOp{
&SysVarSet{
Name: "x",
Keyspace: &vindexes.Keyspace{
Name: "ks",
Sharded: true,
},
TargetDestination: key.DestinationAnyShard{},
Expr: "dummy_expr",
},
},
expectedQueryLog: []string{
`ResolveDestinations ks [] Destinations:DestinationAnyShard()`,
`ExecuteMultiShard ks.-20: select dummy_expr from dual where false {} false false`,
`SysVar set with (x,dummy_expr)`,
},
},
}

for _, tc := range tests {
Expand All @@ -204,3 +224,35 @@ func TestSetTable(t *testing.T) {
})
}
}

func TestSysVarSetErr(t *testing.T) {

setOps := []SetOp{
&SysVarSet{
Name: "x",
Keyspace: &vindexes.Keyspace{
Name: "ks",
Sharded: true,
},
TargetDestination: key.DestinationAnyShard{},
Expr: "dummy_expr",
},
}

expectedQueryLog := []string{
`ResolveDestinations ks [] Destinations:DestinationAnyShard()`,
`ExecuteMultiShard ks.-20: select dummy_expr from dual where false {} false false`,
}

set := &Set{
Ops: setOps,
Input: &SingleRow{},
}
vc := &loggingVCursor{
shards: []string{"-20", "20-"},
multiShardErrs: []error{fmt.Errorf("error")},
}
_, err := set.Execute(vc, map[string]*querypb.BindVariable{}, false)
require.EqualError(t, err, "error")
vc.ExpectLog(t, expectedQueryLog)
}
1 change: 1 addition & 0 deletions go/vt/vtgate/planbuilder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ type ContextVSchema interface {
Destination() key.Destination
TabletType() topodatapb.TabletType
TargetDestination(qualifier string) (key.Destination, *vindexes.Keyspace, topodatapb.TabletType, error)
AnyKeyspace() (*vindexes.Keyspace, error)
}

//-------------------------------------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions go/vt/vtgate/planbuilder/plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,10 @@ func (vw *vschemaWrapper) DefaultKeyspace() (*vindexes.Keyspace, error) {
return vw.v.Keyspaces["main"].Keyspace, nil
}

func (vw *vschemaWrapper) AnyKeyspace() (*vindexes.Keyspace, error) {
return vw.DefaultKeyspace()
}

func (vw *vschemaWrapper) TargetString() string {
return "targetString"
}
Expand Down
84 changes: 50 additions & 34 deletions go/vt/vtgate/planbuilder/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"fmt"
"strings"

"vitess.io/vitess/go/vt/vtgate/vindexes"

vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"

Expand All @@ -34,6 +36,7 @@ var sysVarPlanningFunc = map[string]func(expr *sqlparser.SetExpr, vschema Contex
func init() {
sysVarPlanningFunc["default_storage_engine"] = buildSetOpIgnore
sysVarPlanningFunc["sql_mode"] = buildSetOpCheckAndIgnore
sysVarPlanningFunc["sql_safe_updates"] = buildSetOpVarSet
}

func buildSetPlan(stmt *sqlparser.Set, vschema ContextVSchema) (engine.Primitive, error) {
Expand Down Expand Up @@ -100,24 +103,19 @@ func buildSetPlan(stmt *sqlparser.Set, vschema ContextVSchema) (engine.Primitive
}

func planTabletInput(vschema ContextVSchema, tabletExpressions []*sqlparser.SetExpr) (engine.Primitive, error) {
keyspace, err := vschema.DefaultKeyspace()
ks, dest, err := resolveDestination(vschema)
if err != nil {
return nil, err
}

dest := vschema.Destination()
if dest == nil {
dest = key.DestinationAnyShard{}
}

var expr []string
for _, e := range tabletExpressions {
expr = append(expr, sqlparser.String(e.Expr))
}
query := fmt.Sprintf("select %s from dual", strings.Join(expr, ","))

primitive := &engine.Send{
Keyspace: keyspace,
Keyspace: ks,
TargetDestination: dest,
Query: query,
IsDML: false,
Expand All @@ -137,20 +135,11 @@ func buildSetOpIgnore(expr *sqlparser.SetExpr, _ ContextVSchema) (engine.SetOp,
}

func buildSetOpCheckAndIgnore(expr *sqlparser.SetExpr, vschema ContextVSchema) (engine.SetOp, error) {
keyspace, err := vschema.DefaultKeyspace()
keyspace, dest, err := resolveDestination(vschema)
if err != nil {
//TODO: Record warning for switching plan construct.
if strings.HasPrefix(err.Error(), "no keyspace in database name specified") {
return buildSetOpIgnore(expr, vschema)
}
return nil, err
}

dest := vschema.Destination()
if dest == nil {
dest = key.DestinationAnyShard{}
}

return &engine.SysVarCheckAndIgnore{
Name: expr.Name.Lowered(),
Keyspace: keyspace,
Expand All @@ -159,6 +148,50 @@ func buildSetOpCheckAndIgnore(expr *sqlparser.SetExpr, vschema ContextVSchema) (
}, nil
}

func expressionOkToDelegateToTablet(e sqlparser.Expr) bool {
valid := true
sqlparser.Rewrite(e, nil, func(cursor *sqlparser.Cursor) bool {
switch n := cursor.Node().(type) {
case *sqlparser.Subquery, *sqlparser.TimestampFuncExpr, *sqlparser.CurTimeFuncExpr:
valid = false
return false
case *sqlparser.FuncExpr:
_, ok := validFuncs[n.Name.Lowered()]
valid = ok
return ok
}
return true
})
return valid
}

func buildSetOpVarSet(expr *sqlparser.SetExpr, vschema ContextVSchema) (engine.SetOp, error) {
ks, dest, err := resolveDestination(vschema)
if err != nil {
return nil, err
}

return &engine.SysVarSet{
Name: expr.Name.Lowered(),
Keyspace: ks,
TargetDestination: dest,
Expr: sqlparser.String(expr.Expr),
}, nil
}

func resolveDestination(vschema ContextVSchema) (*vindexes.Keyspace, key.Destination, error) {
keyspace, err := vschema.AnyKeyspace()
if err != nil {
return nil, nil, err
}

dest := vschema.Destination()
if dest == nil {
dest = key.DestinationAnyShard{}
}
return keyspace, dest, nil
}

// whitelist of functions knows to be safe to pass through to mysql for evaluation
// this list tries to not include functions that might return different results on different tablets
var validFuncs = map[string]interface{}{
Expand Down Expand Up @@ -292,20 +325,3 @@ var validFuncs = map[string]interface{}{
"upper": nil,
"weight_string": nil,
}

func expressionOkToDelegateToTablet(e sqlparser.Expr) bool {
valid := true
sqlparser.Rewrite(e, nil, func(cursor *sqlparser.Cursor) bool {
switch n := cursor.Node().(type) {
case *sqlparser.Subquery, *sqlparser.TimestampFuncExpr, *sqlparser.CurTimeFuncExpr:
valid = false
return false
case *sqlparser.FuncExpr:
_, ok := validFuncs[n.Name.Lowered()]
valid = ok
return ok
}
return true
})
return valid
}
27 changes: 27 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/set_sysvar_cases.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,33 @@
}
}

# set check and ignore plan
"set @@sql_safe_updates = 1"
{
"QueryType": "SET",
"Original": "set @@sql_safe_updates = 1",
"Instructions": {
"OperatorType": "Set",
"Ops": [
{
"Type": "SysVarSet",
"Name": "sql_safe_updates",
"Keyspace": {
"Name": "main",
"Sharded": false
},
"TargetDestination": {},
"Expr": "1"
}
],
"Inputs": [
{
"OperatorType": "SingleRow"
}
]
}
}

# set plan building not supported
"set @@innodb_strict_mode = OFF"
"plan building not supported"
7 changes: 7 additions & 0 deletions go/vt/vtgate/safe_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,10 @@ func (session *SafeSession) SetTargetString(target string) {
defer session.mu.Unlock()
session.TargetString = target
}

//SetSystemVariable sets the system variable in th session.
func (session *SafeSession) SetSystemVariable(name string, expr string) {
session.mu.Lock()
defer session.mu.Unlock()
session.SystemVariables[name] = expr
}
Loading