diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index 6a12fc4fab4..fb88acd5718 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -10,6 +10,7 @@ import ( "github.com/youtube/vitess/go/sqltypes" querypb "github.com/youtube/vitess/go/vt/proto/query" + "github.com/youtube/vitess/go/vt/vtgate/queryinfo" ) // Join specifies the parameters for a join primitive. @@ -33,8 +34,8 @@ type Join struct { } // Execute performs a non-streaming exec. -func (jn *Join) Execute(vcursor VCursor, joinvars map[string]interface{}, wantfields bool) (*sqltypes.Result, error) { - lresult, err := jn.Left.Execute(vcursor, joinvars, wantfields) +func (jn *Join) Execute(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, joinvars map[string]interface{}, wantfields bool) (*sqltypes.Result, error) { + lresult, err := jn.Left.Execute(vcursor, queryConstruct, joinvars, wantfields) if err != nil { return nil, err } @@ -43,7 +44,7 @@ func (jn *Join) Execute(vcursor VCursor, joinvars map[string]interface{}, wantfi for k := range jn.Vars { joinvars[k] = nil } - rresult, err := jn.Right.GetFields(vcursor, joinvars) + rresult, err := jn.Right.GetFields(vcursor, queryConstruct, joinvars) if err != nil { return nil, err } @@ -54,7 +55,7 @@ func (jn *Join) Execute(vcursor VCursor, joinvars map[string]interface{}, wantfi for k, col := range jn.Vars { joinvars[k] = lrow[col] } - rresult, err := jn.Right.Execute(vcursor, joinvars, wantfields) + rresult, err := jn.Right.Execute(vcursor, queryConstruct, joinvars, wantfields) if err != nil { return nil, err } @@ -76,14 +77,14 @@ func (jn *Join) Execute(vcursor VCursor, joinvars map[string]interface{}, wantfi } // StreamExecute performs a streaming exec. -func (jn *Join) StreamExecute(vcursor VCursor, joinvars map[string]interface{}, wantfields bool, sendReply func(*sqltypes.Result) error) error { - err := jn.Left.StreamExecute(vcursor, joinvars, wantfields, func(lresult *sqltypes.Result) error { +func (jn *Join) StreamExecute(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, joinvars map[string]interface{}, wantfields bool, sendReply func(*sqltypes.Result) error) error { + err := jn.Left.StreamExecute(vcursor, queryConstruct, joinvars, wantfields, func(lresult *sqltypes.Result) error { for _, lrow := range lresult.Rows { for k, col := range jn.Vars { joinvars[k] = lrow[col] } rowSent := false - err := jn.Right.StreamExecute(vcursor, joinvars, wantfields, func(rresult *sqltypes.Result) error { + err := jn.Right.StreamExecute(vcursor, queryConstruct, joinvars, wantfields, func(rresult *sqltypes.Result) error { result := &sqltypes.Result{} if wantfields { wantfields = false @@ -120,7 +121,7 @@ func (jn *Join) StreamExecute(vcursor VCursor, joinvars map[string]interface{}, joinvars[k] = nil } result := &sqltypes.Result{} - rresult, err := jn.Right.GetFields(vcursor, joinvars) + rresult, err := jn.Right.GetFields(vcursor, queryConstruct, joinvars) if err != nil { return err } @@ -133,8 +134,8 @@ func (jn *Join) StreamExecute(vcursor VCursor, joinvars map[string]interface{}, } // GetFields fetches the field info. -func (jn *Join) GetFields(vcursor VCursor, joinvars map[string]interface{}) (*sqltypes.Result, error) { - lresult, err := jn.Left.GetFields(vcursor, joinvars) +func (jn *Join) GetFields(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, joinvars map[string]interface{}) (*sqltypes.Result, error) { + lresult, err := jn.Left.GetFields(vcursor, queryConstruct, joinvars) if err != nil { return nil, err } @@ -142,7 +143,7 @@ func (jn *Join) GetFields(vcursor VCursor, joinvars map[string]interface{}) (*sq for k := range jn.Vars { joinvars[k] = nil } - rresult, err := jn.Right.GetFields(vcursor, joinvars) + rresult, err := jn.Right.GetFields(vcursor, queryConstruct, joinvars) if err != nil { return nil, err } diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index aacb9af0aae..a3c8ae1c272 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -4,7 +4,12 @@ package engine -import "github.com/youtube/vitess/go/sqltypes" +import ( + "github.com/youtube/vitess/go/sqltypes" + topodatapb "github.com/youtube/vitess/go/vt/proto/topodata" + "github.com/youtube/vitess/go/vt/tabletserver/querytypes" + "github.com/youtube/vitess/go/vt/vtgate/queryinfo" +) // SeqVarName is a reserved bind var name for sequence values. const SeqVarName = "__seq" @@ -17,9 +22,14 @@ const ListVarName = "__vals" // VCursor defines the interface the engine will use // to execute routes. type VCursor interface { - ExecuteRoute(route *Route, joinvars map[string]interface{}) (*sqltypes.Result, error) - StreamExecuteRoute(route *Route, joinvars map[string]interface{}, sendReply func(*sqltypes.Result) error) error - GetRouteFields(route *Route, joinvars map[string]interface{}) (*sqltypes.Result, error) + ExecuteMultiShard(keyspace string, shardQueries map[string]querytypes.BoundQuery, notInTransaction bool) (*sqltypes.Result, error) + StreamExecuteMulti(query string, keyspace string, shardVars map[string]map[string]interface{}, sendReply func(reply *sqltypes.Result) error) error + GetAnyShard(keyspace string) (ks, shard string, err error) + ScatterConnExecute(query string, bindVars map[string]interface{}, keyspace string, shards []string, notInTransaction bool) (*sqltypes.Result, error) + GetKeyspaceShards(keyspace string) (string, *topodatapb.SrvKeyspace, []*topodatapb.ShardReference, error) + GetShardForKeyspaceID(allShards []*topodatapb.ShardReference, keyspaceID []byte) (string, error) + ExecuteShard(keyspace string, shardQueries map[string]querytypes.BoundQuery) (*sqltypes.Result, error) + Execute(query string, bindvars map[string]interface{}) (*sqltypes.Result, error) } // Plan represents the execution strategy for a given query. @@ -45,7 +55,7 @@ func (pln *Plan) Size() int { // Primitive is the interface that needs to be satisfied by // all primitives of a plan. type Primitive interface { - Execute(vcursor VCursor, joinvars map[string]interface{}, wantfields bool) (*sqltypes.Result, error) - StreamExecute(vcursor VCursor, joinvars map[string]interface{}, wantields bool, sendReply func(*sqltypes.Result) error) error - GetFields(vcursor VCursor, joinvars map[string]interface{}) (*sqltypes.Result, error) + Execute(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, joinvars map[string]interface{}, wantfields bool) (*sqltypes.Result, error) + StreamExecute(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, joinvars map[string]interface{}, wantields bool, sendReply func(*sqltypes.Result) error) error + GetFields(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, joinvars map[string]interface{}) (*sqltypes.Result, error) } diff --git a/go/vt/vtgate/engine/route.go b/go/vt/vtgate/engine/route.go index 1ee4aef135b..e6513de798f 100644 --- a/go/vt/vtgate/engine/route.go +++ b/go/vt/vtgate/engine/route.go @@ -8,7 +8,15 @@ import ( "encoding/json" "fmt" + "encoding/hex" + "strconv" + "strings" + "github.com/youtube/vitess/go/sqltypes" + querypb "github.com/youtube/vitess/go/vt/proto/query" + "github.com/youtube/vitess/go/vt/sqlannotation" + "github.com/youtube/vitess/go/vt/tabletserver/querytypes" + "github.com/youtube/vitess/go/vt/vtgate/queryinfo" "github.com/youtube/vitess/go/vt/vtgate/vindexes" ) @@ -31,30 +39,15 @@ type Route struct { Suffix string } -// Execute performs a non-streaming exec. -func (rt *Route) Execute(vcursor VCursor, joinvars map[string]interface{}, wantfields bool) (*sqltypes.Result, error) { - return vcursor.ExecuteRoute(rt, joinvars) -} - -// StreamExecute performs a streaming exec. -func (rt *Route) StreamExecute(vcursor VCursor, joinvars map[string]interface{}, wantfields bool, sendReply func(*sqltypes.Result) error) error { - return vcursor.StreamExecuteRoute(rt, joinvars, sendReply) -} - -// GetFields fetches the field info. -func (rt *Route) GetFields(vcursor VCursor, joinvars map[string]interface{}) (*sqltypes.Result, error) { - return vcursor.GetRouteFields(rt, joinvars) -} - // MarshalJSON serializes the Route into a JSON representation. // It's used for testing and diagnostics. -func (rt *Route) MarshalJSON() ([]byte, error) { +func (route *Route) MarshalJSON() ([]byte, error) { var tname, vindexName string - if rt.Table != nil { - tname = rt.Table.Name.String() + if route.Table != nil { + tname = route.Table.Name.String() } - if rt.Vindex != nil { - vindexName = rt.Vindex.String() + if route.Vindex != nil { + vindexName = route.Vindex.String() } marshalRoute := struct { Opcode RouteOpcode `json:",omitempty"` @@ -71,19 +64,19 @@ func (rt *Route) MarshalJSON() ([]byte, error) { Mid []string `json:",omitempty"` Suffix string `json:",omitempty"` }{ - Opcode: rt.Opcode, - Keyspace: rt.Keyspace, - Query: rt.Query, - FieldQuery: rt.FieldQuery, + Opcode: route.Opcode, + Keyspace: route.Keyspace, + Query: route.Query, + FieldQuery: route.FieldQuery, Vindex: vindexName, - Values: prettyValue(rt.Values), - JoinVars: rt.JoinVars, + Values: prettyValue(route.Values), + JoinVars: route.JoinVars, Table: tname, - Subquery: rt.Subquery, - Generate: rt.Generate, - Prefix: rt.Prefix, - Mid: rt.Mid, - Suffix: rt.Suffix, + Subquery: route.Subquery, + Generate: route.Generate, + Prefix: route.Prefix, + Mid: route.Mid, + Suffix: route.Suffix, } return json.Marshal(marshalRoute) } @@ -231,3 +224,589 @@ func (code RouteOpcode) String() string { func (code RouteOpcode) MarshalJSON() ([]byte, error) { return ([]byte)(fmt.Sprintf("\"%s\"", code.String())), nil } + +type scatterParams struct { + ks string + shardVars map[string]map[string]interface{} +} + +func newScatterParams(ks string, bv map[string]interface{}, shards []string) *scatterParams { + shardVars := make(map[string]map[string]interface{}, len(shards)) + for _, shard := range shards { + shardVars[shard] = bv + } + return &scatterParams{ + ks: ks, + shardVars: shardVars, + } +} + +// Execute performs a non-streaming exec. +func (route *Route) Execute(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, joinvars map[string]interface{}, wantfields bool) (*sqltypes.Result, error) { + saved := copyBindVars(queryConstruct.BindVars) + defer func() { queryConstruct.BindVars = saved }() + for k, v := range joinvars { + queryConstruct.BindVars[k] = v + } + + switch route.Opcode { + case UpdateEqual: + return route.execUpdateEqual(vcursor, queryConstruct) + case DeleteEqual: + return route.execDeleteEqual(vcursor, queryConstruct) + case InsertSharded: + return route.execInsertSharded(vcursor, queryConstruct) + case InsertUnsharded: + return route.execInsertUnsharded(vcursor, queryConstruct) + } + + var err error + var params *scatterParams + switch route.Opcode { + case SelectUnsharded, UpdateUnsharded, + DeleteUnsharded: + params, err = route.paramsUnsharded(vcursor, queryConstruct) + case SelectEqual, SelectEqualUnique: + params, err = route.paramsSelectEqual(vcursor, queryConstruct) + case SelectIN: + params, err = route.paramsSelectIN(vcursor, queryConstruct) + case SelectScatter: + params, err = route.paramsSelectScatter(vcursor, queryConstruct) + default: + // TODO(sougou): improve error. + return nil, fmt.Errorf("unsupported query route: %v", route) + } + if err != nil { + return nil, err + } + + shardQueries := route.getShardQueries(route.Query+queryConstruct.Comments, params) + return vcursor.ExecuteMultiShard(params.ks, shardQueries, queryConstruct.NotInTransaction) +} + +// StreamExecute performs a streaming exec. +func (route *Route) StreamExecute(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, joinvars map[string]interface{}, wantfields bool, sendReply func(*sqltypes.Result) error) error { + saved := copyBindVars(queryConstruct.BindVars) + defer func() { queryConstruct.BindVars = saved }() + for k, v := range joinvars { + queryConstruct.BindVars[k] = v + } + + var err error + var params *scatterParams + switch route.Opcode { + case SelectUnsharded: + params, err = route.paramsUnsharded(vcursor, queryConstruct) + case SelectEqual, SelectEqualUnique: + params, err = route.paramsSelectEqual(vcursor, queryConstruct) + case SelectIN: + params, err = route.paramsSelectIN(vcursor, queryConstruct) + case SelectScatter: + params, err = route.paramsSelectScatter(vcursor, queryConstruct) + default: + return fmt.Errorf("query %q cannot be used for streaming", route.Query) + } + if err != nil { + return err + } + return vcursor.StreamExecuteMulti( + route.Query+queryConstruct.Comments, + params.ks, + params.shardVars, + sendReply, + ) +} + +// GetFields fetches the field info. +func (route *Route) GetFields(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, joinvars map[string]interface{}) (*sqltypes.Result, error) { + saved := copyBindVars(queryConstruct.BindVars) + defer func() { queryConstruct.BindVars = saved }() + for k := range joinvars { + queryConstruct.BindVars[k] = nil + } + ks, shard, err := vcursor.GetAnyShard(route.Keyspace.Name) + if err != nil { + return nil, err + } + + return vcursor.ScatterConnExecute(route.FieldQuery, queryConstruct.BindVars, ks, []string{shard}, queryConstruct.NotInTransaction) +} + +func copyBindVars(bindVars map[string]interface{}) map[string]interface{} { + out := make(map[string]interface{}) + for k, v := range bindVars { + out[k] = v + } + return out +} + +func (route *Route) paramsUnsharded(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct) (*scatterParams, error) { + ks, _, allShards, err := vcursor.GetKeyspaceShards(route.Keyspace.Name) + if err != nil { + return nil, fmt.Errorf("paramsUnsharded: %v", err) + } + if len(allShards) != 1 { + return nil, fmt.Errorf("unsharded keyspace %s has multiple shards", ks) + } + return newScatterParams(ks, queryConstruct.BindVars, []string{allShards[0].Name}), nil +} + +func (route *Route) paramsSelectEqual(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct) (*scatterParams, error) { + keys, err := route.resolveKeys([]interface{}{route.Values}, queryConstruct.BindVars) + if err != nil { + return nil, fmt.Errorf("paramsSelectEqual: %v", err) + } + ks, routing, err := route.resolveShards(vcursor, queryConstruct, keys) + if err != nil { + return nil, fmt.Errorf("paramsSelectEqual: %v", err) + } + return newScatterParams(ks, queryConstruct.BindVars, routing.Shards()), nil +} + +func (route *Route) paramsSelectIN(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct) (*scatterParams, error) { + vals, err := route.resolveList(route.Values, queryConstruct.BindVars) + if err != nil { + return nil, fmt.Errorf("paramsSelectIN: %v", err) + } + keys, err := route.resolveKeys(vals, queryConstruct.BindVars) + if err != nil { + return nil, fmt.Errorf("paramsSelectIN: %v", err) + } + ks, routing, err := route.resolveShards(vcursor, queryConstruct, keys) + if err != nil { + return nil, fmt.Errorf("paramsSelectEqual: %v", err) + } + return &scatterParams{ + ks: ks, + shardVars: routing.ShardVars(queryConstruct.BindVars), + }, nil +} + +func (route *Route) paramsSelectScatter(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct) (*scatterParams, error) { + ks, _, allShards, err := vcursor.GetKeyspaceShards(route.Keyspace.Name) + if err != nil { + return nil, fmt.Errorf("paramsSelectScatter: %v", err) + } + var shards []string + for _, shard := range allShards { + shards = append(shards, shard.Name) + } + return newScatterParams(ks, queryConstruct.BindVars, shards), nil +} + +func (route *Route) execUpdateEqual(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct) (*sqltypes.Result, error) { + keys, err := route.resolveKeys([]interface{}{route.Values}, queryConstruct.BindVars) + if err != nil { + return nil, fmt.Errorf("execUpdateEqual: %v", err) + } + ks, shard, ksid, err := route.resolveSingleShard(vcursor, queryConstruct, keys[0]) + if err != nil { + return nil, fmt.Errorf("execUpdateEqual: %v", err) + } + if len(ksid) == 0 { + return &sqltypes.Result{}, nil + } + rewritten := sqlannotation.AddKeyspaceIDs(route.Query, [][]byte{ksid}, queryConstruct.Comments) + return vcursor.ScatterConnExecute(rewritten, queryConstruct.BindVars, ks, []string{shard}, queryConstruct.NotInTransaction) +} + +func (route *Route) execDeleteEqual(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct) (*sqltypes.Result, error) { + keys, err := route.resolveKeys([]interface{}{route.Values}, queryConstruct.BindVars) + if err != nil { + return nil, fmt.Errorf("execDeleteEqual: %v", err) + } + ks, shard, ksid, err := route.resolveSingleShard(vcursor, queryConstruct, keys[0]) + if err != nil { + return nil, fmt.Errorf("execDeleteEqual: %v", err) + } + if len(ksid) == 0 { + return &sqltypes.Result{}, nil + } + if route.Subquery != "" { + err = route.deleteVindexEntries(vcursor, queryConstruct, ks, shard, ksid) + if err != nil { + return nil, fmt.Errorf("execDeleteEqual: %v", err) + } + } + rewritten := sqlannotation.AddKeyspaceIDs(route.Query, [][]byte{ksid}, queryConstruct.Comments) + return vcursor.ScatterConnExecute(rewritten, queryConstruct.BindVars, ks, []string{shard}, queryConstruct.NotInTransaction) +} + +func (route *Route) execInsertUnsharded(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct) (*sqltypes.Result, error) { + insertid, err := route.handleGenerate(vcursor, queryConstruct) + if err != nil { + return nil, fmt.Errorf("execInsertUnsharded: %v", err) + } + params, err := route.paramsUnsharded(vcursor, queryConstruct) + if err != nil { + return nil, fmt.Errorf("execInsertUnsharded: %v", err) + } + + shardQueries := route.getShardQueries(route.Query+queryConstruct.Comments, params) + result, err := vcursor.ExecuteMultiShard(params.ks, shardQueries, queryConstruct.NotInTransaction) + if err != nil { + return nil, fmt.Errorf("execInsertUnsharded: %v", err) + } + if insertid != 0 { + if result.InsertID != 0 { + return nil, fmt.Errorf("sequence and db generated a value each for insert") + } + result.InsertID = uint64(insertid) + } + return result, nil +} + +func (route *Route) execInsertSharded(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct) (*sqltypes.Result, error) { + insertid, err := route.handleGenerate(vcursor, queryConstruct) + if err != nil { + return nil, fmt.Errorf("execInsertSharded: %v", err) + } + keyspace, shardQueries, err := route.getInsertShardedRoute(vcursor, queryConstruct) + if err != nil { + return nil, fmt.Errorf("execInsertSharded: %v", err) + } + + result, err := vcursor.ExecuteMultiShard(keyspace, shardQueries, queryConstruct.NotInTransaction) + + if err != nil { + return nil, fmt.Errorf("execInsertSharded: %v", err) + } + + if insertid != 0 { + if result.InsertID != 0 { + return nil, fmt.Errorf("sequence and db generated a value each for insert") + } + result.InsertID = uint64(insertid) + } + + return result, nil +} + +func (route *Route) getInsertShardedRoute(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct) (keyspace string, shardQueries map[string]querytypes.BoundQuery, err error) { + keyspaceIDs := [][]byte{} + routing := make(map[string][]string) + shardKeyspaceIDMap := make(map[string][][]byte) + keyspace, _, allShards, err := vcursor.GetKeyspaceShards(route.Keyspace.Name) + if err != nil { + return "", nil, fmt.Errorf("getInsertShardedRoute: %v", err) + } + + inputs := route.Values.([]interface{}) + for rowNum, input := range inputs { + keys, err := route.resolveKeys(input.([]interface{}), queryConstruct.BindVars) + if err != nil { + return "", nil, fmt.Errorf("getInsertShardedRoute: %v", err) + } + for colNum := 0; colNum < len(keys); colNum++ { + if colNum == 0 { + ksid, err := route.handlePrimary(vcursor, queryConstruct, keys[colNum], route.Table.ColumnVindexes[colNum], queryConstruct.BindVars, rowNum) + if err != nil { + return "", nil, fmt.Errorf("getInsertShardedRoute: %v", err) + } + keyspaceIDs = append(keyspaceIDs, ksid) + } else { + err := route.handleNonPrimary(vcursor, queryConstruct, keys[colNum], route.Table.ColumnVindexes[colNum], queryConstruct.BindVars, keyspaceIDs[rowNum], rowNum) + if err != nil { + return "", nil, fmt.Errorf("getInsertShardedRoute: %v", err) + } + } + } + shard, err := vcursor.GetShardForKeyspaceID(allShards, keyspaceIDs[rowNum]) + routing[shard] = append(routing[shard], route.Mid[rowNum]) + if err != nil { + return "", nil, fmt.Errorf("getInsertShardedRoute: %v", err) + } + shardKeyspaceIDMap[shard] = append(shardKeyspaceIDMap[shard], keyspaceIDs[rowNum]) + } + + shardQueries = make(map[string]querytypes.BoundQuery, len(routing)) + for shard := range routing { + rewritten := route.Prefix + strings.Join(routing[shard], ",") + route.Suffix + if err != nil { + return "", nil, fmt.Errorf("getInsertShardedRoute: Error While Rewriting Query: %v", err) + } + rewrittenQuery := sqlannotation.AddKeyspaceIDs(rewritten, shardKeyspaceIDMap[shard], queryConstruct.Comments) + query := querytypes.BoundQuery{ + Sql: rewrittenQuery, + BindVariables: queryConstruct.BindVars, + } + shardQueries[shard] = query + } + + return keyspace, shardQueries, nil +} + +// resolveList returns a list of values, typically for an IN clause. If the input +// is a bind var name, it uses the list provided in the bind var. If the input is +// already a list, it returns just that. +func (route *Route) resolveList(val interface{}, bindVars map[string]interface{}) ([]interface{}, error) { + switch v := val.(type) { + case []interface{}: + return v, nil + case string: + // It can only be a list bind var. + list, ok := bindVars[v[2:]] + if !ok { + return nil, fmt.Errorf("could not find bind var %s", v) + } + + // Lists can be an []interface{}, or a *querypb.BindVariable + // with type TUPLE. + switch l := list.(type) { + case []interface{}: + return l, nil + case *querypb.BindVariable: + if l.Type != querypb.Type_TUPLE { + return nil, fmt.Errorf("expecting list for bind var %s: %v", v, list) + } + result := make([]interface{}, len(l.Values)) + for i, val := range l.Values { + // We can use MakeTrusted as the lower + // layers will verify the value if needed. + result[i] = sqltypes.MakeTrusted(val.Type, val.Value) + } + return result, nil + default: + return nil, fmt.Errorf("expecting list for bind var %s: %v", v, list) + } + default: + panic("unexpected") + } +} + +// resolveKeys takes a list as input that may have values or bind var names. +// It returns a new list with all the bind vars resolved. +func (route *Route) resolveKeys(vals []interface{}, bindVars map[string]interface{}) (keys []interface{}, err error) { + keys = make([]interface{}, 0, len(vals)) + for _, val := range vals { + if v, ok := val.(string); ok { + val, ok = bindVars[v[1:]] + if !ok { + return nil, fmt.Errorf("could not find bind var %s", v) + } + } + keys = append(keys, val) + } + return keys, nil +} + +func (route *Route) resolveShards(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, vindexKeys []interface{}) (newKeyspace string, routing routingMap, err error) { + newKeyspace, _, allShards, err := vcursor.GetKeyspaceShards(route.Keyspace.Name) + if err != nil { + return "", nil, err + } + routing = make(routingMap) + switch mapper := route.Vindex.(type) { + case vindexes.Unique: + ksids, err := mapper.Map(vcursor, vindexKeys) + if err != nil { + return "", nil, err + } + for i, ksid := range ksids { + if len(ksid) == 0 { + continue + } + shard, err := vcursor.GetShardForKeyspaceID(allShards, ksid) + if err != nil { + return "", nil, err + } + routing.Add(shard, vindexKeys[i]) + } + case vindexes.NonUnique: + ksidss, err := mapper.Map(vcursor, vindexKeys) + if err != nil { + return "", nil, err + } + for i, ksids := range ksidss { + for _, ksid := range ksids { + shard, err := vcursor.GetShardForKeyspaceID(allShards, ksid) + if err != nil { + return "", nil, err + } + routing.Add(shard, vindexKeys[i]) + } + } + default: + panic("unexpected") + } + return newKeyspace, routing, nil +} + +func (route *Route) resolveSingleShard(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, vindexKey interface{}) (newKeyspace, shard string, ksid []byte, err error) { + newKeyspace, _, allShards, err := vcursor.GetKeyspaceShards(route.Keyspace.Name) + if err != nil { + return "", "", nil, err + } + mapper := route.Vindex.(vindexes.Unique) + ksids, err := mapper.Map(vcursor, []interface{}{vindexKey}) + if err != nil { + return "", "", nil, err + } + ksid = ksids[0] + if len(ksid) == 0 { + return "", "", ksid, nil + } + shard, err = vcursor.GetShardForKeyspaceID(allShards, ksid) + if err != nil { + return "", "", nil, err + } + return newKeyspace, shard, ksid, nil +} + +func (route *Route) deleteVindexEntries(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, ks, shard string, ksid []byte) error { + result, err := vcursor.ScatterConnExecute(route.Subquery, queryConstruct.BindVars, ks, []string{shard}, queryConstruct.NotInTransaction) + if err != nil { + return err + } + if len(result.Rows) == 0 { + return nil + } + for i, colVindex := range route.Table.Owned { + keys := make(map[interface{}]bool) + for _, row := range result.Rows { + switch k := row[i].ToNative().(type) { + case []byte: + keys[string(k)] = true + default: + keys[k] = true + } + } + var ids []interface{} + for k := range keys { + ids = append(ids, k) + } + switch vindex := colVindex.Vindex.(type) { + case vindexes.Lookup: + if err = vindex.Delete(vcursor, ids, ksid); err != nil { + return err + } + default: + panic("unexpceted") + } + } + return nil +} + +func (route *Route) handleGenerate(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct) (insertid int64, err error) { + if route.Generate == nil { + return 0, nil + } + count := 0 + resolved := make([]interface{}, len(route.Generate.Value.([]interface{}))) + for i, val := range route.Generate.Value.([]interface{}) { + if v, ok := val.(string); ok { + val, ok = queryConstruct.BindVars[v[1:]] + if !ok { + return 0, fmt.Errorf("handleGenerate: could not find bind var %s", v) + } + } + if val == nil { + count++ + } else if v, ok := val.(*querypb.BindVariable); ok && v.Type == sqltypes.Null { + count++ + } else { + resolved[i] = val + } + } + if count != 0 { + // TODO(sougou): This is similar to paramsUnsharded. + ks, _, allShards, err := vcursor.GetKeyspaceShards(route.Generate.Keyspace.Name) + if err != nil { + return 0, fmt.Errorf("handleGenerate: %v", err) + } + if len(allShards) != 1 { + return 0, fmt.Errorf("unsharded keyspace %s has multiple shards", ks) + } + params := newScatterParams(ks, map[string]interface{}{"n": int64(count)}, []string{allShards[0].Name}) + // We nil out the transaction context for this particular call. + // TODO(sougou): Use ExecuteShard instead. + shardQueries := route.getShardQueries(route.Generate.Query, params) + qr, err := vcursor.ExecuteShard(params.ks, shardQueries) + if err != nil { + return 0, err + } + // If no rows are returned, it's an internal error, and the code + // must panic, which will caught and reported. + insertid, err = qr.Rows[0][0].ParseInt64() + if err != nil { + return 0, err + } + } + cur := insertid + for i, v := range resolved { + if v != nil { + queryConstruct.BindVars[SeqVarName+strconv.Itoa(i)] = v + } else { + queryConstruct.BindVars[SeqVarName+strconv.Itoa(i)] = cur + cur++ + } + } + return insertid, nil +} + +func (route *Route) handlePrimary(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, vindexKey interface{}, colVindex *vindexes.ColumnVindex, bv map[string]interface{}, rowNum int) (ksid []byte, err error) { + if vindexKey == nil { + return nil, fmt.Errorf("value must be supplied for column %v", colVindex.Column) + } + mapper := colVindex.Vindex.(vindexes.Unique) + ksids, err := mapper.Map(vcursor, []interface{}{vindexKey}) + if err != nil { + return nil, err + } + ksid = ksids[0] + if len(ksid) == 0 { + return nil, fmt.Errorf("could not map %v to a keyspace id", vindexKey) + } + bv["_"+colVindex.Column.CompliantName()+strconv.Itoa(rowNum)] = vindexKey + return ksid, nil +} + +func (route *Route) handleNonPrimary(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, vindexKey interface{}, colVindex *vindexes.ColumnVindex, bv map[string]interface{}, ksid []byte, rowNum int) error { + if colVindex.Owned { + if vindexKey == nil { + return fmt.Errorf("value must be supplied for column %v", colVindex.Column) + } + err := colVindex.Vindex.(vindexes.Lookup).Create(vcursor, vindexKey, ksid) + if err != nil { + return err + } + } else { + if vindexKey == nil { + reversible, ok := colVindex.Vindex.(vindexes.Reversible) + if !ok { + return fmt.Errorf("value must be supplied for column %v", colVindex.Column) + } + var err error + vindexKey, err = reversible.ReverseMap(vcursor, ksid) + if err != nil { + return err + } + if vindexKey == nil { + return fmt.Errorf("could not compute value for column %v", colVindex.Column) + } + } else { + ok, err := colVindex.Vindex.Verify(vcursor, vindexKey, ksid) + if err != nil { + return err + } + if !ok { + return fmt.Errorf("value %v for column %v does not map to keyspace id %v", vindexKey, colVindex.Column, hex.EncodeToString(ksid)) + } + } + } + bv["_"+colVindex.Column.CompliantName()+strconv.Itoa(rowNum)] = vindexKey + return nil +} + +func (route *Route) getShardQueries(query string, params *scatterParams) map[string]querytypes.BoundQuery { + + shardQueries := make(map[string]querytypes.BoundQuery, len(params.shardVars)) + for shard, shardVars := range params.shardVars { + query := querytypes.BoundQuery{ + Sql: query, + BindVariables: shardVars, + } + shardQueries[shard] = query + } + return shardQueries +} diff --git a/go/vt/vtgate/routing_map.go b/go/vt/vtgate/engine/routing_map.go similarity index 88% rename from go/vt/vtgate/routing_map.go rename to go/vt/vtgate/engine/routing_map.go index a04b4f99128..494dd6811aa 100644 --- a/go/vt/vtgate/routing_map.go +++ b/go/vt/vtgate/engine/routing_map.go @@ -2,9 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package vtgate - -import "github.com/youtube/vitess/go/vt/vtgate/engine" +package engine type routingMap map[string][]interface{} @@ -27,7 +25,7 @@ func (rtm routingMap) ShardVars(bv map[string]interface{}) map[string]map[string for k, v := range bv { newbv[k] = v } - newbv[engine.ListVarName] = vals + newbv[ListVarName] = vals shardVars[shard] = newbv } return shardVars diff --git a/go/vt/vtgate/query_executor.go b/go/vt/vtgate/query_executor.go index 24cbf356d7d..5c7e62e881a 100644 --- a/go/vt/vtgate/query_executor.go +++ b/go/vt/vtgate/query_executor.go @@ -8,56 +8,68 @@ import ( "golang.org/x/net/context" "github.com/youtube/vitess/go/sqltypes" - "github.com/youtube/vitess/go/vt/sqlparser" - "github.com/youtube/vitess/go/vt/vtgate/engine" querypb "github.com/youtube/vitess/go/vt/proto/query" topodatapb "github.com/youtube/vitess/go/vt/proto/topodata" vtgatepb "github.com/youtube/vitess/go/vt/proto/vtgate" + "github.com/youtube/vitess/go/vt/tabletserver/querytypes" ) type queryExecutor struct { - ctx context.Context - sql, comments string - bindVars map[string]interface{} - keyspace string - tabletType topodatapb.TabletType - session *vtgatepb.Session - notInTransaction bool - options *querypb.ExecuteOptions - router *Router -} - -func newQueryExecutor(ctx context.Context, sql string, bindVars map[string]interface{}, keyspace string, tabletType topodatapb.TabletType, session *vtgatepb.Session, notInTransaction bool, options *querypb.ExecuteOptions, router *Router) *queryExecutor { - query, comments := sqlparser.SplitTrailingComments(sql) + ctx context.Context + tabletType topodatapb.TabletType + session *vtgatepb.Session + options *querypb.ExecuteOptions + router *Router +} + +func newQueryExecutor(ctx context.Context, tabletType topodatapb.TabletType, session *vtgatepb.Session, options *querypb.ExecuteOptions, router *Router) *queryExecutor { return &queryExecutor{ - ctx: ctx, - sql: query, - comments: comments, - bindVars: bindVars, - keyspace: keyspace, - tabletType: tabletType, - session: session, - notInTransaction: notInTransaction, - options: options, - router: router, + ctx: ctx, + tabletType: tabletType, + session: session, + options: options, + router: router, } } +// Execute method call from vindex call to vtgate. func (vc *queryExecutor) Execute(query string, bindvars map[string]interface{}) (*sqltypes.Result, error) { // We have to use an empty keyspace here, becasue vindexes that call back can reference // any table. return vc.router.Execute(vc.ctx, query, bindvars, "", vc.tabletType, vc.session, false, vc.options) } -func (vc *queryExecutor) ExecuteRoute(route *engine.Route, joinvars map[string]interface{}) (*sqltypes.Result, error) { - return vc.router.ExecuteRoute(vc, route, joinvars) +// ExecuteMultiShard method call from engine call to vtgate. +func (vc *queryExecutor) ExecuteMultiShard(keyspace string, shardQueries map[string]querytypes.BoundQuery, notInTransaction bool) (*sqltypes.Result, error) { + return vc.router.scatterConn.ExecuteMultiShard(vc.ctx, keyspace, shardQueries, vc.tabletType, NewSafeSession(vc.session), notInTransaction, vc.options) +} + +// StreamExecuteMulti method call from engine call to vtgate. +func (vc *queryExecutor) StreamExecuteMulti(query string, keyspace string, shardVars map[string]map[string]interface{}, sendReply func(reply *sqltypes.Result) error) error { + return vc.router.scatterConn.StreamExecuteMulti(vc.ctx, query, keyspace, shardVars, vc.tabletType, vc.options, sendReply) +} + +// GetAnyShard method call from engine call to vtgate. +func (vc *queryExecutor) GetAnyShard(keyspace string) (ks, shard string, err error) { + return getAnyShard(vc.ctx, vc.router.serv, vc.router.cell, keyspace, vc.tabletType) +} + +// ScatterConnExecute method call from engine call to vtgate. +func (vc *queryExecutor) ScatterConnExecute(query string, bindVars map[string]interface{}, keyspace string, shards []string, notInTransaction bool) (*sqltypes.Result, error) { + return vc.router.scatterConn.Execute(vc.ctx, query, bindVars, keyspace, shards, vc.tabletType, NewSafeSession(vc.session), notInTransaction, vc.options) +} + +// GetKeyspaceShards method call from engine call to vtgate. +func (vc *queryExecutor) GetKeyspaceShards(keyspace string) (string, *topodatapb.SrvKeyspace, []*topodatapb.ShardReference, error) { + return getKeyspaceShards(vc.ctx, vc.router.serv, vc.router.cell, keyspace, vc.tabletType) } -func (vc *queryExecutor) StreamExecuteRoute(route *engine.Route, joinvars map[string]interface{}, sendReply func(*sqltypes.Result) error) error { - return vc.router.StreamExecuteRoute(vc, route, joinvars, sendReply) +// GetShardForKeyspaceID method call from engine call to vtgate. +func (vc *queryExecutor) GetShardForKeyspaceID(allShards []*topodatapb.ShardReference, keyspaceID []byte) (string, error) { + return getShardForKeyspaceID(allShards, keyspaceID) } -func (vc *queryExecutor) GetRouteFields(route *engine.Route, joinvars map[string]interface{}) (*sqltypes.Result, error) { - return vc.router.GetRouteFields(vc, route, joinvars) +func (vc *queryExecutor) ExecuteShard(keyspace string, shardQueries map[string]querytypes.BoundQuery) (*sqltypes.Result, error) { + return vc.router.scatterConn.ExecuteMultiShard(vc.ctx, keyspace, shardQueries, vc.tabletType, NewSafeSession(nil), false, vc.options) } diff --git a/go/vt/vtgate/queryinfo/query_construct.go b/go/vt/vtgate/queryinfo/query_construct.go new file mode 100644 index 00000000000..a214241e666 --- /dev/null +++ b/go/vt/vtgate/queryinfo/query_construct.go @@ -0,0 +1,24 @@ +package queryinfo + +import "github.com/youtube/vitess/go/vt/sqlparser" + +// QueryConstruct contains the information about the sql and bindVars to be used by vtgate and engine. +type QueryConstruct struct { + SQL string + Comments string + Keyspace string + BindVars map[string]interface{} + NotInTransaction bool +} + +// NewQueryConstruct method initializes the structure. +func NewQueryConstruct(sql, keyspace string, bindVars map[string]interface{}, notInTransaction bool) *QueryConstruct { + query, comments := sqlparser.SplitTrailingComments(sql) + return &QueryConstruct{ + SQL: query, + Comments: comments, + Keyspace: keyspace, + BindVars: bindVars, + NotInTransaction: notInTransaction, + } +} diff --git a/go/vt/vtgate/router.go b/go/vt/vtgate/router.go index 8b95b52f572..b63e682d969 100644 --- a/go/vt/vtgate/router.go +++ b/go/vt/vtgate/router.go @@ -5,23 +5,14 @@ package vtgate import ( - "encoding/hex" - "fmt" - "strconv" - "github.com/youtube/vitess/go/sqltypes" - "github.com/youtube/vitess/go/vt/sqlannotation" "github.com/youtube/vitess/go/vt/topo" - "github.com/youtube/vitess/go/vt/vtgate/engine" - "github.com/youtube/vitess/go/vt/vtgate/vindexes" "golang.org/x/net/context" - "strings" - querypb "github.com/youtube/vitess/go/vt/proto/query" topodatapb "github.com/youtube/vitess/go/vt/proto/topodata" vtgatepb "github.com/youtube/vitess/go/vt/proto/vtgate" - "github.com/youtube/vitess/go/vt/tabletserver/querytypes" + "github.com/youtube/vitess/go/vt/vtgate/queryinfo" ) // Router is the layer to route queries to the correct shards @@ -33,22 +24,6 @@ type Router struct { scatterConn *ScatterConn } -type scatterParams struct { - ks string - shardVars map[string]map[string]interface{} -} - -func newScatterParams(ks string, bv map[string]interface{}, shards []string) *scatterParams { - shardVars := make(map[string]map[string]interface{}, len(shards)) - for _, shard := range shards { - shardVars[shard] = bv - } - return &scatterParams{ - ks: ks, - shardVars: shardVars, - } -} - // NewRouter creates a new Router. func NewRouter(ctx context.Context, serv topo.SrvTopoServer, cell, statsName string, scatterConn *ScatterConn, normalize bool) *Router { return &Router{ @@ -64,12 +39,13 @@ func (rtr *Router) Execute(ctx context.Context, sql string, bindVars map[string] if bindVars == nil { bindVars = make(map[string]interface{}) } - vcursor := newQueryExecutor(ctx, sql, bindVars, keyspace, tabletType, session, notInTransaction, options, rtr) + vcursor := newQueryExecutor(ctx, tabletType, session, options, rtr) + queryConstruct := queryinfo.NewQueryConstruct(sql, keyspace, bindVars, notInTransaction) plan, err := rtr.planner.GetPlan(sql, keyspace, bindVars) if err != nil { return nil, err } - return plan.Instructions.Execute(vcursor, make(map[string]interface{}), true) + return plan.Instructions.Execute(vcursor, queryConstruct, make(map[string]interface{}), true) } // StreamExecute executes a streaming query. @@ -77,12 +53,13 @@ func (rtr *Router) StreamExecute(ctx context.Context, sql string, bindVars map[s if bindVars == nil { bindVars = make(map[string]interface{}) } - vcursor := newQueryExecutor(ctx, sql, bindVars, keyspace, tabletType, nil, false, options, rtr) + vcursor := newQueryExecutor(ctx, tabletType, nil, options, rtr) + queryConstruct := queryinfo.NewQueryConstruct(sql, keyspace, bindVars, false) plan, err := rtr.planner.GetPlan(sql, keyspace, bindVars) if err != nil { return err } - return plan.Instructions.StreamExecute(vcursor, make(map[string]interface{}), true, sendReply) + return plan.Instructions.StreamExecute(vcursor, queryConstruct, make(map[string]interface{}), true, sendReply) } // ExecuteBatch routes a non-streaming queries. @@ -99,12 +76,13 @@ func (rtr *Router) ExecuteBatch(ctx context.Context, sqlList []string, bindVarsL bindVars = make(map[string]interface{}) } //Using same QueryExecutor -> marking notInTransaction as false and not using asTransaction flag - vcursor := newQueryExecutor(ctx, query, bindVars, keyspace, tabletType, session, false, options, rtr) + vcursor := newQueryExecutor(ctx, tabletType, session, options, rtr) + queryConstruct := queryinfo.NewQueryConstruct(query, keyspace, bindVars, false) plan, err := rtr.planner.GetPlan(query, keyspace, bindVars) if err != nil { queryResponse.QueryError = err } else { - result, err := plan.Instructions.Execute(vcursor, make(map[string]interface{}), true) + result, err := plan.Instructions.Execute(vcursor, queryConstruct, make(map[string]interface{}), true) queryResponse.QueryResult = result queryResponse.QueryError = err } @@ -113,126 +91,6 @@ func (rtr *Router) ExecuteBatch(ctx context.Context, sqlList []string, bindVarsL return queryResponseList, nil } -// ExecuteRoute executes the route query for all route opcodes. -func (rtr *Router) ExecuteRoute(vcursor *queryExecutor, route *engine.Route, joinvars map[string]interface{}) (*sqltypes.Result, error) { - saved := copyBindVars(vcursor.bindVars) - defer func() { vcursor.bindVars = saved }() - for k, v := range joinvars { - vcursor.bindVars[k] = v - } - - switch route.Opcode { - case engine.UpdateEqual: - return rtr.execUpdateEqual(vcursor, route) - case engine.DeleteEqual: - return rtr.execDeleteEqual(vcursor, route) - case engine.InsertSharded: - return rtr.execInsertSharded(vcursor, route) - case engine.InsertUnsharded: - return rtr.execInsertUnsharded(vcursor, route) - } - - var err error - var params *scatterParams - switch route.Opcode { - case engine.SelectUnsharded, engine.UpdateUnsharded, - engine.DeleteUnsharded: - params, err = rtr.paramsUnsharded(vcursor, route) - case engine.SelectEqual, engine.SelectEqualUnique: - params, err = rtr.paramsSelectEqual(vcursor, route) - case engine.SelectIN: - params, err = rtr.paramsSelectIN(vcursor, route) - case engine.SelectScatter: - params, err = rtr.paramsSelectScatter(vcursor, route) - default: - // TODO(sougou): improve error. - return nil, fmt.Errorf("unsupported query route: %v", route) - } - if err != nil { - return nil, err - } - - shardQueries := rtr.getShardQueries(vcursor, route.Query+vcursor.comments, params) - return rtr.scatterConn.ExecuteMultiShard( - vcursor.ctx, - params.ks, - shardQueries, - vcursor.tabletType, - NewSafeSession(vcursor.session), - vcursor.notInTransaction, - vcursor.options, - ) -} - -func copyBindVars(bindVars map[string]interface{}) map[string]interface{} { - out := make(map[string]interface{}) - for k, v := range bindVars { - out[k] = v - } - return out -} - -// GetRouteFields fetches the field info for the route. -func (rtr *Router) GetRouteFields(vcursor *queryExecutor, route *engine.Route, joinvars map[string]interface{}) (*sqltypes.Result, error) { - saved := copyBindVars(vcursor.bindVars) - defer func() { vcursor.bindVars = saved }() - for k := range joinvars { - vcursor.bindVars[k] = nil - } - ks, shard, err := getAnyShard(vcursor.ctx, rtr.serv, rtr.cell, route.Keyspace.Name, vcursor.tabletType) - if err != nil { - return nil, err - } - - return rtr.scatterConn.Execute( - vcursor.ctx, - route.FieldQuery, - vcursor.bindVars, - ks, - []string{shard}, - vcursor.tabletType, - NewSafeSession(vcursor.session), - vcursor.notInTransaction, - vcursor.options, - ) -} - -// StreamExecuteRoute performs a streaming route. Only selects are allowed. -func (rtr *Router) StreamExecuteRoute(vcursor *queryExecutor, route *engine.Route, joinvars map[string]interface{}, sendReply func(*sqltypes.Result) error) error { - saved := copyBindVars(vcursor.bindVars) - defer func() { vcursor.bindVars = saved }() - for k, v := range joinvars { - vcursor.bindVars[k] = v - } - - var err error - var params *scatterParams - switch route.Opcode { - case engine.SelectUnsharded: - params, err = rtr.paramsUnsharded(vcursor, route) - case engine.SelectEqual, engine.SelectEqualUnique: - params, err = rtr.paramsSelectEqual(vcursor, route) - case engine.SelectIN: - params, err = rtr.paramsSelectIN(vcursor, route) - case engine.SelectScatter: - params, err = rtr.paramsSelectScatter(vcursor, route) - default: - return fmt.Errorf("query %q cannot be used for streaming", route.Query) - } - if err != nil { - return err - } - return rtr.scatterConn.StreamExecuteMulti( - vcursor.ctx, - route.Query+vcursor.comments, - params.ks, - params.shardVars, - vcursor.tabletType, - vcursor.options, - sendReply, - ) -} - // IsKeyspaceRangeBasedSharded returns true if the keyspace in the vschema is // marked as sharded. func (rtr *Router) IsKeyspaceRangeBasedSharded(keyspace string) bool { @@ -246,524 +104,3 @@ func (rtr *Router) IsKeyspaceRangeBasedSharded(keyspace string) bool { } return ks.Keyspace.Sharded } - -func (rtr *Router) paramsUnsharded(vcursor *queryExecutor, route *engine.Route) (*scatterParams, error) { - ks, _, allShards, err := getKeyspaceShards(vcursor.ctx, rtr.serv, rtr.cell, route.Keyspace.Name, vcursor.tabletType) - if err != nil { - return nil, fmt.Errorf("paramsUnsharded: %v", err) - } - if len(allShards) != 1 { - return nil, fmt.Errorf("unsharded keyspace %s has multiple shards", ks) - } - return newScatterParams(ks, vcursor.bindVars, []string{allShards[0].Name}), nil -} - -func (rtr *Router) paramsSelectEqual(vcursor *queryExecutor, route *engine.Route) (*scatterParams, error) { - keys, err := rtr.resolveKeys([]interface{}{route.Values}, vcursor.bindVars) - if err != nil { - return nil, fmt.Errorf("paramsSelectEqual: %v", err) - } - ks, routing, err := rtr.resolveShards(vcursor, keys, route) - if err != nil { - return nil, fmt.Errorf("paramsSelectEqual: %v", err) - } - return newScatterParams(ks, vcursor.bindVars, routing.Shards()), nil -} - -func (rtr *Router) paramsSelectIN(vcursor *queryExecutor, route *engine.Route) (*scatterParams, error) { - vals, err := rtr.resolveList(route.Values, vcursor.bindVars) - if err != nil { - return nil, fmt.Errorf("paramsSelectIN: %v", err) - } - keys, err := rtr.resolveKeys(vals, vcursor.bindVars) - if err != nil { - return nil, fmt.Errorf("paramsSelectIN: %v", err) - } - ks, routing, err := rtr.resolveShards(vcursor, keys, route) - if err != nil { - return nil, fmt.Errorf("paramsSelectEqual: %v", err) - } - return &scatterParams{ - ks: ks, - shardVars: routing.ShardVars(vcursor.bindVars), - }, nil -} - -func (rtr *Router) paramsSelectScatter(vcursor *queryExecutor, route *engine.Route) (*scatterParams, error) { - ks, _, allShards, err := getKeyspaceShards(vcursor.ctx, rtr.serv, rtr.cell, route.Keyspace.Name, vcursor.tabletType) - if err != nil { - return nil, fmt.Errorf("paramsSelectScatter: %v", err) - } - var shards []string - for _, shard := range allShards { - shards = append(shards, shard.Name) - } - return newScatterParams(ks, vcursor.bindVars, shards), nil -} - -func (rtr *Router) execUpdateEqual(vcursor *queryExecutor, route *engine.Route) (*sqltypes.Result, error) { - keys, err := rtr.resolveKeys([]interface{}{route.Values}, vcursor.bindVars) - if err != nil { - return nil, fmt.Errorf("execUpdateEqual: %v", err) - } - ks, shard, ksid, err := rtr.resolveSingleShard(vcursor, keys[0], route) - if err != nil { - return nil, fmt.Errorf("execUpdateEqual: %v", err) - } - if len(ksid) == 0 { - return &sqltypes.Result{}, nil - } - rewritten := sqlannotation.AddKeyspaceIDs(route.Query, [][]byte{ksid}, vcursor.comments) - return rtr.scatterConn.Execute( - vcursor.ctx, - rewritten, - vcursor.bindVars, - ks, - []string{shard}, - vcursor.tabletType, - NewSafeSession(vcursor.session), - vcursor.notInTransaction, - vcursor.options) -} - -func (rtr *Router) execDeleteEqual(vcursor *queryExecutor, route *engine.Route) (*sqltypes.Result, error) { - keys, err := rtr.resolveKeys([]interface{}{route.Values}, vcursor.bindVars) - if err != nil { - return nil, fmt.Errorf("execDeleteEqual: %v", err) - } - ks, shard, ksid, err := rtr.resolveSingleShard(vcursor, keys[0], route) - if err != nil { - return nil, fmt.Errorf("execDeleteEqual: %v", err) - } - if len(ksid) == 0 { - return &sqltypes.Result{}, nil - } - if route.Subquery != "" { - err = rtr.deleteVindexEntries(vcursor, route, ks, shard, ksid) - if err != nil { - return nil, fmt.Errorf("execDeleteEqual: %v", err) - } - } - rewritten := sqlannotation.AddKeyspaceIDs(route.Query, [][]byte{ksid}, vcursor.comments) - return rtr.scatterConn.Execute( - vcursor.ctx, - rewritten, - vcursor.bindVars, - ks, - []string{shard}, - vcursor.tabletType, - NewSafeSession(vcursor.session), - vcursor.notInTransaction, - vcursor.options) -} - -func (rtr *Router) execInsertUnsharded(vcursor *queryExecutor, route *engine.Route) (*sqltypes.Result, error) { - insertid, err := rtr.handleGenerate(vcursor, route.Generate) - if err != nil { - return nil, fmt.Errorf("execInsertUnsharded: %v", err) - } - params, err := rtr.paramsUnsharded(vcursor, route) - if err != nil { - return nil, fmt.Errorf("execInsertUnsharded: %v", err) - } - - shardQueries := rtr.getShardQueries(vcursor, route.Query+vcursor.comments, params) - result, err := rtr.scatterConn.ExecuteMultiShard( - vcursor.ctx, - params.ks, - shardQueries, - vcursor.tabletType, - NewSafeSession(vcursor.session), - vcursor.notInTransaction, - vcursor.options, - ) - if err != nil { - return nil, fmt.Errorf("execInsertUnsharded: %v", err) - } - if insertid != 0 { - if result.InsertID != 0 { - return nil, fmt.Errorf("sequence and db generated a value each for insert") - } - result.InsertID = uint64(insertid) - } - return result, nil -} - -func (rtr *Router) execInsertSharded(vcursor *queryExecutor, route *engine.Route) (*sqltypes.Result, error) { - insertid, err := rtr.handleGenerate(vcursor, route.Generate) - if err != nil { - return nil, fmt.Errorf("execInsertSharded: %v", err) - } - keyspace, shardQueries, err := rtr.getInsertShardedRoute(vcursor, route) - if err != nil { - return nil, fmt.Errorf("execInsertSharded: %v", err) - } - - result, err := rtr.scatterConn.ExecuteMultiShard( - vcursor.ctx, - keyspace, - shardQueries, - vcursor.tabletType, - NewSafeSession(vcursor.session), - vcursor.notInTransaction, - vcursor.options) - - if err != nil { - return nil, fmt.Errorf("execInsertSharded: %v", err) - } - - if insertid != 0 { - if result.InsertID != 0 { - return nil, fmt.Errorf("sequence and db generated a value each for insert") - } - result.InsertID = uint64(insertid) - } - - return result, nil -} - -func (rtr *Router) getInsertShardedRoute(vcursor *queryExecutor, route *engine.Route) (keyspace string, shardQueries map[string]querytypes.BoundQuery, err error) { - keyspaceIDs := [][]byte{} - routing := make(map[string][]string) - shardKeyspaceIDMap := make(map[string][][]byte) - keyspace, _, allShards, err := getKeyspaceShards(vcursor.ctx, rtr.serv, rtr.cell, route.Keyspace.Name, vcursor.tabletType) - if err != nil { - return "", nil, fmt.Errorf("getInsertShardedRoute: %v", err) - } - - inputs := route.Values.([]interface{}) - for rowNum, input := range inputs { - keys, err := rtr.resolveKeys(input.([]interface{}), vcursor.bindVars) - if err != nil { - return "", nil, fmt.Errorf("getInsertShardedRoute: %v", err) - } - for colNum := 0; colNum < len(keys); colNum++ { - if colNum == 0 { - ksid, err := rtr.handlePrimary(vcursor, keys[colNum], route.Table.ColumnVindexes[colNum], vcursor.bindVars, rowNum) - if err != nil { - return "", nil, fmt.Errorf("getInsertShardedRoute: %v", err) - } - keyspaceIDs = append(keyspaceIDs, ksid) - } else { - err := rtr.handleNonPrimary(vcursor, keys[colNum], route.Table.ColumnVindexes[colNum], vcursor.bindVars, keyspaceIDs[rowNum], rowNum) - if err != nil { - return "", nil, fmt.Errorf("getInsertShardedRoute: %v", err) - } - } - } - shard, err := getShardForKeyspaceID(allShards, keyspaceIDs[rowNum]) - routing[shard] = append(routing[shard], route.Mid[rowNum]) - if err != nil { - return "", nil, fmt.Errorf("getInsertShardedRoute: %v", err) - } - shardKeyspaceIDMap[shard] = append(shardKeyspaceIDMap[shard], keyspaceIDs[rowNum]) - } - - shardQueries = make(map[string]querytypes.BoundQuery, len(routing)) - for shard := range routing { - rewritten := route.Prefix + strings.Join(routing[shard], ",") + route.Suffix - if err != nil { - return "", nil, fmt.Errorf("getInsertShardedRoute: Error While Rewriting Query: %v", err) - } - rewrittenQuery := sqlannotation.AddKeyspaceIDs(rewritten, shardKeyspaceIDMap[shard], vcursor.comments) - query := querytypes.BoundQuery{ - Sql: rewrittenQuery, - BindVariables: vcursor.bindVars, - } - shardQueries[shard] = query - } - - return keyspace, shardQueries, nil -} - -// resolveList returns a list of values, typically for an IN clause. If the input -// is a bind var name, it uses the list provided in the bind var. If the input is -// already a list, it returns just that. -func (rtr *Router) resolveList(val interface{}, bindVars map[string]interface{}) ([]interface{}, error) { - switch v := val.(type) { - case []interface{}: - return v, nil - case string: - // It can only be a list bind var. - list, ok := bindVars[v[2:]] - if !ok { - return nil, fmt.Errorf("could not find bind var %s", v) - } - - // Lists can be an []interface{}, or a *querypb.BindVariable - // with type TUPLE. - switch l := list.(type) { - case []interface{}: - return l, nil - case *querypb.BindVariable: - if l.Type != querypb.Type_TUPLE { - return nil, fmt.Errorf("expecting list for bind var %s: %v", v, list) - } - result := make([]interface{}, len(l.Values)) - for i, val := range l.Values { - // We can use MakeTrusted as the lower - // layers will verify the value if needed. - result[i] = sqltypes.MakeTrusted(val.Type, val.Value) - } - return result, nil - default: - return nil, fmt.Errorf("expecting list for bind var %s: %v", v, list) - } - default: - panic("unexpected") - } -} - -// resolveKeys takes a list as input that may have values or bind var names. -// It returns a new list with all the bind vars resolved. -func (rtr *Router) resolveKeys(vals []interface{}, bindVars map[string]interface{}) (keys []interface{}, err error) { - keys = make([]interface{}, 0, len(vals)) - for _, val := range vals { - if v, ok := val.(string); ok { - val, ok = bindVars[v[1:]] - if !ok { - return nil, fmt.Errorf("could not find bind var %s", v) - } - } - keys = append(keys, val) - } - return keys, nil -} - -func (rtr *Router) resolveShards(vcursor *queryExecutor, vindexKeys []interface{}, route *engine.Route) (newKeyspace string, routing routingMap, err error) { - newKeyspace, _, allShards, err := getKeyspaceShards(vcursor.ctx, rtr.serv, rtr.cell, route.Keyspace.Name, vcursor.tabletType) - if err != nil { - return "", nil, err - } - routing = make(routingMap) - switch mapper := route.Vindex.(type) { - case vindexes.Unique: - ksids, err := mapper.Map(vcursor, vindexKeys) - if err != nil { - return "", nil, err - } - for i, ksid := range ksids { - if len(ksid) == 0 { - continue - } - shard, err := getShardForKeyspaceID(allShards, ksid) - if err != nil { - return "", nil, err - } - routing.Add(shard, vindexKeys[i]) - } - case vindexes.NonUnique: - ksidss, err := mapper.Map(vcursor, vindexKeys) - if err != nil { - return "", nil, err - } - for i, ksids := range ksidss { - for _, ksid := range ksids { - shard, err := getShardForKeyspaceID(allShards, ksid) - if err != nil { - return "", nil, err - } - routing.Add(shard, vindexKeys[i]) - } - } - default: - panic("unexpected") - } - return newKeyspace, routing, nil -} - -func (rtr *Router) resolveSingleShard(vcursor *queryExecutor, vindexKey interface{}, route *engine.Route) (newKeyspace, shard string, ksid []byte, err error) { - newKeyspace, _, allShards, err := getKeyspaceShards(vcursor.ctx, rtr.serv, rtr.cell, route.Keyspace.Name, vcursor.tabletType) - if err != nil { - return "", "", nil, err - } - mapper := route.Vindex.(vindexes.Unique) - ksids, err := mapper.Map(vcursor, []interface{}{vindexKey}) - if err != nil { - return "", "", nil, err - } - ksid = ksids[0] - if len(ksid) == 0 { - return "", "", ksid, nil - } - shard, err = getShardForKeyspaceID(allShards, ksid) - if err != nil { - return "", "", nil, err - } - return newKeyspace, shard, ksid, nil -} - -func (rtr *Router) deleteVindexEntries(vcursor *queryExecutor, route *engine.Route, ks, shard string, ksid []byte) error { - result, err := rtr.scatterConn.Execute( - vcursor.ctx, - route.Subquery, - vcursor.bindVars, - ks, - []string{shard}, - vcursor.tabletType, - NewSafeSession(vcursor.session), - vcursor.notInTransaction, - vcursor.options) - if err != nil { - return err - } - if len(result.Rows) == 0 { - return nil - } - for i, colVindex := range route.Table.Owned { - keys := make(map[interface{}]bool) - for _, row := range result.Rows { - switch k := row[i].ToNative().(type) { - case []byte: - keys[string(k)] = true - default: - keys[k] = true - } - } - var ids []interface{} - for k := range keys { - ids = append(ids, k) - } - switch vindex := colVindex.Vindex.(type) { - case vindexes.Lookup: - if err = vindex.Delete(vcursor, ids, ksid); err != nil { - return err - } - default: - panic("unexpceted") - } - } - return nil -} - -func (rtr *Router) handleGenerate(vcursor *queryExecutor, gen *engine.Generate) (insertid int64, err error) { - if gen == nil { - return 0, nil - } - count := 0 - resolved := make([]interface{}, len(gen.Value.([]interface{}))) - for i, val := range gen.Value.([]interface{}) { - if v, ok := val.(string); ok { - val, ok = vcursor.bindVars[v[1:]] - if !ok { - return 0, fmt.Errorf("handleGenerate: could not find bind var %s", v) - } - } - if val == nil { - count++ - } else if v, ok := val.(*querypb.BindVariable); ok && v.Type == sqltypes.Null { - count++ - } else { - resolved[i] = val - } - } - if count != 0 { - // TODO(sougou): This is similar to paramsUnsharded. - ks, _, allShards, err := getKeyspaceShards(vcursor.ctx, rtr.serv, rtr.cell, gen.Keyspace.Name, vcursor.tabletType) - if err != nil { - return 0, fmt.Errorf("handleGenerate: %v", err) - } - if len(allShards) != 1 { - return 0, fmt.Errorf("unsharded keyspace %s has multiple shards", ks) - } - params := newScatterParams(ks, map[string]interface{}{"n": int64(count)}, []string{allShards[0].Name}) - // We nil out the transaction context for this particular call. - // TODO(sougou): Use ExecuteShard instead. - shardQueries := rtr.getShardQueries(vcursor, gen.Query, params) - qr, err := rtr.scatterConn.ExecuteMultiShard( - vcursor.ctx, - params.ks, - shardQueries, - vcursor.tabletType, - NewSafeSession(nil), - false, - vcursor.options, - ) - if err != nil { - return 0, err - } - // If no rows are returned, it's an internal error, and the code - // must panic, which will caught and reported. - insertid, err = qr.Rows[0][0].ParseInt64() - if err != nil { - return 0, err - } - } - cur := insertid - for i, v := range resolved { - if v != nil { - vcursor.bindVars[engine.SeqVarName+strconv.Itoa(i)] = v - } else { - vcursor.bindVars[engine.SeqVarName+strconv.Itoa(i)] = cur - cur++ - } - } - return insertid, nil -} - -func (rtr *Router) handlePrimary(vcursor *queryExecutor, vindexKey interface{}, colVindex *vindexes.ColumnVindex, bv map[string]interface{}, rowNum int) (ksid []byte, err error) { - if vindexKey == nil { - return nil, fmt.Errorf("value must be supplied for column %v", colVindex.Column) - } - mapper := colVindex.Vindex.(vindexes.Unique) - ksids, err := mapper.Map(vcursor, []interface{}{vindexKey}) - if err != nil { - return nil, err - } - ksid = ksids[0] - if len(ksid) == 0 { - return nil, fmt.Errorf("could not map %v to a keyspace id", vindexKey) - } - bv["_"+colVindex.Column.CompliantName()+strconv.Itoa(rowNum)] = vindexKey - return ksid, nil -} - -func (rtr *Router) handleNonPrimary(vcursor *queryExecutor, vindexKey interface{}, colVindex *vindexes.ColumnVindex, bv map[string]interface{}, ksid []byte, rowNum int) error { - if colVindex.Owned { - if vindexKey == nil { - return fmt.Errorf("value must be supplied for column %v", colVindex.Column) - } - err := colVindex.Vindex.(vindexes.Lookup).Create(vcursor, vindexKey, ksid) - if err != nil { - return err - } - } else { - if vindexKey == nil { - reversible, ok := colVindex.Vindex.(vindexes.Reversible) - if !ok { - return fmt.Errorf("value must be supplied for column %v", colVindex.Column) - } - var err error - vindexKey, err = reversible.ReverseMap(vcursor, ksid) - if err != nil { - return err - } - if vindexKey == nil { - return fmt.Errorf("could not compute value for column %v", colVindex.Column) - } - } else { - ok, err := colVindex.Vindex.Verify(vcursor, vindexKey, ksid) - if err != nil { - return err - } - if !ok { - return fmt.Errorf("value %v for column %v does not map to keyspace id %v", vindexKey, colVindex.Column, hex.EncodeToString(ksid)) - } - } - } - bv["_"+colVindex.Column.CompliantName()+strconv.Itoa(rowNum)] = vindexKey - return nil -} - -func (rtr *Router) getShardQueries(vcursor *queryExecutor, query string, params *scatterParams) map[string]querytypes.BoundQuery { - - shardQueries := make(map[string]querytypes.BoundQuery, len(params.shardVars)) - for shard, shardVars := range params.shardVars { - query := querytypes.BoundQuery{ - Sql: query, - BindVariables: shardVars, - } - shardQueries[shard] = query - } - return shardQueries -}