diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 9143a149802..38d5f346f1c 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -22,6 +22,7 @@ import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vterrors" ) @@ -56,6 +57,16 @@ func (c *Concatenate) GetTableName() string { return res } +// GetExecShards lists all the shards that would be accessed by this primitive +func (c *Concatenate) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + for _, src := range c.Sources { + if err := src.GetExecShards(vcursor, bindVars, each); err != nil { + return err + } + } + return nil +} + func formatTwoOptionsNicely(a, b string) string { if a == b { return a diff --git a/go/vt/vtgate/engine/dbddl.go b/go/vt/vtgate/engine/dbddl.go index f06c8150c2b..ce53fe2fa11 100644 --- a/go/vt/vtgate/engine/dbddl.go +++ b/go/vt/vtgate/engine/dbddl.go @@ -94,6 +94,12 @@ func (c *DBDDL) GetTableName() string { return "" } +// GetExecShards lists all the shards that would be accessed by this primitive +func (c *DBDDL) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + // The DBDDL primitive is not shard-aware, it acts globally on the cluster + return nil +} + // Execute implements the Primitive interface func (c *DBDDL) Execute(vcursor VCursor, _ map[string]*querypb.BindVariable, _ bool) (*sqltypes.Result, error) { name := vcursor.GetDBDDLPluginName() diff --git a/go/vt/vtgate/engine/ddl.go b/go/vt/vtgate/engine/ddl.go index 110e85e3d24..714de55cf68 100644 --- a/go/vt/vtgate/engine/ddl.go +++ b/go/vt/vtgate/engine/ddl.go @@ -22,6 +22,7 @@ import ( vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/schema" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -76,6 +77,15 @@ func (ddl *DDL) GetTableName() string { return ddl.DDL.GetTable().Name.String() } +// GetExecShards lists all the shards that would be accessed by this primitive +func (ddl *DDL) GetExecShards(vcursor VCursor, bindVars map[string]*query.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + primitiveDDL, err := ddl.getPrimitiveToExecute(vcursor, false) + if err != nil { + return err + } + return primitiveDDL.GetExecShards(vcursor, bindVars, each) +} + // IsOnlineSchemaDDL returns true if the query is an online schema change DDL func (ddl *DDL) isOnlineSchemaDDL() bool { switch ddl.DDL.GetAction() { @@ -85,12 +95,13 @@ func (ddl *DDL) isOnlineSchemaDDL() bool { return false } -// Execute implements the Primitive interface -func (ddl *DDL) Execute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantfields bool) (result *sqltypes.Result, err error) { +func (ddl *DDL) getPrimitiveToExecute(vcursor VCursor, updateSession bool) (Primitive, error) { if ddl.CreateTempTable { - vcursor.Session().HasCreatedTempTable() - vcursor.Session().NeedsReservedConn() - return ddl.NormalDDL.Execute(vcursor, bindVars, wantfields) + if updateSession { + vcursor.Session().HasCreatedTempTable() + vcursor.Session().NeedsReservedConn() + } + return ddl.NormalDDL, nil } ddlStrategySetting, err := schema.ParseDDLStrategy(vcursor.Session().GetDDLStrategy()) @@ -104,13 +115,22 @@ func (ddl *DDL) Execute(vcursor VCursor, bindVars map[string]*query.BindVariable if !ddl.OnlineDDLEnabled { return nil, schema.ErrOnlineDDLDisabled } - return ddl.OnlineDDL.Execute(vcursor, bindVars, wantfields) + return ddl.OnlineDDL, nil default: // non online-ddl if !ddl.DirectDDLEnabled { return nil, schema.ErrDirectDDLDisabled } - return ddl.NormalDDL.Execute(vcursor, bindVars, wantfields) + return ddl.NormalDDL, nil + } +} + +// Execute implements the Primitive interface +func (ddl *DDL) Execute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantfields bool) (result *sqltypes.Result, err error) { + primitiveDDL, err := ddl.getPrimitiveToExecute(vcursor, true) + if err != nil { + return nil, err } + return primitiveDDL.Execute(vcursor, bindVars, wantfields) } // StreamExecute implements the Primitive interface diff --git a/go/vt/vtgate/engine/delete.go b/go/vt/vtgate/engine/delete.go index 5d92df61630..07f10439ed0 100644 --- a/go/vt/vtgate/engine/delete.go +++ b/go/vt/vtgate/engine/delete.go @@ -68,6 +68,60 @@ func (del *Delete) GetTableName() string { return "" } +// GetExecShards lists all the shards that would be accessed by this primitive +func (del *Delete) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + switch del.Opcode { + case Unsharded: + rss, _, err := vcursor.ResolveDestinations(del.Keyspace.Name, nil, []key.Destination{key.DestinationAllShards{}}) + if err != nil { + return err + } + each(rss[0]) + return nil + case Equal: + key, err := del.Values[0].ResolveValue(bindVars) + if err != nil { + return err + } + rs, _, err := resolveSingleShard(vcursor, del.Vindex, del.Keyspace, key) + if err != nil { + return err + } + each(rs) + return nil + case In: + rss, _, err := resolveMultiValueShards(vcursor, del.Keyspace, del.Query, bindVars, del.Values[0], del.Vindex) + if err != nil { + return err + } + for _, rs := range rss { + each(rs) + } + return nil + case Scatter: + rss, _, err := vcursor.ResolveDestinations(del.Keyspace.Name, nil, []key.Destination{key.DestinationAllShards{}}) + if err != nil { + return err + } + for _, rs := range rss { + each(rs) + } + return nil + case ByDestination: + rss, _, err := vcursor.ResolveDestinations(del.Keyspace.Name, nil, []key.Destination{del.TargetDestination}) + if err != nil { + return err + } + for _, rs := range rss { + each(rs) + } + return nil + default: + // Unreachable. + return fmt.Errorf("unsupported opcode: %v", del) + } +} + // Execute performs a non-streaming exec. func (del *Delete) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool) (*sqltypes.Result, error) { if del.QueryTimeout != 0 { diff --git a/go/vt/vtgate/engine/distinct.go b/go/vt/vtgate/engine/distinct.go index 9bf260869ae..37e429473dd 100644 --- a/go/vt/vtgate/engine/distinct.go +++ b/go/vt/vtgate/engine/distinct.go @@ -19,6 +19,7 @@ package engine import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vtgate/evalengine" ) @@ -154,6 +155,11 @@ func (d *Distinct) GetTableName() string { return d.Source.GetTableName() } +// GetExecShards lists all the shards that would be accessed by this primitive +func (d *Distinct) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + return d.Source.GetExecShards(vcursor, bindVars, each) +} + // GetFields implements the Primitive interface func (d *Distinct) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { return d.Source.GetFields(vcursor, bindVars) diff --git a/go/vt/vtgate/engine/fake_primitive_test.go b/go/vt/vtgate/engine/fake_primitive_test.go index d1186e4491c..95d5e016b5b 100644 --- a/go/vt/vtgate/engine/fake_primitive_test.go +++ b/go/vt/vtgate/engine/fake_primitive_test.go @@ -23,6 +23,7 @@ import ( "testing" "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/srvtopo" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -63,6 +64,10 @@ func (f *fakePrimitive) GetTableName() string { return "fakeTable" } +func (f *fakePrimitive) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + return nil +} + func (f *fakePrimitive) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { f.log = append(f.log, fmt.Sprintf("Execute %v %v", printBindVars(bindVars), wantfields)) if f.results == nil { diff --git a/go/vt/vtgate/engine/insert.go b/go/vt/vtgate/engine/insert.go index 2ce825ffd5c..bffb5c1945c 100644 --- a/go/vt/vtgate/engine/insert.go +++ b/go/vt/vtgate/engine/insert.go @@ -187,6 +187,30 @@ func (ins *Insert) GetTableName() string { return "" } +// GetExecShards lists all the shards that would be accessed by this primitive +func (ins *Insert) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + switch ins.Opcode { + case InsertUnsharded: + rss, _, err := vcursor.ResolveDestinations(ins.Keyspace.Name, nil, []key.Destination{key.DestinationAllShards{}}) + if err != nil { + return err + } + each(rss[0]) + return nil + case InsertSharded, InsertShardedIgnore: + rss, _, err := ins.getInsertShardedRoute(vcursor, bindVars) + if err != nil { + return err + } + for _, rs := range rss { + each(rs) + } + return nil + default: + return fmt.Errorf("unsupported query route: %v", ins) + } +} + // Execute performs a non-streaming exec. func (ins *Insert) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { if ins.QueryTimeout != 0 { diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index a909dddef3c..14da2739641 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -22,6 +22,7 @@ import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/srvtopo" ) var _ Primitive = (*Join)(nil) @@ -168,6 +169,17 @@ func (jn *Join) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVari return result, nil } +// GetExecShards lists all the shards that would be accessed by this primitive +func (jn *Join) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + if err := jn.Left.GetExecShards(vcursor, bindVars, each); err != nil { + return err + } + if err := jn.Right.GetExecShards(vcursor, bindVars, each); err != nil { + return err + } + return nil +} + // Inputs returns the input primitives for this join func (jn *Join) Inputs() []Primitive { return []Primitive{jn.Left, jn.Right} diff --git a/go/vt/vtgate/engine/limit.go b/go/vt/vtgate/engine/limit.go index 1ec801124ed..7648d8a2574 100644 --- a/go/vt/vtgate/engine/limit.go +++ b/go/vt/vtgate/engine/limit.go @@ -20,6 +20,7 @@ import ( "fmt" "io" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/sqltypes" @@ -51,6 +52,11 @@ func (l *Limit) GetTableName() string { return l.Input.GetTableName() } +// GetExecShards lists all the shards that would be accessed by this primitive +func (l *Limit) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + return l.Input.GetExecShards(vcursor, bindVars, each) +} + // Execute satisfies the Primtive interface. func (l *Limit) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { count, err := l.fetchCount(bindVars) diff --git a/go/vt/vtgate/engine/lock.go b/go/vt/vtgate/engine/lock.go index 794dbc7084e..2998de49f48 100644 --- a/go/vt/vtgate/engine/lock.go +++ b/go/vt/vtgate/engine/lock.go @@ -21,6 +21,7 @@ import ( "vitess.io/vitess/go/vt/key" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -58,6 +59,19 @@ func (l *Lock) GetTableName() string { return "dual" } +// GetExecShards lists all the shards that would be accessed by this primitive +func (l *Lock) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + rss, _, err := vcursor.ResolveDestinations(l.Keyspace.Name, nil, []key.Destination{l.TargetDestination}) + if err != nil { + return err + } + if len(rss) != 1 { + return vterrors.Errorf(vtrpc.Code_FAILED_PRECONDITION, "lock query can be routed to single shard only: %v", rss) + } + each(rss[0]) + return nil +} + // Execute is part of the Primitive interface func (l *Lock) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool) (*sqltypes.Result, error) { rss, _, err := vcursor.ResolveDestinations(l.Keyspace.Name, nil, []key.Destination{l.TargetDestination}) diff --git a/go/vt/vtgate/engine/memory_sort.go b/go/vt/vtgate/engine/memory_sort.go index 1b8c60635af..081eef86f0e 100644 --- a/go/vt/vtgate/engine/memory_sort.go +++ b/go/vt/vtgate/engine/memory_sort.go @@ -24,6 +24,7 @@ import ( "sort" "strings" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/sqltypes" @@ -59,6 +60,11 @@ func (ms *MemorySort) GetTableName() string { return ms.Input.GetTableName() } +// GetExecShards lists all the shards that would be accessed by this primitive +func (ms *MemorySort) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + return ms.Input.GetExecShards(vcursor, bindVars, each) +} + // SetTruncateColumnCount sets the truncate column count. func (ms *MemorySort) SetTruncateColumnCount(count int) { ms.TruncateColumnCount = count diff --git a/go/vt/vtgate/engine/merge_sort.go b/go/vt/vtgate/engine/merge_sort.go index dbee39ee8e5..f22621ca49d 100644 --- a/go/vt/vtgate/engine/merge_sort.go +++ b/go/vt/vtgate/engine/merge_sort.go @@ -18,11 +18,11 @@ package engine import ( "container/heap" + "context" "io" "vitess.io/vitess/go/mysql" - - "context" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/sqltypes" @@ -65,6 +65,18 @@ func (ms *MergeSort) GetKeyspaceName() string { return "" } // GetTableName satisfies Primitive. func (ms *MergeSort) GetTableName() string { return "" } +// GetExecShards lists all the shards that would be accessed by this primitive +func (ms *MergeSort) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + for _, merge := range ms.Primitives { + if primitive, ok := merge.(Primitive); ok { + if err := primitive.GetExecShards(vcursor, bindVars, each); err != nil { + return err + } + } + } + return nil +} + // Execute is not supported. func (ms *MergeSort) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] Execute is not reachable") diff --git a/go/vt/vtgate/engine/mstream.go b/go/vt/vtgate/engine/mstream.go index aee29ddfa3e..28600b9a434 100644 --- a/go/vt/vtgate/engine/mstream.go +++ b/go/vt/vtgate/engine/mstream.go @@ -21,6 +21,7 @@ import ( "vitess.io/vitess/go/vt/key" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -58,6 +59,18 @@ func (m *MStream) GetTableName() string { return m.TableName } +// GetExecShards lists all the shards that would be accessed by this primitive +func (m *MStream) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + rss, _, err := vcursor.ResolveDestinations(m.Keyspace.Name, nil, []key.Destination{m.TargetDestination}) + if err != nil { + return err + } + for _, rs := range rss { + each(rs) + } + return nil +} + // Execute implements the Primitive interface func (m *MStream) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { return nil, vterrors.New(vtrpcpb.Code_INTERNAL, "[BUG] 'Execute' called for Stream") diff --git a/go/vt/vtgate/engine/online_ddl.go b/go/vt/vtgate/engine/online_ddl.go index 32ebe9f53eb..f61aa63f7be 100644 --- a/go/vt/vtgate/engine/online_ddl.go +++ b/go/vt/vtgate/engine/online_ddl.go @@ -26,6 +26,7 @@ import ( vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/schema" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -71,6 +72,17 @@ func (v *OnlineDDL) GetTableName() string { return v.DDL.GetTable().Name.String() } +// GetExecShards lists all the shards that would be accessed by this primitive +func (v *OnlineDDL) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + // In Vitess 12, IsSkipTopo() is always true for all strategies + rss, _, err := vcursor.ResolveDestinations(v.Keyspace.Name, nil, []key.Destination{v.TargetDestination}) + if err != nil { + return err + } + each(rss[0]) + return nil +} + // Execute implements the Primitive interface func (v *OnlineDDL) Execute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantfields bool) (result *sqltypes.Result, err error) { result = &sqltypes.Result{ diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index c31de3de5fc..9fe9998f5bf 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -21,6 +21,7 @@ import ( "strconv" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/srvtopo" "google.golang.org/protobuf/proto" @@ -182,6 +183,11 @@ func (oa *OrderedAggregate) GetTableName() string { return oa.Input.GetTableName() } +// GetExecShards lists all the shards that would be accessed by this primitive +func (oa *OrderedAggregate) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + return oa.Input.GetExecShards(vcursor, bindVars, each) +} + // SetTruncateColumnCount sets the truncate column count. func (oa *OrderedAggregate) SetTruncateColumnCount(count int) { oa.TruncateColumnCount = count diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index ab08f72951e..13a878d308e 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -186,6 +186,7 @@ type ( RouteType() string GetKeyspaceName() string GetTableName() string + GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(*srvtopo.ResolvedShard)) error Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go index 8306efed86a..63534acb09a 100644 --- a/go/vt/vtgate/engine/projection.go +++ b/go/vt/vtgate/engine/projection.go @@ -3,6 +3,7 @@ package engine import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vtgate/evalengine" ) @@ -31,6 +32,11 @@ func (p *Projection) GetTableName() string { return p.Input.GetTableName() } +// GetExecShards lists all the shards that would be accessed by this primitive +func (p *Projection) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + return p.Input.GetExecShards(vcursor, bindVars, each) +} + // Execute implements the Primitive interface func (p *Projection) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { result, err := p.Input.Execute(vcursor, bindVars, wantfields) diff --git a/go/vt/vtgate/engine/pullout_subquery.go b/go/vt/vtgate/engine/pullout_subquery.go index 732d49b15b7..40f3d55fa78 100644 --- a/go/vt/vtgate/engine/pullout_subquery.go +++ b/go/vt/vtgate/engine/pullout_subquery.go @@ -20,6 +20,7 @@ import ( "fmt" "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vterrors" querypb "vitess.io/vitess/go/vt/proto/query" @@ -61,6 +62,17 @@ func (ps *PulloutSubquery) GetTableName() string { return ps.Underlying.GetTableName() } +// GetExecShards lists all the shards that would be accessed by this primitive +func (ps *PulloutSubquery) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + if err := ps.Subquery.GetExecShards(vcursor, bindVars, each); err != nil { + return err + } + if err := ps.Underlying.GetExecShards(vcursor, bindVars, each); err != nil { + return err + } + return nil +} + // Execute satisfies the Primitive interface. func (ps *PulloutSubquery) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { combinedVars, err := ps.execSubquery(vcursor, bindVars) diff --git a/go/vt/vtgate/engine/rename_fields.go b/go/vt/vtgate/engine/rename_fields.go index 9983c5036e1..b6bf09b7b30 100644 --- a/go/vt/vtgate/engine/rename_fields.go +++ b/go/vt/vtgate/engine/rename_fields.go @@ -20,6 +20,7 @@ import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vterrors" ) @@ -60,6 +61,11 @@ func (r *RenameFields) GetTableName() string { return r.Input.GetTableName() } +// GetExecShards lists all the shards that would be accessed by this primitive +func (r *RenameFields) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + return r.Input.GetExecShards(vcursor, bindVars, each) +} + // Execute implements the primitive interface func (r *RenameFields) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { qr, err := r.Input.Execute(vcursor, bindVars, wantfields) diff --git a/go/vt/vtgate/engine/replace_variables.go b/go/vt/vtgate/engine/replace_variables.go index c0397e26403..84f6d417a1b 100644 --- a/go/vt/vtgate/engine/replace_variables.go +++ b/go/vt/vtgate/engine/replace_variables.go @@ -19,6 +19,7 @@ package engine import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/srvtopo" ) var _ Primitive = (*ReplaceVariables)(nil) @@ -49,6 +50,11 @@ func (r *ReplaceVariables) GetTableName() string { return r.Input.GetTableName() } +// GetExecShards lists all the shards that would be accessed by this primitive +func (r *ReplaceVariables) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + return r.Input.GetExecShards(vcursor, bindVars, each) +} + // Execute implements the Primitive interface func (r *ReplaceVariables) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { qr, err := r.Input.Execute(vcursor, bindVars, wantfields) diff --git a/go/vt/vtgate/engine/revert_migration.go b/go/vt/vtgate/engine/revert_migration.go index b3c8594e49d..83dc8de406f 100644 --- a/go/vt/vtgate/engine/revert_migration.go +++ b/go/vt/vtgate/engine/revert_migration.go @@ -26,6 +26,7 @@ import ( vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/schema" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -69,6 +70,17 @@ func (v *RevertMigration) GetTableName() string { return "" } +// GetExecShards lists all the shards that would be accessed by this primitive +func (v *RevertMigration) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + // In Vitess 12, IsSkipTopo() is always true for all strategies + rss, _, err := vcursor.ResolveDestinations(v.Keyspace.Name, nil, []key.Destination{v.TargetDestination}) + if err != nil { + return err + } + each(rss[0]) + return nil +} + // Execute implements the Primitive interface func (v *RevertMigration) Execute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantfields bool) (result *sqltypes.Result, err error) { result = &sqltypes.Result{ diff --git a/go/vt/vtgate/engine/route.go b/go/vt/vtgate/engine/route.go index 7b5fe30eb59..00c5bfcd7c2 100644 --- a/go/vt/vtgate/engine/route.go +++ b/go/vt/vtgate/engine/route.go @@ -235,6 +235,18 @@ func (route *Route) GetTableName() string { return route.TableName } +// GetExecShards lists all the shards that would be accessed by this primitive +func (route *Route) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + rss, _, err := route.resolveShards(vcursor, bindVars) + if err != nil { + return err + } + for _, rs := range rss { + each(rs) + } + return nil +} + // SetTruncateColumnCount sets the truncate column count. func (route *Route) SetTruncateColumnCount(count int) { route.TruncateColumnCount = count @@ -253,29 +265,29 @@ func (route *Route) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVa return qr.Truncate(route.TruncateColumnCount), nil } -func (route *Route) execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { - var rss []*srvtopo.ResolvedShard - var bvs []map[string]*querypb.BindVariable - var err error +func (route *Route) resolveShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable) ([]*srvtopo.ResolvedShard, []map[string]*querypb.BindVariable, error) { switch route.Opcode { case SelectDBA: - rss, bvs, err = route.paramsSystemQuery(vcursor, bindVars) + return route.paramsSystemQuery(vcursor, bindVars) case SelectUnsharded, SelectNext, SelectReference: - rss, bvs, err = route.paramsAnyShard(vcursor, bindVars) + return route.paramsAnyShard(vcursor, bindVars) case SelectScatter: - rss, bvs, err = route.paramsAllShards(vcursor, bindVars) + return route.paramsAllShards(vcursor, bindVars) case SelectEqual, SelectEqualUnique: - rss, bvs, err = route.paramsSelectEqual(vcursor, bindVars) + return route.paramsSelectEqual(vcursor, bindVars) case SelectIN: - rss, bvs, err = route.paramsSelectIn(vcursor, bindVars) + return route.paramsSelectIn(vcursor, bindVars) case SelectMultiEqual: - rss, bvs, err = route.paramsSelectMultiEqual(vcursor, bindVars) + return route.paramsSelectMultiEqual(vcursor, bindVars) case SelectNone: - rss, bvs, err = nil, nil, nil + return nil, nil, nil default: - // Unreachable. - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unsupported query route: %v", route) + return nil, nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unsupported query route: %v", route) } +} + +func (route *Route) execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { + rss, bvs, err := route.resolveShards(vcursor, bindVars) if err != nil { return nil, err } @@ -324,35 +336,14 @@ func filterOutNilErrors(errs []error) []error { // StreamExecute performs a streaming exec. func (route *Route) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - var rss []*srvtopo.ResolvedShard - var bvs []map[string]*querypb.BindVariable - var err error if route.QueryTimeout != 0 { cancel := vcursor.SetContextTimeout(time.Duration(route.QueryTimeout) * time.Millisecond) defer cancel() } - switch route.Opcode { - case SelectDBA: - rss, bvs, err = route.paramsSystemQuery(vcursor, bindVars) - case SelectUnsharded, SelectNext, SelectReference: - rss, bvs, err = route.paramsAnyShard(vcursor, bindVars) - case SelectScatter: - rss, bvs, err = route.paramsAllShards(vcursor, bindVars) - case SelectEqual, SelectEqualUnique: - rss, bvs, err = route.paramsSelectEqual(vcursor, bindVars) - case SelectIN: - rss, bvs, err = route.paramsSelectIn(vcursor, bindVars) - case SelectMultiEqual: - rss, bvs, err = route.paramsSelectMultiEqual(vcursor, bindVars) - case SelectNone: - rss, bvs, err = nil, nil, nil - default: - return fmt.Errorf("query %q cannot be used for streaming", route.Query) - } + rss, bvs, err := route.resolveShards(vcursor, bindVars) if err != nil { return err } - // No route. if len(rss) == 0 { if wantfields { diff --git a/go/vt/vtgate/engine/rows.go b/go/vt/vtgate/engine/rows.go index 98e7fd5e57f..ab97010a253 100644 --- a/go/vt/vtgate/engine/rows.go +++ b/go/vt/vtgate/engine/rows.go @@ -19,6 +19,7 @@ package engine import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/srvtopo" ) var _ Primitive = (*Rows)(nil) @@ -52,6 +53,12 @@ func (r *Rows) GetTableName() string { return "" } +// GetExecShards lists all the shards that would be accessed by this primitive +func (r *Rows) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + // No shards are accessed + return nil +} + //Execute implements the Primitive interface func (r *Rows) Execute(VCursor, map[string]*querypb.BindVariable, bool) (*sqltypes.Result, error) { return &sqltypes.Result{ diff --git a/go/vt/vtgate/engine/send.go b/go/vt/vtgate/engine/send.go index f42411b2666..da04923477a 100644 --- a/go/vt/vtgate/engine/send.go +++ b/go/vt/vtgate/engine/send.go @@ -20,6 +20,7 @@ import ( "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/vindexes" @@ -81,21 +82,38 @@ func (s *Send) GetTableName() string { return "" } -// Execute implements Primitive interface -func (s *Send) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { +func (s *Send) resolveShards(vcursor VCursor) ([]*srvtopo.ResolvedShard, error) { rss, _, err := vcursor.ResolveDestinations(s.Keyspace.Name, nil, []key.Destination{s.TargetDestination}) if err != nil { return nil, err } - if !s.Keyspace.Sharded && len(rss) != 1 { return nil, vterrors.Errorf(vtrpcpb.Code_FAILED_PRECONDITION, "Keyspace does not have exactly one shard: %v", rss) } - if s.SingleShardOnly && len(rss) != 1 { return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Unexpected error, DestinationKeyspaceID mapping to multiple shards: %s, got: %v", s.Query, s.TargetDestination) } + return rss, nil +} + +// GetExecShards lists all the shards that would be accessed by this primitive +func (s *Send) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + rss, err := s.resolveShards(vcursor) + if err != nil { + return err + } + for _, rs := range rss { + each(rs) + } + return nil +} +// Execute implements Primitive interface +func (s *Send) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { + rss, err := s.resolveShards(vcursor) + if err != nil { + return nil, err + } queries := make([]*querypb.BoundQuery, len(rss)) for i, rs := range rss { bv := bindVars @@ -133,19 +151,10 @@ func copyBindVars(in map[string]*querypb.BindVariable) map[string]*querypb.BindV // StreamExecute implements Primitive interface func (s *Send) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - rss, _, err := vcursor.ResolveDestinations(s.Keyspace.Name, nil, []key.Destination{s.TargetDestination}) + rss, err := s.resolveShards(vcursor) if err != nil { return err } - - if !s.Keyspace.Sharded && len(rss) != 1 { - return vterrors.Errorf(vtrpcpb.Code_FAILED_PRECONDITION, "Keyspace does not have exactly one shard: %v", rss) - } - - if s.SingleShardOnly && len(rss) != 1 { - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Unexpected error, DestinationKeyspaceID mapping to multiple shards: %s, got: %v", s.Query, s.TargetDestination) - } - multiBindVars := make([]map[string]*querypb.BindVariable, len(rss)) for i, rs := range rss { bv := bindVars diff --git a/go/vt/vtgate/engine/session_primitive.go b/go/vt/vtgate/engine/session_primitive.go index f32d2e5f3ae..2308eeafca2 100644 --- a/go/vt/vtgate/engine/session_primitive.go +++ b/go/vt/vtgate/engine/session_primitive.go @@ -20,6 +20,7 @@ import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vterrors" ) @@ -58,6 +59,13 @@ func (s *SessionPrimitive) GetTableName() string { return "" } +// GetExecShards lists all the shards that would be accessed by this primitive +func (s *SessionPrimitive) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + // the SessionPrimitive only changes the state on the current Session, it + // does not reach out to any shards + return nil +} + // Execute implements the Primitive interface func (s *SessionPrimitive) Execute(vcursor VCursor, _ map[string]*querypb.BindVariable, _ bool) (*sqltypes.Result, error) { return s.action(vcursor.Session()) diff --git a/go/vt/vtgate/engine/set.go b/go/vt/vtgate/engine/set.go index f15c506c503..038c2f6946a 100644 --- a/go/vt/vtgate/engine/set.go +++ b/go/vt/vtgate/engine/set.go @@ -53,6 +53,7 @@ type ( // SetOp is an interface that different type of set operations implements. SetOp interface { Execute(vcursor VCursor, env evalengine.ExpressionEnv) error + GetAffectedShards(vcursor VCursor, each func(rs *srvtopo.ResolvedShard)) error VariableName() string } @@ -109,6 +110,18 @@ func (s *Set) GetTableName() string { return "" } +func (s *Set) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + if err := s.Input.GetExecShards(vcursor, bindVars, each); err != nil { + return err + } + for _, setOp := range s.Ops { + if err := setOp.GetAffectedShards(vcursor, each); err != nil { + return err + } + } + return nil +} + //Execute implements the Primitive interface method. func (s *Set) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool) (*sqltypes.Result, error) { input, err := s.Input.Execute(vcursor, bindVars, false) @@ -181,6 +194,10 @@ func (u *UserDefinedVariable) VariableName() string { return u.Name } +func (u *UserDefinedVariable) GetAffectedShards(vcursor VCursor, each func(rs *srvtopo.ResolvedShard)) error { + return nil +} + //Execute implements the SetOp interface method. func (u *UserDefinedVariable) Execute(vcursor VCursor, env evalengine.ExpressionEnv) error { value, err := u.Expr.Evaluate(env) @@ -209,6 +226,10 @@ func (svi *SysVarIgnore) VariableName() string { return svi.Name } +func (svi *SysVarIgnore) GetAffectedShards(vcursor VCursor, each func(rs *srvtopo.ResolvedShard)) error { + return nil +} + //Execute implements the SetOp interface method. func (svi *SysVarIgnore) Execute(VCursor, evalengine.ExpressionEnv) error { log.Infof("Ignored inapplicable SET %v = %v", svi.Name, svi.Expr) @@ -234,6 +255,18 @@ func (svci *SysVarCheckAndIgnore) VariableName() string { return svci.Name } +func (svci *SysVarCheckAndIgnore) GetAffectedShards(vcursor VCursor, each func(rs *srvtopo.ResolvedShard)) error { + rss, _, err := vcursor.ResolveDestinations(svci.Keyspace.Name, nil, []key.Destination{svci.TargetDestination}) + if err != nil { + return err + } + if len(rss) != 1 { + return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Unexpected error, DestinationKeyspaceID mapping to multiple shards: %v", svci.TargetDestination) + } + each(rss[0]) + return nil +} + //Execute implements the SetOp interface method func (svci *SysVarCheckAndIgnore) Execute(vcursor VCursor, env evalengine.ExpressionEnv) error { rss, _, err := vcursor.ResolveDestinations(svci.Keyspace.Name, nil, []key.Destination{svci.TargetDestination}) @@ -278,6 +311,24 @@ func (svs *SysVarReservedConn) VariableName() string { return svs.Name } +func (svs *SysVarReservedConn) GetAffectedShards(vcursor VCursor, each func(rs *srvtopo.ResolvedShard)) error { + if svs.TargetDestination != nil { + rss, _, err := vcursor.ResolveDestinations(svs.Keyspace.Name, nil, []key.Destination{svs.TargetDestination}) + if err != nil { + return err + } + for _, rs := range rss { + each(rs) + } + return nil + } + + for _, rs := range vcursor.Session().ShardSession() { + each(rs) + } + return nil +} + //Execute implements the SetOp interface method func (svs *SysVarReservedConn) Execute(vcursor VCursor, env evalengine.ExpressionEnv) error { // For those running on advanced vitess settings. @@ -362,6 +413,10 @@ func (svss *SysVarSetAware) MarshalJSON() ([]byte, error) { }) } +func (svss *SysVarSetAware) GetAffectedShards(vcursor VCursor, each func(rs *srvtopo.ResolvedShard)) error { + return nil +} + //Execute implements the SetOp interface method func (svss *SysVarSetAware) Execute(vcursor VCursor, env evalengine.ExpressionEnv) error { var err error diff --git a/go/vt/vtgate/engine/singlerow.go b/go/vt/vtgate/engine/singlerow.go index a2a5b80b308..66ab6d70476 100644 --- a/go/vt/vtgate/engine/singlerow.go +++ b/go/vt/vtgate/engine/singlerow.go @@ -19,6 +19,8 @@ package engine import ( "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/proto/query" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/srvtopo" ) var _ Primitive = (*SingleRow)(nil) @@ -44,6 +46,10 @@ func (s *SingleRow) GetTableName() string { return "" } +func (s *SingleRow) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + return nil +} + // Execute performs a non-streaming exec. func (s *SingleRow) Execute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantfields bool) (*sqltypes.Result, error) { result := sqltypes.Result{ diff --git a/go/vt/vtgate/engine/sql_calc_found_rows.go b/go/vt/vtgate/engine/sql_calc_found_rows.go index 65c3f7574dc..d002b6f4a8d 100644 --- a/go/vt/vtgate/engine/sql_calc_found_rows.go +++ b/go/vt/vtgate/engine/sql_calc_found_rows.go @@ -20,6 +20,7 @@ import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/evalengine" ) @@ -47,6 +48,17 @@ func (s SQLCalcFoundRows) GetTableName() string { return s.LimitPrimitive.GetTableName() } +// GetExecShards lists all the shards that would be accessed by this primitive +func (s SQLCalcFoundRows) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + if err := s.LimitPrimitive.GetExecShards(vcursor, bindVars, each); err != nil { + return err + } + if err := s.CountPrimitive.GetExecShards(vcursor, bindVars, each); err != nil { + return err + } + return nil +} + //Execute implements the Primitive interface func (s SQLCalcFoundRows) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { limitQr, err := s.LimitPrimitive.Execute(vcursor, bindVars, wantfields) diff --git a/go/vt/vtgate/engine/subquery.go b/go/vt/vtgate/engine/subquery.go new file mode 100644 index 00000000000..e69de29bb2d diff --git a/go/vt/vtgate/engine/update.go b/go/vt/vtgate/engine/update.go index a8b5ea33c74..2fb4dfed664 100644 --- a/go/vt/vtgate/engine/update.go +++ b/go/vt/vtgate/engine/update.go @@ -80,6 +80,59 @@ func (upd *Update) GetTableName() string { return "" } +func (upd *Update) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + switch upd.Opcode { + case Unsharded: + rss, _, err := vcursor.ResolveDestinations(upd.Keyspace.Name, nil, []key.Destination{key.DestinationAllShards{}}) + if err != nil { + return err + } + each(rss[0]) + return nil + case Equal: + key, err := upd.Values[0].ResolveValue(bindVars) + if err != nil { + return err + } + rs, _, err := resolveSingleShard(vcursor, upd.Vindex, upd.Keyspace, key) + if err != nil { + return err + } + each(rs) + return nil + case In: + rss, _, err := resolveMultiValueShards(vcursor, upd.Keyspace, upd.Query, bindVars, upd.Values[0], upd.Vindex) + if err != nil { + return err + } + for _, rs := range rss { + each(rs) + } + return nil + case Scatter: + rss, _, err := vcursor.ResolveDestinations(upd.Keyspace.Name, nil, []key.Destination{key.DestinationAllShards{}}) + if err != nil { + return err + } + for _, rs := range rss { + each(rs) + } + return nil + case ByDestination: + rss, _, err := vcursor.ResolveDestinations(upd.Keyspace.Name, nil, []key.Destination{upd.TargetDestination}) + if err != nil { + return err + } + for _, rs := range rss { + each(rs) + } + return nil + default: + // Unreachable. + return fmt.Errorf("unsupported opcode: %v", upd) + } +} + // Execute performs a non-streaming exec. func (upd *Update) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { if upd.QueryTimeout != 0 { diff --git a/go/vt/vtgate/engine/update_target.go b/go/vt/vtgate/engine/update_target.go index ad5d04d4e5e..e4e6cf87f81 100644 --- a/go/vt/vtgate/engine/update_target.go +++ b/go/vt/vtgate/engine/update_target.go @@ -18,6 +18,7 @@ package engine import ( vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/sqltypes" @@ -58,6 +59,11 @@ func (updTarget *UpdateTarget) GetTableName() string { return "" } +func (updTarget *UpdateTarget) GetExecShards(vcursor VCursor, bindVars map[string]*query.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + // This is a memory-only operation + return nil +} + // Execute implements the Primitive interface func (updTarget *UpdateTarget) Execute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantfields bool) (*sqltypes.Result, error) { err := vcursor.Session().SetTarget(updTarget.Target) diff --git a/go/vt/vtgate/engine/vindex_func.go b/go/vt/vtgate/engine/vindex_func.go index 0bdefcc98d7..6096f061589 100644 --- a/go/vt/vtgate/engine/vindex_func.go +++ b/go/vt/vtgate/engine/vindex_func.go @@ -20,6 +20,7 @@ import ( "encoding/json" "fmt" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/sqltypes" @@ -85,6 +86,12 @@ func (vf *VindexFunc) GetTableName() string { return "" } +// GetExecShards lists all the shards that would be accessed by this primitive +func (vf *VindexFunc) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + // This Vindex query is resolved without accessing any shards + return nil +} + // Execute performs a non-streaming exec. func (vf *VindexFunc) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { return vf.mapVindex(vcursor, bindVars) diff --git a/go/vt/vtgate/engine/vschema_ddl.go b/go/vt/vtgate/engine/vschema_ddl.go index 2cf2591a399..b8ae1bbed00 100644 --- a/go/vt/vtgate/engine/vschema_ddl.go +++ b/go/vt/vtgate/engine/vschema_ddl.go @@ -21,6 +21,7 @@ import ( "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -63,6 +64,12 @@ func (v *AlterVSchema) GetTableName() string { return v.AlterVschemaDDL.Table.Name.String() } +// GetExecShards lists all the shards that would be accessed by this primitive +func (v *AlterVSchema) GetExecShards(vcursor VCursor, bindVars map[string]*query.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + // AlterVSchema is a topo operation so it doesn't reach out to any shards + return nil +} + //Execute implements the Primitive interface func (v *AlterVSchema) Execute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantfields bool) (*sqltypes.Result, error) { err := vcursor.ExecuteVSchema(v.Keyspace.Name, v.AlterVschemaDDL) diff --git a/go/vt/vtgate/engine/vstream.go b/go/vt/vtgate/engine/vstream.go index 66386ba0b37..0e6cdc7a77f 100644 --- a/go/vt/vtgate/engine/vstream.go +++ b/go/vt/vtgate/engine/vstream.go @@ -21,6 +21,7 @@ import ( "io" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/sqltypes" @@ -60,6 +61,18 @@ func (v *VStream) GetTableName() string { return v.TableName } +// GetExecShards lists all the shards that would be accessed by this primitive +func (v *VStream) GetExecShards(vcursor VCursor, bindVars map[string]*querypb.BindVariable, each func(rs *srvtopo.ResolvedShard)) error { + rss, _, err := vcursor.ResolveDestinations(v.Keyspace.Name, nil, []key.Destination{v.TargetDestination}) + if err != nil { + return err + } + for _, rs := range rss { + each(rs) + } + return nil +} + // Execute implements the Primitive interface func (v *VStream) Execute(_ VCursor, _ map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { return nil, vterrors.New(vtrpcpb.Code_INTERNAL, "[BUG] 'Execute' called for VStream")