Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions go/vt/vtgate/engine/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/youtube/vitess/go/sqltypes"

querypb "github.com/youtube/vitess/go/vt/proto/query"
"github.com/youtube/vitess/go/vt/vtgate/queryinfo"
)

// Join specifies the parameters for a join primitive.
Expand All @@ -33,8 +34,8 @@ type Join struct {
}

// Execute performs a non-streaming exec.
func (jn *Join) Execute(vcursor VCursor, joinvars map[string]interface{}, wantfields bool) (*sqltypes.Result, error) {
lresult, err := jn.Left.Execute(vcursor, joinvars, wantfields)
func (jn *Join) Execute(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, joinvars map[string]interface{}, wantfields bool) (*sqltypes.Result, error) {
lresult, err := jn.Left.Execute(vcursor, queryConstruct, joinvars, wantfields)
if err != nil {
return nil, err
}
Expand All @@ -43,7 +44,7 @@ func (jn *Join) Execute(vcursor VCursor, joinvars map[string]interface{}, wantfi
for k := range jn.Vars {
joinvars[k] = nil
}
rresult, err := jn.Right.GetFields(vcursor, joinvars)
rresult, err := jn.Right.GetFields(vcursor, queryConstruct, joinvars)
if err != nil {
return nil, err
}
Expand All @@ -54,7 +55,7 @@ func (jn *Join) Execute(vcursor VCursor, joinvars map[string]interface{}, wantfi
for k, col := range jn.Vars {
joinvars[k] = lrow[col]
}
rresult, err := jn.Right.Execute(vcursor, joinvars, wantfields)
rresult, err := jn.Right.Execute(vcursor, queryConstruct, joinvars, wantfields)
if err != nil {
return nil, err
}
Expand All @@ -76,14 +77,14 @@ func (jn *Join) Execute(vcursor VCursor, joinvars map[string]interface{}, wantfi
}

// StreamExecute performs a streaming exec.
func (jn *Join) StreamExecute(vcursor VCursor, joinvars map[string]interface{}, wantfields bool, sendReply func(*sqltypes.Result) error) error {
err := jn.Left.StreamExecute(vcursor, joinvars, wantfields, func(lresult *sqltypes.Result) error {
func (jn *Join) StreamExecute(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, joinvars map[string]interface{}, wantfields bool, sendReply func(*sqltypes.Result) error) error {
err := jn.Left.StreamExecute(vcursor, queryConstruct, joinvars, wantfields, func(lresult *sqltypes.Result) error {
for _, lrow := range lresult.Rows {
for k, col := range jn.Vars {
joinvars[k] = lrow[col]
}
rowSent := false
err := jn.Right.StreamExecute(vcursor, joinvars, wantfields, func(rresult *sqltypes.Result) error {
err := jn.Right.StreamExecute(vcursor, queryConstruct, joinvars, wantfields, func(rresult *sqltypes.Result) error {
result := &sqltypes.Result{}
if wantfields {
wantfields = false
Expand Down Expand Up @@ -120,7 +121,7 @@ func (jn *Join) StreamExecute(vcursor VCursor, joinvars map[string]interface{},
joinvars[k] = nil
}
result := &sqltypes.Result{}
rresult, err := jn.Right.GetFields(vcursor, joinvars)
rresult, err := jn.Right.GetFields(vcursor, queryConstruct, joinvars)
if err != nil {
return err
}
Expand All @@ -133,16 +134,16 @@ func (jn *Join) StreamExecute(vcursor VCursor, joinvars map[string]interface{},
}

// GetFields fetches the field info.
func (jn *Join) GetFields(vcursor VCursor, joinvars map[string]interface{}) (*sqltypes.Result, error) {
lresult, err := jn.Left.GetFields(vcursor, joinvars)
func (jn *Join) GetFields(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, joinvars map[string]interface{}) (*sqltypes.Result, error) {
lresult, err := jn.Left.GetFields(vcursor, queryConstruct, joinvars)
if err != nil {
return nil, err
}
result := &sqltypes.Result{}
for k := range jn.Vars {
joinvars[k] = nil
}
rresult, err := jn.Right.GetFields(vcursor, joinvars)
rresult, err := jn.Right.GetFields(vcursor, queryConstruct, joinvars)
if err != nil {
return nil, err
}
Expand Down
24 changes: 17 additions & 7 deletions go/vt/vtgate/engine/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

package engine

import "github.com/youtube/vitess/go/sqltypes"
import (
"github.com/youtube/vitess/go/sqltypes"
topodatapb "github.com/youtube/vitess/go/vt/proto/topodata"
"github.com/youtube/vitess/go/vt/tabletserver/querytypes"
"github.com/youtube/vitess/go/vt/vtgate/queryinfo"
)

// SeqVarName is a reserved bind var name for sequence values.
const SeqVarName = "__seq"
Expand All @@ -17,9 +22,14 @@ const ListVarName = "__vals"
// VCursor defines the interface the engine will use
// to execute routes.
type VCursor interface {
ExecuteRoute(route *Route, joinvars map[string]interface{}) (*sqltypes.Result, error)
StreamExecuteRoute(route *Route, joinvars map[string]interface{}, sendReply func(*sqltypes.Result) error) error
GetRouteFields(route *Route, joinvars map[string]interface{}) (*sqltypes.Result, error)
ExecuteMultiShard(keyspace string, shardQueries map[string]querytypes.BoundQuery, notInTransaction bool) (*sqltypes.Result, error)
StreamExecuteMulti(query string, keyspace string, shardVars map[string]map[string]interface{}, sendReply func(reply *sqltypes.Result) error) error
GetAnyShard(keyspace string) (ks, shard string, err error)
ScatterConnExecute(query string, bindVars map[string]interface{}, keyspace string, shards []string, notInTransaction bool) (*sqltypes.Result, error)
GetKeyspaceShards(keyspace string) (string, *topodatapb.SrvKeyspace, []*topodatapb.ShardReference, error)
GetShardForKeyspaceID(allShards []*topodatapb.ShardReference, keyspaceID []byte) (string, error)
ExecuteShard(keyspace string, shardQueries map[string]querytypes.BoundQuery) (*sqltypes.Result, error)
Execute(query string, bindvars map[string]interface{}) (*sqltypes.Result, error)
}

// Plan represents the execution strategy for a given query.
Expand All @@ -45,7 +55,7 @@ func (pln *Plan) Size() int {
// Primitive is the interface that needs to be satisfied by
// all primitives of a plan.
type Primitive interface {
Execute(vcursor VCursor, joinvars map[string]interface{}, wantfields bool) (*sqltypes.Result, error)
StreamExecute(vcursor VCursor, joinvars map[string]interface{}, wantields bool, sendReply func(*sqltypes.Result) error) error
GetFields(vcursor VCursor, joinvars map[string]interface{}) (*sqltypes.Result, error)
Execute(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, joinvars map[string]interface{}, wantfields bool) (*sqltypes.Result, error)
StreamExecute(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, joinvars map[string]interface{}, wantields bool, sendReply func(*sqltypes.Result) error) error
GetFields(vcursor VCursor, queryConstruct *queryinfo.QueryConstruct, joinvars map[string]interface{}) (*sqltypes.Result, error)
}
Loading