Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ require (
golang.org/x/lint v0.0.0-20190409202823-959b441ac422
golang.org/x/net v0.0.0-20200202094626-16171245cfb2
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e
golang.org/x/text v0.3.2
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
golang.org/x/tools v0.0.0-20191219041853-979b82bfef62
Expand Down
38 changes: 37 additions & 1 deletion go/test/endtoend/vtgate/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"fmt"
"testing"

"vitess.io/vitess/go/test/utils"

"github.com/google/go-cmp/cmp"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -558,6 +560,40 @@ func TestCastConvert(t *testing.T) {
assertMatches(t, conn, `SELECT CAST("test" AS CHAR(60))`, `[[VARCHAR("test")]]`)
}

func TestUnionAll(t *testing.T) {
conn, err := mysql.Connect(context.Background(), &vtParams)
require.NoError(t, err)
defer conn.Close()

exec(t, conn, "delete from t1")
exec(t, conn, "delete from t2")

exec(t, conn, "insert into t1(id1, id2) values(1, 1), (2, 2)")
exec(t, conn, "insert into t2(id3, id4) values(3, 3), (4, 4)")

// union all between two selectuniqueequal
assertMatches(t, conn, "select id1 from t1 where id1 = 1 union all select id1 from t1 where id1 = 4", "[[INT64(1)]]")

// union all between two different tables
assertMatches(t, conn, "(select id1,id2 from t1 order by id1) union all (select id3,id4 from t2 order by id3)",
"[[INT64(1) INT64(1)] [INT64(2) INT64(2)] [INT64(3) INT64(3)] [INT64(4) INT64(4)]]")

// union all between two different tables
assertMatches(t, conn, "select tbl2.id1 FROM ((select id1 from t1 order by id1 limit 5) union all (select id1 from t1 order by id1 desc limit 5)) as tbl1 INNER JOIN t1 as tbl2 ON tbl1.id1 = tbl2.id1",
"[[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]")

exec(t, conn, "insert into t1(id1, id2) values(3, 3), (4, 4), (5, 5), (6, 6), (7, 7), (8, 8)")

// union all between two selectuniquein tables
qr := exec(t, conn, "select id1 from t1 where id1 in (1, 2, 3, 4, 5, 6, 7, 8) union all select id1 from t1 where id1 in (1, 2, 3, 4, 5, 6, 7, 8)")
expected := utils.SortString("[[INT64(1)] [INT64(2)] [INT64(3)] [INT64(5)] [INT64(4)] [INT64(6)] [INT64(7)] [INT64(8)] [INT64(1)] [INT64(2)] [INT64(3)] [INT64(5)] [INT64(4)] [INT64(6)] [INT64(7)] [INT64(8)]]")
assert.Equal(t, expected, utils.SortString(fmt.Sprintf("%v", qr.Rows)))

// clean up
exec(t, conn, "delete from t1")
exec(t, conn, "delete from t2")
}

func TestUnion(t *testing.T) {
conn, err := mysql.Connect(context.Background(), &vtParams)
require.NoError(t, err)
Expand All @@ -569,7 +605,7 @@ func TestUnion(t *testing.T) {
assertMatches(t, conn, `SELECT 1,'a' UNION ALL SELECT 1,'a' UNION ALL SELECT 1,'a' ORDER BY 1`, `[[INT64(1) VARCHAR("a")] [INT64(1) VARCHAR("a")] [INT64(1) VARCHAR("a")]]`)
assertMatches(t, conn, `(SELECT 1,'a') UNION ALL (SELECT 1,'a') UNION ALL (SELECT 1,'a') ORDER BY 1`, `[[INT64(1) VARCHAR("a")] [INT64(1) VARCHAR("a")] [INT64(1) VARCHAR("a")]]`)
assertMatches(t, conn, `(SELECT 1,'a') ORDER BY 1`, `[[INT64(1) VARCHAR("a")]]`)
assertMatches(t, conn, `(SELECT 1,'a' order by 1) union SELECT 1,'a' ORDER BY 1`, `[[INT64(1) VARCHAR("a")]]`)
assertMatches(t, conn, `(SELECT 1,'a' order by 1) union (SELECT 1,'a' ORDER BY 1)`, `[[INT64(1) VARCHAR("a")]]`)
}

func assertMatches(t *testing.T, conn *mysql.Conn, query, expected string) {
Expand Down
13 changes: 13 additions & 0 deletions go/test/utils/sort.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package utils

import (
"sort"
"strings"
)

//SortString sorts the string.
func SortString(w string) string {
s := strings.Split(w, "")
sort.Strings(s)
return strings.Join(s, "")
}
1 change: 1 addition & 0 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type (
iInsertRows()
AddOrder(*Order)
SetLimit(*Limit)
SetLock(lock string)
SQLNode
}

Expand Down
19 changes: 17 additions & 2 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,11 @@ func (node *Select) SetLimit(limit *Limit) {
node.Limit = limit
}

// SetLock sets the lock clause
func (node *Select) SetLock(lock string) {
node.Lock = lock
}

// AddWhere adds the boolean expression to the
// WHERE clause as an AND condition.
func (node *Select) AddWhere(expr Expr) {
Expand Down Expand Up @@ -751,12 +756,17 @@ func (node *Select) AddHaving(expr Expr) {

// AddOrder adds an order by element
func (node *ParenSelect) AddOrder(order *Order) {
panic("unreachable")
node.Select.AddOrder(order)
}

// SetLimit sets the limit clause
func (node *ParenSelect) SetLimit(limit *Limit) {
panic("unreachable")
node.Select.SetLimit(limit)
}

// SetLock sets the lock clause
func (node *ParenSelect) SetLock(lock string) {
node.Select.SetLock(lock)
}

// AddOrder adds an order by element
Expand All @@ -769,6 +779,11 @@ func (node *Union) SetLimit(limit *Limit) {
node.Limit = limit
}

// SetLock sets the lock clause
func (node *Union) SetLock(lock string) {
node.Lock = lock
}

//Unionize returns a UNION, either creating one or adding SELECT to an existing one
func Unionize(lhs, rhs SelectStatement, typ string, by OrderBy, limit *Limit, lock string) *Union {
union, isUnion := lhs.(*Union)
Expand Down
209 changes: 209 additions & 0 deletions go/vt/vtgate/engine/concatenate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
Copyright 2020 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package engine

import (
"sort"
"strings"
"sync"

"vitess.io/vitess/go/mysql"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
)

// Concatenate Primitive is used to concatenate results from multiple sources.
var _ Primitive = (*Concatenate)(nil)

//Concatenate specified the parameter for concatenate primitive
type Concatenate struct {
Sources []Primitive
}

//RouteType returns a description of the query routing type used by the primitive
func (c *Concatenate) RouteType() string {
return "Concatenate"
}

// GetKeyspaceName specifies the Keyspace that this primitive routes to
func (c *Concatenate) GetKeyspaceName() string {
ksMap := map[string]interface{}{}
for _, source := range c.Sources {
ksMap[source.GetKeyspaceName()] = nil
}
var ksArr []string
for ks := range ksMap {
ksArr = append(ksArr, ks)
}
sort.Strings(ksArr)
return strings.Join(ksArr, "_")
}

// GetTableName specifies the table that this primitive routes to.
func (c *Concatenate) GetTableName() string {
var tabArr []string
for _, source := range c.Sources {
tabArr = append(tabArr, source.GetTableName())
}
return strings.Join(tabArr, "_")
}

// Execute performs a non-streaming exec.
func (c *Concatenate) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
result := &sqltypes.Result{}
var wg sync.WaitGroup
qrs := make([]*sqltypes.Result, len(c.Sources))
errs := make([]error, len(c.Sources))
for i, source := range c.Sources {
wg.Add(1)
go func(i int, source Primitive) {
defer wg.Done()
qrs[i], errs[i] = source.Execute(vcursor, bindVars, wantfields)
}(i, source)
}
wg.Wait()
for i := 0; i < len(c.Sources); i++ {
if errs[i] != nil {
return nil, vterrors.Wrap(errs[i], "Concatenate.Execute")
}
qr := qrs[i]
if result.Fields == nil {
result.Fields = qr.Fields
}
err := compareFields(result.Fields, qr.Fields)
if err != nil {
return nil, err
}
if len(qr.Rows) > 0 {
result.Rows = append(result.Rows, qr.Rows...)
if len(result.Rows[0]) != len(qr.Rows[0]) {
return nil, mysql.NewSQLError(mysql.ERWrongNumberOfColumnsInSelect, "21000", "The used SELECT statements have a different number of columns")
}
result.RowsAffected += qr.RowsAffected
}
}
return result, nil
}

// StreamExecute performs a streaming exec.
func (c *Concatenate) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
var seenFields []*querypb.Field
var fieldset sync.WaitGroup
fieldsSent := false

g := vcursor.ErrorGroupCancellableContext()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a general rule, we should not change the caller's context. Although things will work as expected in this case, it's better to still follow the rule. This means that you have to create a local cancel-able one and use that instead.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need the engine primitives to use this new cancelable context. Contexts are like onions - there are layers to them. We are merely adding a cancelable layer to the context the caller provided us with. if you want, we could peel that layer before leaving this method, to restore the original context.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. I didn't realize that you were passing vcursor down to the callees. So, it's ok to leave it as is for now.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The context created here needs to be sent down to calls from here one. So that once the context is cancelled. The lower layers can also stop processing.
The issue here is the lower layer can receive this new context via vcursor as of now. So, the current way was to replace it.
Once we start passing context separately this would not be an issue.

@systay The current suggestion does not hold. Because if we replace the vcursor context and there is parent branch which has vcursor usage. It will already be updated and using the new context. So peeling off before returning may not guarantee things are in order.

fieldset.Add(1)
var mu sync.Mutex
for i, source := range c.Sources {
i, source := i, source
g.Go(func() error {
err := source.StreamExecute(vcursor, bindVars, wantfields, func(resultChunk *sqltypes.Result) error {
// if we have fields to compare, make sure all the fields are all the same
if i == 0 && !fieldsSent {
defer fieldset.Done()
seenFields = resultChunk.Fields
fieldsSent = true
// No other call can happen before this call.
return callback(resultChunk)
}
fieldset.Wait()
if resultChunk.Fields != nil {
err := compareFields(seenFields, resultChunk.Fields)
if err != nil {
return err
}
}
// This to ensure only one send happens back to the client.
mu.Lock()
defer mu.Unlock()
select {
case <-vcursor.Context().Done():
return nil
default:
return callback(resultChunk)
}
})
// This is to ensure other streams complete if the first stream failed to unlock the wait.
if i == 0 && !fieldsSent {
fieldset.Done()
}
return err
})

}
if err := g.Wait(); err != nil {
return err
}
return nil
}

// GetFields fetches the field info.
func (c *Concatenate) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
firstQr, err := c.Sources[0].GetFields(vcursor, bindVars)
if err != nil {
return nil, err
}
for i, source := range c.Sources {
if i == 0 {
continue
}
qr, err := source.GetFields(vcursor, bindVars)
if err != nil {
return nil, err
}
err = compareFields(firstQr.Fields, qr.Fields)
if err != nil {
return nil, err
}
}
return firstQr, nil
}

//NeedsTransaction returns whether a transaction is needed for this primitive
func (c *Concatenate) NeedsTransaction() bool {
for _, source := range c.Sources {
if source.NeedsTransaction() {
return true
}
}
return false
}

// Inputs returns the input primitives for this
func (c *Concatenate) Inputs() []Primitive {
return c.Sources
}

func (c *Concatenate) description() PrimitiveDescription {
return PrimitiveDescription{OperatorType: c.RouteType()}
}

func compareFields(fields1 []*querypb.Field, fields2 []*querypb.Field) error {
if len(fields1) != len(fields2) {
return mysql.NewSQLError(mysql.ERWrongNumberOfColumnsInSelect, "21000", "The used SELECT statements have a different number of columns")
}
for i, field2 := range fields2 {
field1 := fields1[i]
if field1.Type != field2.Type {
return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "column field type does not match for name: (%v, %v) types: (%v, %v)", field1.Name, field2.Name, field1.Type, field2.Type)
}
}
return nil
}
Loading