From a75e95abf494f92607f172bb3f8e11e40cd4ed70 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Mon, 22 Jun 2020 14:36:04 +0530 Subject: [PATCH 01/13] union-all: added concatenate primitive Signed-off-by: Harshit Gangal --- go/vt/vtgate/engine/concatenate.go | 89 +++++++++++++++++++ go/vt/vtgate/engine/concatenate_test.go | 57 ++++++++++++ go/vt/vtgate/planbuilder/testdata/onecase.txt | 5 +- 3 files changed, 150 insertions(+), 1 deletion(-) create mode 100644 go/vt/vtgate/engine/concatenate.go create mode 100644 go/vt/vtgate/engine/concatenate_test.go diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go new file mode 100644 index 00000000000..8ef8cfcf320 --- /dev/null +++ b/go/vt/vtgate/engine/concatenate.go @@ -0,0 +1,89 @@ +package engine + +import ( + "sort" + "strings" + + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" +) + +var _ Primitive = (*Concatenate)(nil) + +type Concatenate struct { + Sources []Primitive +} + +func (c Concatenate) RouteType() string { + return "Concatenate" +} + +func (c Concatenate) GetKeyspaceName() string { + ksMap := map[string]interface{}{} + for _, source := range c.Sources { + ksMap[source.GetKeyspaceName()] = nil + } + var ksArr []string + for ks := range ksMap { + ksArr = append(ksArr, ks) + } + sort.Strings(ksArr) + return strings.Join(ksArr, "_") +} + +func (c Concatenate) GetTableName() string { + var tabArr []string + for _, source := range c.Sources { + tabArr = append(tabArr, source.GetTableName()) + } + return strings.Join(tabArr, "_") +} + +func (c Concatenate) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { + panic("implement me") +} + +func (c Concatenate) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantields bool, callback func(*sqltypes.Result) error) error { + panic("implement me") +} + +func (c Concatenate) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + firstQr, err := c.Sources[0].GetFields(vcursor, bindVars) + if err != nil { + return nil, err + } + for i, source := range c.Sources { + if i == 0 { + continue + } + qr, err := source.GetFields(vcursor, bindVars) + if err != nil { + return nil, err + } + for i, field := range qr.Fields { + if firstQr.Fields[i].Type != field.Type { + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "column field type does not match for name: (%v, %v) types: (%v, %v)", firstQr.Fields[i].Name, field.Name, firstQr.Fields[i].Type, field.Type) + } + } + } + return firstQr, nil +} + +func (c Concatenate) NeedsTransaction() bool { + for _, source := range c.Sources { + if source.NeedsTransaction() { + return true + } + } + return false +} + +func (c Concatenate) Inputs() []Primitive { + return c.Sources +} + +func (c Concatenate) description() PrimitiveDescription { + return PrimitiveDescription{OperatorType: c.RouteType()} +} diff --git a/go/vt/vtgate/engine/concatenate_test.go b/go/vt/vtgate/engine/concatenate_test.go new file mode 100644 index 00000000000..89a294ecb48 --- /dev/null +++ b/go/vt/vtgate/engine/concatenate_test.go @@ -0,0 +1,57 @@ +package engine + +import ( + "testing" + + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/sqltypes" +) + +func TestConcatenate_Execute(t *testing.T) { + type testCase struct { + testName string + inputs []*sqltypes.Result + expectedResult *sqltypes.Result + expectedError string + } + + testCases := []*testCase{ + { + testName: "2 empty result", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("", ""), ""), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("", ""), ""), + }, + expectedResult: nil, + expectedError: "", + }, + { + testName: "3 empty result", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("", ""), ""), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("", ""), ""), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("", ""), ""), + }, + expectedResult: nil, + expectedError: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.testName, func(t *testing.T) { + var fps []Primitive + for _, input := range tc.inputs { + fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{input}}) + } + concatenate := Concatenate{Sources: fps} + qr, err := concatenate.Execute(&noopVCursor{}, nil, true) + if tc.expectedError == "" { + require.NoError(t, err) + require.Equal(t, tc.expectedResult, qr) + } else { + require.EqualError(t, err, tc.expectedError) + } + }) + } + +} diff --git a/go/vt/vtgate/planbuilder/testdata/onecase.txt b/go/vt/vtgate/planbuilder/testdata/onecase.txt index e819513f354..f821c90b4e5 100644 --- a/go/vt/vtgate/planbuilder/testdata/onecase.txt +++ b/go/vt/vtgate/planbuilder/testdata/onecase.txt @@ -1 +1,4 @@ -# Add your test case here for debugging and run go test -run=One. +# union all using sharded +"select id from user union all select id from user" +{ +} From 1ffea268c6d6ac86818b5ea784d17eba98933324 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Mon, 22 Jun 2020 17:41:08 +0530 Subject: [PATCH 02/13] union-all: concatenate engine implementation Signed-off-by: Harshit Gangal --- go/vt/vtgate/engine/concatenate.go | 76 ++++++++++++++++++++----- go/vt/vtgate/engine/concatenate_test.go | 66 +++++++++++++++------ 2 files changed, 112 insertions(+), 30 deletions(-) diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 8ef8cfcf320..32970517f13 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -4,23 +4,29 @@ import ( "sort" "strings" + "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" ) +// Concatenate Primitive is used to concatenate results from multiple sources. var _ Primitive = (*Concatenate)(nil) +//Concatenate specified the parameter for concatenate primitive type Concatenate struct { Sources []Primitive } -func (c Concatenate) RouteType() string { +//RouteType returns a description of the query routing type used by the primitive +func (c *Concatenate) RouteType() string { return "Concatenate" } -func (c Concatenate) GetKeyspaceName() string { +// GetKeyspaceName specifies the Keyspace that this primitive routes to +func (c *Concatenate) GetKeyspaceName() string { ksMap := map[string]interface{}{} for _, source := range c.Sources { ksMap[source.GetKeyspaceName()] = nil @@ -33,7 +39,8 @@ func (c Concatenate) GetKeyspaceName() string { return strings.Join(ksArr, "_") } -func (c Concatenate) GetTableName() string { +// GetTableName specifies the table that this primitive routes to. +func (c *Concatenate) GetTableName() string { var tabArr []string for _, source := range c.Sources { tabArr = append(tabArr, source.GetTableName()) @@ -41,15 +48,42 @@ func (c Concatenate) GetTableName() string { return strings.Join(tabArr, "_") } -func (c Concatenate) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { - panic("implement me") +// Execute performs a non-streaming exec. +func (c *Concatenate) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { + result := &sqltypes.Result{} + for _, source := range c.Sources { + qr, err := source.Execute(vcursor, bindVars, wantfields) + if err != nil { + return nil, vterrors.Wrap(err, "Concatenate.Execute: ") + } + if wantfields { + wantfields = false + result.Fields = qr.Fields + } + if result.Fields != nil { + err = compareFields(result.Fields, qr.Fields) + if err != nil { + return nil, err + } + } + if len(qr.Rows) > 0 { + result.Rows = append(result.Rows, qr.Rows...) + if len(result.Rows[0]) != len(qr.Rows[0]) { + return nil, mysql.NewSQLError(mysql.ERWrongNumberOfColumnsInSelect, "21000", "The used SELECT statements have a different number of columns") + } + result.RowsAffected += qr.RowsAffected + } + } + return result, nil } -func (c Concatenate) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantields bool, callback func(*sqltypes.Result) error) error { +// StreamExecute performs a streaming exec. +func (c *Concatenate) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantields bool, callback func(*sqltypes.Result) error) error { panic("implement me") } -func (c Concatenate) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { +// GetFields fetches the field info. +func (c *Concatenate) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { firstQr, err := c.Sources[0].GetFields(vcursor, bindVars) if err != nil { return nil, err @@ -62,16 +96,16 @@ func (c Concatenate) GetFields(vcursor VCursor, bindVars map[string]*querypb.Bin if err != nil { return nil, err } - for i, field := range qr.Fields { - if firstQr.Fields[i].Type != field.Type { - return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "column field type does not match for name: (%v, %v) types: (%v, %v)", firstQr.Fields[i].Name, field.Name, firstQr.Fields[i].Type, field.Type) - } + err = compareFields(firstQr.Fields, qr.Fields) + if err != nil { + return nil, err } } return firstQr, nil } -func (c Concatenate) NeedsTransaction() bool { +//NeedsTransaction returns whether a transaction is needed for this primitive +func (c *Concatenate) NeedsTransaction() bool { for _, source := range c.Sources { if source.NeedsTransaction() { return true @@ -80,10 +114,24 @@ func (c Concatenate) NeedsTransaction() bool { return false } -func (c Concatenate) Inputs() []Primitive { +// Inputs returns the input primitives for this +func (c *Concatenate) Inputs() []Primitive { return c.Sources } -func (c Concatenate) description() PrimitiveDescription { +func (c *Concatenate) description() PrimitiveDescription { return PrimitiveDescription{OperatorType: c.RouteType()} } + +func compareFields(fields1 []*querypb.Field, fields2 []*querypb.Field) error { + if len(fields1) != len(fields2) { + return mysql.NewSQLError(mysql.ERWrongNumberOfColumnsInSelect, "21000", "The used SELECT statements have a different number of columns") + } + for i, field2 := range fields2 { + field1 := fields1[i] + if field1.Type != field2.Type { + return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "column field type does not match for name: (%v, %v) types: (%v, %v)", field1.Name, field2.Name, field1.Type, field2.Type) + } + } + return nil +} diff --git a/go/vt/vtgate/engine/concatenate_test.go b/go/vt/vtgate/engine/concatenate_test.go index 89a294ecb48..fa9085d125d 100644 --- a/go/vt/vtgate/engine/concatenate_test.go +++ b/go/vt/vtgate/engine/concatenate_test.go @@ -9,31 +9,54 @@ import ( func TestConcatenate_Execute(t *testing.T) { type testCase struct { - testName string - inputs []*sqltypes.Result - expectedResult *sqltypes.Result - expectedError string + testName string + inputs []*sqltypes.Result + expectedResult *sqltypes.Result + expectedError string + skipTestWithFalseWantsFields bool } testCases := []*testCase{ { - testName: "2 empty result", + testName: "empty results", inputs: []*sqltypes.Result{ - sqltypes.MakeTestResult(sqltypes.MakeTestFields("", ""), ""), - sqltypes.MakeTestResult(sqltypes.MakeTestFields("", ""), ""), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary")), }, - expectedResult: nil, - expectedError: "", + expectedResult: sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary")), }, { - testName: "3 empty result", + testName: "2 non empty result", inputs: []*sqltypes.Result{ - sqltypes.MakeTestResult(sqltypes.MakeTestFields("", ""), ""), - sqltypes.MakeTestResult(sqltypes.MakeTestFields("", ""), ""), - sqltypes.MakeTestResult(sqltypes.MakeTestFields("", ""), ""), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("myid|mycol1|mycol2", "int64|varchar|varbinary"), "11|m1|n1", "22|m2|n2"), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), }, - expectedResult: nil, - expectedError: "", + expectedResult: sqltypes.MakeTestResult(sqltypes.MakeTestFields("myid|mycol1|mycol2", "int64|varchar|varbinary"), "11|m1|n1", "22|m2|n2", "1|a1|b1", "2|a2|b2"), + }, + { + testName: "mismatch field type", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary"), "1|a1|b1", "2|a2|b2"), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col3|col4", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), + }, + expectedError: "column field type does not match for name: (col1, col3) types: (VARBINARY, VARCHAR)", + skipTestWithFalseWantsFields: true, + }, + { + testName: "input source has different column count", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varchar"), "1|a1|b1", "2|a2|b2"), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col3|col4|col5", "int64|varchar|varchar|int32"), "1|a1|b1|5", "2|a2|b2|6"), + }, + expectedError: "The used SELECT statements have a different number of columns (errno 1222) (sqlstate 21000)", + }, + { + testName: "1 empty result and 1 non empty result", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("myid|mycol1|mycol2", "int64|varchar|varbinary")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), + }, + expectedResult: sqltypes.MakeTestResult(sqltypes.MakeTestFields("myid|mycol1|mycol2", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), }, } @@ -41,7 +64,7 @@ func TestConcatenate_Execute(t *testing.T) { t.Run(tc.testName, func(t *testing.T) { var fps []Primitive for _, input := range tc.inputs { - fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{input}}) + fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{input, input}}) } concatenate := Concatenate{Sources: fps} qr, err := concatenate.Execute(&noopVCursor{}, nil, true) @@ -51,6 +74,17 @@ func TestConcatenate_Execute(t *testing.T) { } else { require.EqualError(t, err, tc.expectedError) } + if tc.skipTestWithFalseWantsFields { + return + } + qr, err = concatenate.Execute(&noopVCursor{}, nil, false) + if tc.expectedError == "" { + require.NoError(t, err) + tc.expectedResult.Fields = nil + require.Equal(t, tc.expectedResult, qr) + } else { + require.EqualError(t, err, tc.expectedError) + } }) } From de3d16a438c6a023844edf47eba7d92f3d7f8071 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Mon, 22 Jun 2020 17:47:31 +0200 Subject: [PATCH 03/13] union-all: implemented StreamExecute Signed-off-by: Andres Taylor --- go/vt/vtgate/engine/concatenate.go | 62 +++++++++-- go/vt/vtgate/engine/concatenate_test.go | 137 +++++++++++++----------- 2 files changed, 125 insertions(+), 74 deletions(-) diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 32970517f13..45331c3cbe6 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -1,3 +1,19 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package engine import ( @@ -54,17 +70,14 @@ func (c *Concatenate) Execute(vcursor VCursor, bindVars map[string]*querypb.Bind for _, source := range c.Sources { qr, err := source.Execute(vcursor, bindVars, wantfields) if err != nil { - return nil, vterrors.Wrap(err, "Concatenate.Execute: ") + return nil, vterrors.Wrap(err, "Concatenate.Execute") } - if wantfields { - wantfields = false + if result.Fields == nil { result.Fields = qr.Fields } - if result.Fields != nil { - err = compareFields(result.Fields, qr.Fields) - if err != nil { - return nil, err - } + err = compareFields(result.Fields, qr.Fields) + if err != nil { + return nil, err } if len(qr.Rows) > 0 { result.Rows = append(result.Rows, qr.Rows...) @@ -78,8 +91,37 @@ func (c *Concatenate) Execute(vcursor VCursor, bindVars map[string]*querypb.Bind } // StreamExecute performs a streaming exec. -func (c *Concatenate) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantields bool, callback func(*sqltypes.Result) error) error { - panic("implement me") +func (c *Concatenate) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { + var seenFields []*querypb.Field + columnCount := 0 + for _, source := range c.Sources { + err := source.StreamExecute(vcursor, bindVars, wantfields, func(resultChunk *sqltypes.Result) error { + // if we have fields to compare, make sure all the fields are all the same + if seenFields == nil { + seenFields = resultChunk.Fields + } else if resultChunk.Fields != nil { + err := compareFields(seenFields, resultChunk.Fields) + if err != nil { + return err + } + } + if len(resultChunk.Rows) > 0 { + if columnCount == 0 { + columnCount = len(resultChunk.Rows[0]) + } else { + if len(resultChunk.Rows[0]) != columnCount { + return mysql.NewSQLError(mysql.ERWrongNumberOfColumnsInSelect, "21000", "The usasdfasded SELECT statements have a different number of columns") + } + } + } + + return callback(resultChunk) + }) + if err != nil { + return err + } + } + return nil } // GetFields fetches the field info. diff --git a/go/vt/vtgate/engine/concatenate_test.go b/go/vt/vtgate/engine/concatenate_test.go index fa9085d125d..796eb3da73f 100644 --- a/go/vt/vtgate/engine/concatenate_test.go +++ b/go/vt/vtgate/engine/concatenate_test.go @@ -1,3 +1,19 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package engine import ( @@ -9,83 +25,76 @@ import ( func TestConcatenate_Execute(t *testing.T) { type testCase struct { - testName string - inputs []*sqltypes.Result - expectedResult *sqltypes.Result - expectedError string - skipTestWithFalseWantsFields bool + testName string + inputs []*sqltypes.Result + expectedResult *sqltypes.Result + expectedError string } - testCases := []*testCase{ - { - testName: "empty results", - inputs: []*sqltypes.Result{ - sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary")), - sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary")), - }, - expectedResult: sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary")), + testCases := []*testCase{{ + testName: "empty results", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary")), }, - { - testName: "2 non empty result", - inputs: []*sqltypes.Result{ - sqltypes.MakeTestResult(sqltypes.MakeTestFields("myid|mycol1|mycol2", "int64|varchar|varbinary"), "11|m1|n1", "22|m2|n2"), - sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), - }, - expectedResult: sqltypes.MakeTestResult(sqltypes.MakeTestFields("myid|mycol1|mycol2", "int64|varchar|varbinary"), "11|m1|n1", "22|m2|n2", "1|a1|b1", "2|a2|b2"), + expectedResult: sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary")), + }, { + testName: "2 non empty result", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("myid|mycol1|mycol2", "int64|varchar|varbinary"), "11|m1|n1", "22|m2|n2"), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), }, - { - testName: "mismatch field type", - inputs: []*sqltypes.Result{ - sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary"), "1|a1|b1", "2|a2|b2"), - sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col3|col4", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), - }, - expectedError: "column field type does not match for name: (col1, col3) types: (VARBINARY, VARCHAR)", - skipTestWithFalseWantsFields: true, + expectedResult: sqltypes.MakeTestResult(sqltypes.MakeTestFields("myid|mycol1|mycol2", "int64|varchar|varbinary"), "11|m1|n1", "22|m2|n2", "1|a1|b1", "2|a2|b2"), + }, { + testName: "mismatch field type", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary"), "1|a1|b1", "2|a2|b2"), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col3|col4", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), }, - { - testName: "input source has different column count", - inputs: []*sqltypes.Result{ - sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varchar"), "1|a1|b1", "2|a2|b2"), - sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col3|col4|col5", "int64|varchar|varchar|int32"), "1|a1|b1|5", "2|a2|b2|6"), - }, - expectedError: "The used SELECT statements have a different number of columns (errno 1222) (sqlstate 21000)", + expectedError: "column field type does not match for name: (col1, col3) types: (VARBINARY, VARCHAR)", + }, { + testName: "input source has different column count", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varchar"), "1|a1|b1", "2|a2|b2"), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col3|col4|col5", "int64|varchar|varchar|int32"), "1|a1|b1|5", "2|a2|b2|6"), }, - { - testName: "1 empty result and 1 non empty result", - inputs: []*sqltypes.Result{ - sqltypes.MakeTestResult(sqltypes.MakeTestFields("myid|mycol1|mycol2", "int64|varchar|varbinary")), - sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), - }, - expectedResult: sqltypes.MakeTestResult(sqltypes.MakeTestFields("myid|mycol1|mycol2", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), + expectedError: "The used SELECT statements have a different number of columns (errno 1222) (sqlstate 21000)", + }, { + testName: "1 empty result and 1 non empty result", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("myid|mycol1|mycol2", "int64|varchar|varbinary")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), }, - } + expectedResult: sqltypes.MakeTestResult(sqltypes.MakeTestFields("myid|mycol1|mycol2", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), + }} for _, tc := range testCases { t.Run(tc.testName, func(t *testing.T) { var fps []Primitive for _, input := range tc.inputs { - fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{input, input}}) - } - concatenate := Concatenate{Sources: fps} - qr, err := concatenate.Execute(&noopVCursor{}, nil, true) - if tc.expectedError == "" { - require.NoError(t, err) - require.Equal(t, tc.expectedResult, qr) - } else { - require.EqualError(t, err, tc.expectedError) - } - if tc.skipTestWithFalseWantsFields { - return - } - qr, err = concatenate.Execute(&noopVCursor{}, nil, false) - if tc.expectedError == "" { - require.NoError(t, err) - tc.expectedResult.Fields = nil - require.Equal(t, tc.expectedResult, qr) - } else { - require.EqualError(t, err, tc.expectedError) + fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{input, input, input, input}}) } + concatenate := &Concatenate{Sources: fps} + + t.Run("Execute wantfields true", func(t *testing.T) { + qr, err := concatenate.Execute(&noopVCursor{}, nil, true) + if tc.expectedError == "" { + require.NoError(t, err) + require.Equal(t, tc.expectedResult, qr) + } else { + require.EqualError(t, err, tc.expectedError) + } + }) + + t.Run("StreamExecute wantfields true", func(t *testing.T) { + qr, err := wrapStreamExecute(concatenate, &noopVCursor{}, nil, true) + if tc.expectedError == "" { + require.NoError(t, err) + require.Equal(t, tc.expectedResult, qr) + } else { + require.EqualError(t, err, tc.expectedError) + } + }) }) } - } From 5d56e28d2efa16ac311ce31e4f0e510b80133b83 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Tue, 23 Jun 2020 16:38:06 +0200 Subject: [PATCH 04/13] union-all: added endtoendtest Signed-off-by: Andres Taylor --- go/test/endtoend/vtgate/misc_test.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/go/test/endtoend/vtgate/misc_test.go b/go/test/endtoend/vtgate/misc_test.go index 3ee049f0e6d..684d3a5de1b 100644 --- a/go/test/endtoend/vtgate/misc_test.go +++ b/go/test/endtoend/vtgate/misc_test.go @@ -558,6 +558,22 @@ func TestCastConvert(t *testing.T) { assertMatches(t, conn, `SELECT CAST("test" AS CHAR(60))`, `[[VARCHAR("test")]]`) } +func TestUnionAll(t *testing.T) { + conn, err := mysql.Connect(context.Background(), &vtParams) + require.NoError(t, err) + defer conn.Close() + + exec(t, conn, "insert into t1(id1, id2) values(1, 1), (2, 2)") + exec(t, conn, "insert into t2(id3, id4) values(3, 3), (4, 4)") + + // union all between two selectuniqueequal + assertMatches(t, conn, "select id1 from t1 where id1 = 1 union all select id1 from t1 where id1 = 4", "[[INT64(1)]]") + + // union all between two different tables + assertMatches(t, conn, "(select id1,id2 from t1 order by id1) union all (select id3,id4 from t2 order by id3)", + "[[INT64(1) INT64(1)] [INT64(2) INT64(2)] [INT64(3) INT64(3)] [INT64(4) INT64(4)]]") +} + func TestUnion(t *testing.T) { conn, err := mysql.Connect(context.Background(), &vtParams) require.NoError(t, err) From a9b9898a1f039c323ad44a95ab830c5b178193d7 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Tue, 23 Jun 2020 15:45:42 +0200 Subject: [PATCH 05/13] union-all: added planning for UNION ALL Signed-off-by: Andres Taylor --- go/test/endtoend/vtgate/misc_test.go | 5 +- go/vt/sqlparser/ast.go | 1 + go/vt/sqlparser/ast_funcs.go | 19 +- go/vt/vtgate/planbuilder/builder.go | 3 + go/vt/vtgate/planbuilder/concatenate.go | 124 +++++++++ go/vt/vtgate/planbuilder/join.go | 10 + go/vt/vtgate/planbuilder/limit.go | 5 + go/vt/vtgate/planbuilder/memory_sort.go | 5 + go/vt/vtgate/planbuilder/merge_sort.go | 10 +- go/vt/vtgate/planbuilder/ordered_aggregate.go | 5 + go/vt/vtgate/planbuilder/plan_test.go | 7 +- go/vt/vtgate/planbuilder/pullout_subquery.go | 10 + go/vt/vtgate/planbuilder/route.go | 11 +- go/vt/vtgate/planbuilder/subquery.go | 5 + go/vt/vtgate/planbuilder/testdata/onecase.txt | 5 +- .../planbuilder/testdata/union_cases.txt | 243 ++++++++++++++++++ .../testdata/unsupported_cases.txt | 28 +- go/vt/vtgate/planbuilder/union.go | 82 +++++- go/vt/vtgate/planbuilder/vindex_func.go | 5 + 19 files changed, 542 insertions(+), 41 deletions(-) create mode 100644 go/vt/vtgate/planbuilder/concatenate.go create mode 100644 go/vt/vtgate/planbuilder/testdata/union_cases.txt diff --git a/go/test/endtoend/vtgate/misc_test.go b/go/test/endtoend/vtgate/misc_test.go index 684d3a5de1b..d3f379f0d78 100644 --- a/go/test/endtoend/vtgate/misc_test.go +++ b/go/test/endtoend/vtgate/misc_test.go @@ -563,6 +563,9 @@ func TestUnionAll(t *testing.T) { require.NoError(t, err) defer conn.Close() + exec(t, conn, "delete from t1") + exec(t, conn, "delete from t2") + exec(t, conn, "insert into t1(id1, id2) values(1, 1), (2, 2)") exec(t, conn, "insert into t2(id3, id4) values(3, 3), (4, 4)") @@ -585,7 +588,7 @@ func TestUnion(t *testing.T) { assertMatches(t, conn, `SELECT 1,'a' UNION ALL SELECT 1,'a' UNION ALL SELECT 1,'a' ORDER BY 1`, `[[INT64(1) VARCHAR("a")] [INT64(1) VARCHAR("a")] [INT64(1) VARCHAR("a")]]`) assertMatches(t, conn, `(SELECT 1,'a') UNION ALL (SELECT 1,'a') UNION ALL (SELECT 1,'a') ORDER BY 1`, `[[INT64(1) VARCHAR("a")] [INT64(1) VARCHAR("a")] [INT64(1) VARCHAR("a")]]`) assertMatches(t, conn, `(SELECT 1,'a') ORDER BY 1`, `[[INT64(1) VARCHAR("a")]]`) - assertMatches(t, conn, `(SELECT 1,'a' order by 1) union SELECT 1,'a' ORDER BY 1`, `[[INT64(1) VARCHAR("a")]]`) + assertMatches(t, conn, `(SELECT 1,'a' order by 1) union (SELECT 1,'a' ORDER BY 1)`, `[[INT64(1) VARCHAR("a")]]`) } func assertMatches(t *testing.T, conn *mysql.Conn, query, expected string) { diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 06e8d6e706f..60a699d993a 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -50,6 +50,7 @@ type ( iInsertRows() AddOrder(*Order) SetLimit(*Limit) + SetLock(lock string) SQLNode } diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 82844737e46..cd966f82426 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -717,6 +717,11 @@ func (node *Select) SetLimit(limit *Limit) { node.Limit = limit } +// SetLock sets the lock clause +func (node *Select) SetLock(lock string) { + node.Lock = lock +} + // AddWhere adds the boolean expression to the // WHERE clause as an AND condition. func (node *Select) AddWhere(expr Expr) { @@ -751,12 +756,17 @@ func (node *Select) AddHaving(expr Expr) { // AddOrder adds an order by element func (node *ParenSelect) AddOrder(order *Order) { - panic("unreachable") + node.Select.AddOrder(order) } // SetLimit sets the limit clause func (node *ParenSelect) SetLimit(limit *Limit) { - panic("unreachable") + node.Select.SetLimit(limit) +} + +// SetLock sets the lock clause +func (node *ParenSelect) SetLock(lock string) { + node.Select.SetLock(lock) } // AddOrder adds an order by element @@ -769,6 +779,11 @@ func (node *Union) SetLimit(limit *Limit) { node.Limit = limit } +// SetLock sets the lock clause +func (node *Union) SetLock(lock string) { + node.Lock = lock +} + //Unionize returns a UNION, either creating one or adding SELECT to an existing one func Unionize(lhs, rhs SelectStatement, typ string, by OrderBy, limit *Limit, lock string) *Union { union, isUnion := lhs.(*Union) diff --git a/go/vt/vtgate/planbuilder/builder.go b/go/vt/vtgate/planbuilder/builder.go index 175f61846d1..ae348a23290 100644 --- a/go/vt/vtgate/planbuilder/builder.go +++ b/go/vt/vtgate/planbuilder/builder.go @@ -110,6 +110,9 @@ type builder interface { // specified column. SupplyWeightString(colNumber int) (weightcolNumber int, err error) + // PushLock pushes "FOR UPDATE", "LOCK IN SHARE MODE" down to all routes + PushLock(lock string) error + // Primitive returns the underlying primitive. // This function should only be called after Wireup is finished. Primitive() engine.Primitive diff --git a/go/vt/vtgate/planbuilder/concatenate.go b/go/vt/vtgate/planbuilder/concatenate.go new file mode 100644 index 00000000000..fe3e4d8eea9 --- /dev/null +++ b/go/vt/vtgate/planbuilder/concatenate.go @@ -0,0 +1,124 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package planbuilder + +import ( + "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/engine" +) + +type concatenate struct { + lhs, rhs builder + order int +} + +var _ builder = (*concatenate)(nil) + +func (c *concatenate) Order() int { + return c.order +} + +func (c *concatenate) ResultColumns() []*resultColumn { + return c.lhs.ResultColumns() +} + +func (c *concatenate) Reorder(order int) { + c.lhs.Reorder(order) + c.rhs.Reorder(c.lhs.Order()) + c.order = c.rhs.Order() + 1 +} + +func (c *concatenate) First() builder { + panic("implement me") +} + +func (c *concatenate) SetUpperLimit(count *sqlparser.SQLVal) { + // not doing anything by design +} + +func (c *concatenate) PushMisc(sel *sqlparser.Select) { + c.lhs.PushMisc(sel) + c.rhs.PushMisc(sel) +} + +func (c *concatenate) Wireup(bldr builder, jt *jointab) error { + // TODO systay should we do something different here? + err := c.lhs.Wireup(bldr, jt) + if err != nil { + return err + } + return c.rhs.Wireup(bldr, jt) +} + +func (c *concatenate) SupplyVar(from, to int, col *sqlparser.ColName, varname string) { + panic("implement me") +} + +func (c *concatenate) SupplyCol(col *sqlparser.ColName) (rc *resultColumn, colNumber int) { + panic("implement me") +} + +func (c *concatenate) SupplyWeightString(colNumber int) (weightcolNumber int, err error) { + panic("implement me") +} + +func (c *concatenate) PushFilter(pb *primitiveBuilder, filter sqlparser.Expr, whereType string, origin builder) error { + return unreachable("Filter") +} + +func (c *concatenate) PushSelect(pb *primitiveBuilder, expr *sqlparser.AliasedExpr, origin builder) (rc *resultColumn, colNumber int, err error) { + return nil, 0, unreachable("Select") +} + +func (c *concatenate) MakeDistinct() error { + return vterrors.New(vtrpc.Code_UNIMPLEMENTED, "only union-all is supported for this operator") +} + +func (c *concatenate) PushGroupBy(by sqlparser.GroupBy) error { + return unreachable("GroupBy") +} + +func (c *concatenate) PushOrderBy(by sqlparser.OrderBy) (builder, error) { + if by == nil { + return c, nil + } + return nil, unreachable("OrderBy") +} + +func (c *concatenate) Primitive() engine.Primitive { + lhs := c.lhs.Primitive() + rhs := c.rhs.Primitive() + + return &engine.Concatenate{ + Sources: []engine.Primitive{lhs, rhs}, + } +} + +// PushLock satisfies the builder interface. +func (c *concatenate) PushLock(lock string) error { + err := c.lhs.PushLock(lock) + if err != nil { + return err + } + return c.rhs.PushLock(lock) +} + +func unreachable(name string) error { + return vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "concatenate.%s: unreachable", name) +} diff --git a/go/vt/vtgate/planbuilder/join.go b/go/vt/vtgate/planbuilder/join.go index 6ec4abffce4..0ee90033f4c 100644 --- a/go/vt/vtgate/planbuilder/join.go +++ b/go/vt/vtgate/planbuilder/join.go @@ -134,6 +134,16 @@ func (jb *join) Primitive() engine.Primitive { return jb.ejoin } +// PushLock satisfies the builder interface. +func (jb *join) PushLock(lock string) error { + err := jb.Left.PushLock(lock) + if err != nil { + return err + } + + return jb.Right.PushLock(lock) +} + // First satisfies the builder interface. func (jb *join) First() builder { return jb.Left.First() diff --git a/go/vt/vtgate/planbuilder/limit.go b/go/vt/vtgate/planbuilder/limit.go index c6830400a05..210da3399c1 100644 --- a/go/vt/vtgate/planbuilder/limit.go +++ b/go/vt/vtgate/planbuilder/limit.go @@ -50,6 +50,11 @@ func (l *limit) Primitive() engine.Primitive { return l.elimit } +// PushLock satisfies the builder interface. +func (l *limit) PushLock(lock string) error { + return l.input.PushLock(lock) +} + // PushFilter satisfies the builder interface. func (l *limit) PushFilter(_ *primitiveBuilder, _ sqlparser.Expr, whereType string, _ builder) error { return errors.New("limit.PushFilter: unreachable") diff --git a/go/vt/vtgate/planbuilder/memory_sort.go b/go/vt/vtgate/planbuilder/memory_sort.go index 8fed34d0ba9..9255af36d79 100644 --- a/go/vt/vtgate/planbuilder/memory_sort.go +++ b/go/vt/vtgate/planbuilder/memory_sort.go @@ -83,6 +83,11 @@ func (ms *memorySort) Primitive() engine.Primitive { return ms.eMemorySort } +// PushLock satisfies the builder interface. +func (ms *memorySort) PushLock(lock string) error { + return ms.input.PushLock(lock) +} + // PushFilter satisfies the builder interface. func (ms *memorySort) PushFilter(_ *primitiveBuilder, _ sqlparser.Expr, whereType string, _ builder) error { return errors.New("memorySort.PushFilter: unreachable") diff --git a/go/vt/vtgate/planbuilder/merge_sort.go b/go/vt/vtgate/planbuilder/merge_sort.go index 0b423321041..249d15604e5 100644 --- a/go/vt/vtgate/planbuilder/merge_sort.go +++ b/go/vt/vtgate/planbuilder/merge_sort.go @@ -17,7 +17,8 @@ limitations under the License. package planbuilder import ( - "errors" + "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" @@ -59,6 +60,11 @@ func (ms *mergeSort) Primitive() engine.Primitive { return ms.input.Primitive() } +// PushLock satisfies the builder interface. +func (ms *mergeSort) PushLock(lock string) error { + return ms.input.PushLock(lock) +} + // PushFilter satisfies the builder interface. func (ms *mergeSort) PushFilter(pb *primitiveBuilder, expr sqlparser.Expr, whereType string, origin builder) error { return ms.input.PushFilter(pb, expr, whereType, origin) @@ -83,7 +89,7 @@ func (ms *mergeSort) PushGroupBy(groupBy sqlparser.GroupBy) error { // A merge sort is created due to the push of an ORDER BY clause. // So, this function should never get called. func (ms *mergeSort) PushOrderBy(orderBy sqlparser.OrderBy) (builder, error) { - return nil, errors.New("mergeSort.PushOrderBy: unreachable") + return nil, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "can't do ORDER BY on top of ORDER BY") } // Wireup satisfies the builder interface. diff --git a/go/vt/vtgate/planbuilder/ordered_aggregate.go b/go/vt/vtgate/planbuilder/ordered_aggregate.go index bb9853c4eaf..6a7324e6f5e 100644 --- a/go/vt/vtgate/planbuilder/ordered_aggregate.go +++ b/go/vt/vtgate/planbuilder/ordered_aggregate.go @@ -243,6 +243,11 @@ func (oa *orderedAggregate) Primitive() engine.Primitive { return oa.eaggr } +// PushLock satisfies the builder interface. +func (oa *orderedAggregate) PushLock(lock string) error { + return oa.input.PushLock(lock) +} + // PushFilter satisfies the builder interface. func (oa *orderedAggregate) PushFilter(_ *primitiveBuilder, _ sqlparser.Expr, whereType string, _ builder) error { return errors.New("unsupported: filtering on results of aggregates") diff --git a/go/vt/vtgate/planbuilder/plan_test.go b/go/vt/vtgate/planbuilder/plan_test.go index 38582e14cce..f91aa539210 100644 --- a/go/vt/vtgate/planbuilder/plan_test.go +++ b/go/vt/vtgate/planbuilder/plan_test.go @@ -26,10 +26,11 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" - "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" @@ -171,6 +172,7 @@ func TestPlan(t *testing.T) { testFile(t, "use_cases.txt", testOutputTempDir, vschemaWrapper) testFile(t, "set_cases.txt", testOutputTempDir, vschemaWrapper) testFile(t, "set_sysvar_cases.txt", testOutputTempDir, vschemaWrapper) + testFile(t, "union_cases.txt", testOutputTempDir, vschemaWrapper) } func TestOne(t *testing.T) { @@ -331,7 +333,7 @@ func testFile(t *testing.T, filename, tempDir string, vschema *vschemaWrapper) { if out != tcase.output { fail = true - t.Errorf("File: %s, Line: %v\n %s \n%s", filename, tcase.lineno, cmp.Diff(tcase.output, out), out) + t.Errorf("File: %s, Line: %d\nDiff:\n%s\n[%s] \n[%s]", filename, tcase.lineno, cmp.Diff(tcase.output, out), tcase.output, out) } if err != nil { @@ -339,7 +341,6 @@ func testFile(t *testing.T, filename, tempDir string, vschema *vschemaWrapper) { } expected.WriteString(fmt.Sprintf("%s\"%s\"\n%s\n\n", tcase.comments, tcase.input, out)) - }) } if fail && tempDir != "" { diff --git a/go/vt/vtgate/planbuilder/pullout_subquery.go b/go/vt/vtgate/planbuilder/pullout_subquery.go index 35d189df8b1..9977b0a6359 100644 --- a/go/vt/vtgate/planbuilder/pullout_subquery.go +++ b/go/vt/vtgate/planbuilder/pullout_subquery.go @@ -71,6 +71,16 @@ func (ps *pulloutSubquery) Primitive() engine.Primitive { return ps.eSubquery } +// PushLock satisfies the builder interface. +func (ps *pulloutSubquery) PushLock(lock string) error { + err := ps.subquery.PushLock(lock) + if err != nil { + return err + } + + return ps.underlying.PushLock(lock) +} + // First satisfies the builder interface. func (ps *pulloutSubquery) First() builder { return ps.underlying.First() diff --git a/go/vt/vtgate/planbuilder/route.go b/go/vt/vtgate/planbuilder/route.go index 6a7488e8441..4f810805cdf 100644 --- a/go/vt/vtgate/planbuilder/route.go +++ b/go/vt/vtgate/planbuilder/route.go @@ -92,6 +92,12 @@ func (rb *route) Primitive() engine.Primitive { return rb.routeOptions[0].eroute } +// PushLock satisfies the builder interface. +func (rb *route) PushLock(lock string) error { + rb.Select.SetLock(lock) + return nil +} + // First satisfies the builder interface. func (rb *route) First() builder { return rb @@ -390,7 +396,10 @@ func (rb *route) generateFieldQuery(sel sqlparser.SelectStatement, jt *jointab) sqlparser.FormatImpossibleQuery(buf, node) } - return sqlparser.NewTrackedBuffer(formatter).WriteNode(sel).ParsedQuery().Query + buffer := sqlparser.NewTrackedBuffer(formatter) + node := buffer.WriteNode(sel) + query := node.ParsedQuery() + return query.Query } // SupplyVar satisfies the builder interface. diff --git a/go/vt/vtgate/planbuilder/subquery.go b/go/vt/vtgate/planbuilder/subquery.go index 98c5772d73e..fcfeac63419 100644 --- a/go/vt/vtgate/planbuilder/subquery.go +++ b/go/vt/vtgate/planbuilder/subquery.go @@ -74,6 +74,11 @@ func (sq *subquery) Primitive() engine.Primitive { return sq.esubquery } +// PushLock satisfies the builder interface. +func (sq *subquery) PushLock(lock string) error { + return sq.input.PushLock(lock) +} + // First satisfies the builder interface. func (sq *subquery) First() builder { return sq diff --git a/go/vt/vtgate/planbuilder/testdata/onecase.txt b/go/vt/vtgate/planbuilder/testdata/onecase.txt index f821c90b4e5..e819513f354 100644 --- a/go/vt/vtgate/planbuilder/testdata/onecase.txt +++ b/go/vt/vtgate/planbuilder/testdata/onecase.txt @@ -1,4 +1 @@ -# union all using sharded -"select id from user union all select id from user" -{ -} +# Add your test case here for debugging and run go test -run=One. diff --git a/go/vt/vtgate/planbuilder/testdata/union_cases.txt b/go/vt/vtgate/planbuilder/testdata/union_cases.txt new file mode 100644 index 00000000000..f975b7e4e06 --- /dev/null +++ b/go/vt/vtgate/planbuilder/testdata/union_cases.txt @@ -0,0 +1,243 @@ +# union all between two scatter selects +"select id from user union all select id from music" +{ + "QueryType": "SELECT", + "Original": "select id from user union all select id from music", + "Instructions": { + "OperatorType": "Concatenate", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from user where 1 != 1", + "Query": "select id from user", + "Table": "user" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from music where 1 != 1", + "Query": "select id from music", + "Table": "music" + } + ] + } +} + +# union all between two SelectEqualUnique +"select id from user where id = 1 union all select id from user where id = 5" +{ + "QueryType": "SELECT", + "Original": "select id from user where id = 1 union all select id from user where id = 5", + "Instructions": { + "OperatorType": "Concatenate", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from user where 1 != 1", + "Query": "select id from user where id = 1", + "Table": "user", + "Values": [ + 1 + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from user where 1 != 1", + "Query": "select id from user where id = 5", + "Table": "user", + "Values": [ + 5 + ], + "Vindex": "user_index" + } + ] + } +} + +#almost dereks query - two queries with order by and limit being scattered to two different sets of tablets +"(SELECT id FROM user ORDER BY id DESC LIMIT 1) UNION ALL (SELECT id FROM music ORDER BY id DESC LIMIT 1)" +{ + "QueryType": "SELECT", + "Original": "(SELECT id FROM user ORDER BY id DESC LIMIT 1) UNION ALL (SELECT id FROM music ORDER BY id DESC LIMIT 1)", + "Instructions": { + "OperatorType": "Concatenate", + "Inputs": [ + { + "OperatorType": "Limit", + "Count": 1, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from user where 1 != 1", + "Query": "select id from user order by id desc limit :__upper_limit", + "Table": "user" + } + ] + }, + { + "OperatorType": "Limit", + "Count": 1, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from music where 1 != 1", + "Query": "select id from music order by id desc limit :__upper_limit", + "Table": "music" + } + ] + } + ] + } +} + +# Union all +"select col1, col2 from user union all select col1, col2 from user_extra" +{ + "QueryType": "SELECT", + "Original": "select col1, col2 from user union all select col1, col2 from user_extra", + "Instructions": { + "OperatorType": "Concatenate", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col1, col2 from user where 1 != 1", + "Query": "select col1, col2 from user", + "Table": "user" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col1, col2 from user_extra where 1 != 1", + "Query": "select col1, col2 from user_extra", + "Table": "user_extra" + } + ] + } +} + +# union operations in subqueries (FROM) +"select * from (select * from user union all select * from user_extra) as t" +{ + "QueryType": "SELECT", + "Original": "select * from (select * from user union all select * from user_extra) as t", + "Instructions": { + "OperatorType": "Subquery", + "Columns": [ + 0 + ], + "Inputs": [ + { + "OperatorType": "Concatenate", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select * from user where 1 != 1", + "Query": "select * from user", + "Table": "user" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select * from user_extra where 1 != 1", + "Query": "select * from user_extra", + "Table": "user_extra" + } + ] + } + ] + } +} + +# union all between two scatter selects, with order by +"(select id from user order by id limit 5) union all (select id from music order by id desc limit 5)" +{ + "QueryType": "SELECT", + "Original": "(select id from user order by id limit 5) union all (select id from music order by id desc limit 5)", + "Instructions": { + "OperatorType": "Concatenate", + "Inputs": [ + { + "OperatorType": "Limit", + "Count": 5, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from user where 1 != 1", + "Query": "select id from user order by id asc limit :__upper_limit", + "Table": "user" + } + ] + }, + { + "OperatorType": "Limit", + "Count": 5, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from music where 1 != 1", + "Query": "select id from music order by id desc limit :__upper_limit", + "Table": "music" + } + ] + } + ] + } +} diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt index 8943b38b889..ca4723ae58b 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt @@ -14,10 +14,6 @@ "show create database" "plan building not supported" -# union operations in subqueries (FROM) -"select * from (select * from user union all select * from user_extra) as t" -"unsupported: UNION cannot be executed as a single route" - # union operations in subqueries (expressions) "select * from user where id in (select * from user union select * from user_extra)" "unsupported: UNION cannot be executed as a single route" @@ -348,16 +344,12 @@ # multi-shard union "select 1 from music union (select id from user union all select name from unsharded)" -"unsupported: UNION cannot be executed as a single route" +"unsupported: SELECT of UNION is non-trivial" # multi-shard union "select 1 from music union (select id from user union select name from unsharded)" "unsupported: UNION cannot be executed as a single route" -# multi-shard union -"select id from user union all select id from music" -"unsupported: UNION cannot be executed as a single route" - # union with the same target shard because of vindex "select * from music where id = 1 union select * from user where id = 1" "unsupported: UNION cannot be executed as a single route" @@ -366,10 +358,6 @@ "select 1 from music where id = 1 union select 1 from music where id = 2" "unsupported: UNION cannot be executed as a single route" -# Union all -"select col1, col2 from user union all select col1, col2 from user_extra" -"unsupported: UNION cannot be executed as a single route" - "(select user.id, user.name from user join user_extra where user_extra.extra = 'asdf') union select 'b','c' from user" "unsupported: SELECT of UNION is non-trivial" @@ -414,12 +402,24 @@ # order by inside and outside parenthesis select "(select 1 from user order by 1 desc) order by 1 asc limit 2" -"expected union to produce a route" +"can't do ORDER BY on top of ORDER BY" # multiple select statement have inner order by with union "(select 1 from user order by 1 desc) union (select 1 from user order by 1 asc)" "unsupported: SELECT of UNION is non-trivial" +# different number of columns +"select id, 42 from user where id = 1 union all select id from user where id = 5" +"The used SELECT statements have a different number of columns (errno 1222) (sqlstate 21000) during query: select id, 42 from user where id = 1 union all select id from user where id = 5" + +# ambiguous ORDER BY +"select id from user order by id union all select id from music order by id desc" +"Incorrect usage of UNION and ORDER BY - add parens to disambiguate your query (errno 1221) (sqlstate 21000)" + +# ambiguous LIMIT +"select id from user limit 1 union all select id from music limit 1" +"Incorrect usage of UNION and LIMIT - add parens to disambiguate your query (errno 1221) (sqlstate 21000)" + # Savepoint "savepoint a" "unsupported: Savepoint construct savepoint a" diff --git a/go/vt/vtgate/planbuilder/union.go b/go/vt/vtgate/planbuilder/union.go index fb3845bc198..42f692cedb6 100644 --- a/go/vt/vtgate/planbuilder/union.go +++ b/go/vt/vtgate/planbuilder/union.go @@ -20,8 +20,7 @@ import ( "errors" "fmt" - "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine" @@ -41,25 +40,40 @@ func buildUnionPlan(stmt sqlparser.Statement, vschema ContextVSchema) (engine.Pr } func (pb *primitiveBuilder) processUnion(union *sqlparser.Union, outer *symtab) error { - if err := pb.processPart(union.FirstStatement, outer); err != nil { + if err := pb.processPart(union.FirstStatement, outer, false); err != nil { return err } for _, us := range union.UnionSelects { rpb := newPrimitiveBuilder(pb.vschema, pb.jt) - if err := rpb.processPart(us.Statement, outer); err != nil { + if err := rpb.processPart(us.Statement, outer, false); err != nil { return err } + err := unionRouteMerge(pb.bldr, rpb.bldr, us) + if err != nil { + if us.Type != sqlparser.UnionAllStr { + return err + } - if err := unionRouteMerge(pb.bldr, rpb.bldr); err != nil { - return err + // we are merging between two routes - let's check if we can see so that we have the same amount of columns on both sides of the union + lhsCols := len(pb.bldr.ResultColumns()) + rhsCols := len(rpb.bldr.ResultColumns()) + if lhsCols != rhsCols { + return &mysql.SQLError{ + Num: mysql.ERWrongNumberOfColumnsInSelect, + State: "21000", + Message: "The used SELECT statements have a different number of columns", + Query: sqlparser.String(union), + } + } + + pb.bldr = &concatenate{ + lhs: pb.bldr, + rhs: rpb.bldr, + } } pb.st.Outer = outer } - unionRoute, ok := pb.bldr.(*route) - if !ok { - return vterrors.Errorf(vtrpc.Code_INTERNAL, "expected union to produce a route") - } - unionRoute.Select = &sqlparser.Union{FirstStatement: union.FirstStatement, UnionSelects: union.UnionSelects, Lock: union.Lock} + pb.bldr.PushLock(union.Lock) if err := pb.pushOrderBy(union.OrderBy); err != nil { return err @@ -67,19 +81,52 @@ func (pb *primitiveBuilder) processUnion(union *sqlparser.Union, outer *symtab) return pb.pushLimit(union.Limit) } -func (pb *primitiveBuilder) processPart(part sqlparser.SelectStatement, outer *symtab) error { +func (pb *primitiveBuilder) processPart(part sqlparser.SelectStatement, outer *symtab, hasParens bool) error { switch part := part.(type) { case *sqlparser.Union: return pb.processUnion(part, outer) case *sqlparser.Select: + if !hasParens { + err := checkOrderByAndLimit(part) + if err != nil { + return err + } + } return pb.processSelect(part, outer) case *sqlparser.ParenSelect: - return pb.processPart(part.Select, outer) + err := pb.processPart(part.Select, outer, true) + if err != nil { + return err + } + // TODO: This is probably not a great idea. If we ended up with something other than a route, we'll lose the parens + routeOp, ok := pb.bldr.(*route) + if ok { + routeOp.Select = &sqlparser.ParenSelect{Select: routeOp.Select} + } + return nil } return fmt.Errorf("BUG: unexpected SELECT type: %T", part) } -func unionRouteMerge(left, right builder) error { +func checkOrderByAndLimit(part *sqlparser.Select) error { + if part.OrderBy != nil { + return &mysql.SQLError{ + Num: mysql.ERWrongUsage, + State: "21000", + Message: "Incorrect usage of UNION and ORDER BY - add parens to disambiguate your query", + } + } + if part.Limit != nil { + return &mysql.SQLError{ + Num: mysql.ERWrongUsage, + State: "21000", + Message: "Incorrect usage of UNION and LIMIT - add parens to disambiguate your query", + } + } + return nil +} + +func unionRouteMerge(left, right builder, us *sqlparser.UnionSelect) error { lroute, ok := left.(*route) if !ok { return errors.New("unsupported: SELECT of UNION is non-trivial") @@ -92,5 +139,12 @@ func unionRouteMerge(left, right builder) error { return errors.New("unsupported: UNION cannot be executed as a single route") } + switch n := lroute.Select.(type) { + case *sqlparser.Union: + n.UnionSelects = append(n.UnionSelects, us) + default: + lroute.Select = &sqlparser.Union{FirstStatement: lroute.Select, UnionSelects: []*sqlparser.UnionSelect{us}} + } + return nil } diff --git a/go/vt/vtgate/planbuilder/vindex_func.go b/go/vt/vtgate/planbuilder/vindex_func.go index 3902f4a1030..f158acbee24 100644 --- a/go/vt/vtgate/planbuilder/vindex_func.go +++ b/go/vt/vtgate/planbuilder/vindex_func.go @@ -82,6 +82,11 @@ func (vf *vindexFunc) Primitive() engine.Primitive { return vf.eVindexFunc } +// PushLock satisfies the builder interface. +func (vf *vindexFunc) PushLock(lock string) error { + return nil +} + // First satisfies the builder interface. func (vf *vindexFunc) First() builder { return vf From 7add5328ada2a5bdf2509b02ccf3261f589b1038 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Fri, 26 Jun 2020 19:05:32 +0530 Subject: [PATCH 06/13] union-all: addressed review comments Signed-off-by: Harshit Gangal --- go/vt/vtgate/engine/concatenate.go | 97 +++++++++++++++++++------ go/vt/vtgate/engine/concatenate_test.go | 8 +- 2 files changed, 78 insertions(+), 27 deletions(-) diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 45331c3cbe6..d11d7ded29a 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -19,6 +19,7 @@ package engine import ( "sort" "strings" + "sync" "vitess.io/vitess/go/mysql" @@ -67,15 +68,26 @@ func (c *Concatenate) GetTableName() string { // Execute performs a non-streaming exec. func (c *Concatenate) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { result := &sqltypes.Result{} - for _, source := range c.Sources { - qr, err := source.Execute(vcursor, bindVars, wantfields) - if err != nil { - return nil, vterrors.Wrap(err, "Concatenate.Execute") + var wg sync.WaitGroup + qrs := make([]*sqltypes.Result, len(c.Sources)) + errs := make([]error, len(c.Sources)) + for i, source := range c.Sources { + wg.Add(1) + go func(i int, source Primitive) { + defer wg.Done() + qrs[i], errs[i] = source.Execute(vcursor, bindVars, wantfields) + }(i, source) + } + wg.Wait() + for i := 0; i < len(c.Sources); i++ { + if errs[i] != nil { + return nil, vterrors.Wrap(errs[i], "Concatenate.Execute") } + qr := qrs[i] if result.Fields == nil { result.Fields = qr.Fields } - err = compareFields(result.Fields, qr.Fields) + err := compareFields(result.Fields, qr.Fields) if err != nil { return nil, err } @@ -94,29 +106,68 @@ func (c *Concatenate) Execute(vcursor VCursor, bindVars map[string]*querypb.Bind func (c *Concatenate) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { var seenFields []*querypb.Field columnCount := 0 - for _, source := range c.Sources { - err := source.StreamExecute(vcursor, bindVars, wantfields, func(resultChunk *sqltypes.Result) error { - // if we have fields to compare, make sure all the fields are all the same - if seenFields == nil { - seenFields = resultChunk.Fields - } else if resultChunk.Fields != nil { - err := compareFields(seenFields, resultChunk.Fields) - if err != nil { - return err + + // To return deterministic field names. + err := c.Sources[0].StreamExecute(vcursor, bindVars, wantfields, func(resultChunk *sqltypes.Result) error { + // if we have fields to compare, make sure all the fields are all the same + if seenFields == nil { + seenFields = resultChunk.Fields + } else if resultChunk.Fields != nil { + err := compareFields(seenFields, resultChunk.Fields) + if err != nil { + return err + } + } + if len(resultChunk.Rows) > 0 { + if columnCount == 0 { + columnCount = len(resultChunk.Rows[0]) + } else { + if len(resultChunk.Rows[0]) != columnCount { + return mysql.NewSQLError(mysql.ERWrongNumberOfColumnsInSelect, "21000", "The usasdfasded SELECT statements have a different number of columns") } } - if len(resultChunk.Rows) > 0 { - if columnCount == 0 { - columnCount = len(resultChunk.Rows[0]) - } else { - if len(resultChunk.Rows[0]) != columnCount { - return mysql.NewSQLError(mysql.ERWrongNumberOfColumnsInSelect, "21000", "The usasdfasded SELECT statements have a different number of columns") + } + + return callback(resultChunk) + }) + if err != nil { + return err + } + + var errs []error + var wg sync.WaitGroup + for i := 1; i < len(c.Sources); i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + err := c.Sources[i].StreamExecute(vcursor, bindVars, wantfields, func(resultChunk *sqltypes.Result) error { + // if we have fields to compare, make sure all the fields are all the same + if seenFields == nil { + seenFields = resultChunk.Fields + } else if resultChunk.Fields != nil { + err := compareFields(seenFields, resultChunk.Fields) + if err != nil { + return err } } - } + if len(resultChunk.Rows) > 0 { + if columnCount == 0 { + columnCount = len(resultChunk.Rows[0]) + } else { + if len(resultChunk.Rows[0]) != columnCount { + return mysql.NewSQLError(mysql.ERWrongNumberOfColumnsInSelect, "21000", "The usasdfasded SELECT statements have a different number of columns") + } + } + } + + return callback(resultChunk) + }) + errs = append(errs, err) + }(i) + } + wg.Wait() - return callback(resultChunk) - }) + for _, err := range errs { if err != nil { return err } diff --git a/go/vt/vtgate/engine/concatenate_test.go b/go/vt/vtgate/engine/concatenate_test.go index 796eb3da73f..cfde9ff577b 100644 --- a/go/vt/vtgate/engine/concatenate_test.go +++ b/go/vt/vtgate/engine/concatenate_test.go @@ -34,10 +34,10 @@ func TestConcatenate_Execute(t *testing.T) { testCases := []*testCase{{ testName: "empty results", inputs: []*sqltypes.Result{ - sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary")), - sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id1|col11|col12", "int64|varbinary|varbinary")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id2|col21|col22", "int64|varbinary|varbinary")), }, - expectedResult: sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary")), + expectedResult: sqltypes.MakeTestResult(sqltypes.MakeTestFields("id1|col11|col12", "int64|varbinary|varbinary")), }, { testName: "2 non empty result", inputs: []*sqltypes.Result{ @@ -72,7 +72,7 @@ func TestConcatenate_Execute(t *testing.T) { t.Run(tc.testName, func(t *testing.T) { var fps []Primitive for _, input := range tc.inputs { - fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{input, input, input, input}}) + fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{input, input, input, input, input, input}}) } concatenate := &Concatenate{Sources: fps} From 6206c35ea8c6928fb1d60610d9fae1bc8d37f04f Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Fri, 26 Jun 2020 19:18:58 +0530 Subject: [PATCH 07/13] union-all: concatenate stream-execute parallelize Signed-off-by: Harshit Gangal --- go/test/endtoend/vtgate/misc_test.go | 4 ++ go/vt/vtgate/engine/concatenate.go | 59 ++++++------------------- go/vt/vtgate/engine/concatenate_test.go | 1 + go/vt/vtgate/executor_select_test.go | 38 ++++++++++++++++ 4 files changed, 57 insertions(+), 45 deletions(-) diff --git a/go/test/endtoend/vtgate/misc_test.go b/go/test/endtoend/vtgate/misc_test.go index d3f379f0d78..da66cd5d8c2 100644 --- a/go/test/endtoend/vtgate/misc_test.go +++ b/go/test/endtoend/vtgate/misc_test.go @@ -575,6 +575,10 @@ func TestUnionAll(t *testing.T) { // union all between two different tables assertMatches(t, conn, "(select id1,id2 from t1 order by id1) union all (select id3,id4 from t2 order by id3)", "[[INT64(1) INT64(1)] [INT64(2) INT64(2)] [INT64(3) INT64(3)] [INT64(4) INT64(4)]]") + + // union all between two different tables + assertMatches(t, conn, "select tbl2.id1 FROM ((select id1 from t1 order by id1 limit 5) union all (select id1 from t1 order by id1 desc limit 5)) as tbl1 INNER JOIN t1 as tbl2 ON tbl1.id1 = tbl2.id1", + "[[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]") } func TestUnion(t *testing.T) { diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index d11d7ded29a..c7d78d1cffe 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -105,65 +105,34 @@ func (c *Concatenate) Execute(vcursor VCursor, bindVars map[string]*querypb.Bind // StreamExecute performs a streaming exec. func (c *Concatenate) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { var seenFields []*querypb.Field - columnCount := 0 - - // To return deterministic field names. - err := c.Sources[0].StreamExecute(vcursor, bindVars, wantfields, func(resultChunk *sqltypes.Result) error { - // if we have fields to compare, make sure all the fields are all the same - if seenFields == nil { - seenFields = resultChunk.Fields - } else if resultChunk.Fields != nil { - err := compareFields(seenFields, resultChunk.Fields) - if err != nil { - return err - } - } - if len(resultChunk.Rows) > 0 { - if columnCount == 0 { - columnCount = len(resultChunk.Rows[0]) - } else { - if len(resultChunk.Rows[0]) != columnCount { - return mysql.NewSQLError(mysql.ERWrongNumberOfColumnsInSelect, "21000", "The usasdfasded SELECT statements have a different number of columns") - } - } - } - - return callback(resultChunk) - }) - if err != nil { - return err - } + fieldsSent := false var errs []error - var wg sync.WaitGroup - for i := 1; i < len(c.Sources); i++ { + var wg, fieldset sync.WaitGroup + fieldset.Add(1) + for i, source := range c.Sources { wg.Add(1) - go func(i int) { + go func(i int, source Primitive) { defer wg.Done() - err := c.Sources[i].StreamExecute(vcursor, bindVars, wantfields, func(resultChunk *sqltypes.Result) error { + err := source.StreamExecute(vcursor, bindVars, wantfields, func(resultChunk *sqltypes.Result) error { // if we have fields to compare, make sure all the fields are all the same - if seenFields == nil { + if i == 0 && !fieldsSent { + defer fieldset.Done() seenFields = resultChunk.Fields - } else if resultChunk.Fields != nil { + fieldsSent = true + return callback(resultChunk) + } + fieldset.Wait() + if resultChunk.Fields != nil { err := compareFields(seenFields, resultChunk.Fields) if err != nil { return err } } - if len(resultChunk.Rows) > 0 { - if columnCount == 0 { - columnCount = len(resultChunk.Rows[0]) - } else { - if len(resultChunk.Rows[0]) != columnCount { - return mysql.NewSQLError(mysql.ERWrongNumberOfColumnsInSelect, "21000", "The usasdfasded SELECT statements have a different number of columns") - } - } - } - return callback(resultChunk) }) errs = append(errs, err) - }(i) + }(i, source) } wg.Wait() diff --git a/go/vt/vtgate/engine/concatenate_test.go b/go/vt/vtgate/engine/concatenate_test.go index cfde9ff577b..c09be81b123 100644 --- a/go/vt/vtgate/engine/concatenate_test.go +++ b/go/vt/vtgate/engine/concatenate_test.go @@ -36,6 +36,7 @@ func TestConcatenate_Execute(t *testing.T) { inputs: []*sqltypes.Result{ sqltypes.MakeTestResult(sqltypes.MakeTestFields("id1|col11|col12", "int64|varbinary|varbinary")), sqltypes.MakeTestResult(sqltypes.MakeTestFields("id2|col21|col22", "int64|varbinary|varbinary")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id3|col31|col32", "int64|varbinary|varbinary")), }, expectedResult: sqltypes.MakeTestResult(sqltypes.MakeTestFields("id1|col11|col12", "int64|varbinary|varbinary")), }, { diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 53d82ca3ed0..9c27df58629 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -2157,3 +2157,41 @@ func TestSelectBindvarswithPrepare(t *testing.T) { t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries) } } + +func TestSelectWithUnionAll(t *testing.T) { + executor, sbc1, sbc2, _ := createExecutorEnv() + executor.normalize = true + sql := "select id from user where id = 1 union select id from user where id = 1 union all select id from user where id in (1, 2, 3)" + _, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + require.NoError(t, err) + bv, err := sqltypes.BuildBindVariable([]int64{1, 2, 3}) + require.NoError(t, err) + bv1, err := sqltypes.BuildBindVariable([]int64{1, 2}) + require.NoError(t, err) + bv2, err := sqltypes.BuildBindVariable([]int64{3}) + require.NoError(t, err) + sbc1WantQueries := []*querypb.BoundQuery{{ + Sql: "select id from user where id = :vtg1 union select id from user where id = :vtg1", + BindVariables: map[string]*querypb.BindVariable{ + "vtg1": sqltypes.Int64BindVariable(1), + "vtg2": bv, + }, + }, { + Sql: "select id from user where id in ::__vals", + BindVariables: map[string]*querypb.BindVariable{ + "__vals": bv1, + "vtg1": sqltypes.Int64BindVariable(1), + "vtg2": bv, + }, + }} + sbc2WantQueries := []*querypb.BoundQuery{{ + Sql: "select id from user where id in ::__vals", + BindVariables: map[string]*querypb.BindVariable{ + "__vals": bv2, + "vtg1": sqltypes.Int64BindVariable(1), + "vtg2": bv, + }, + }} + utils.MustMatch(t, sbc2WantQueries, sbc2.Queries, "sbc2") + utils.MustMatch(t, sbc1WantQueries, sbc1.Queries, "sbc1") +} From c78910dc8a2e7afd958d1d63cc7decad1302ab5f Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Sat, 27 Jun 2020 00:58:34 +0530 Subject: [PATCH 08/13] union-all: fix data race Signed-off-by: Harshit Gangal --- go/vt/vtgate/engine/concatenate.go | 8 +++- go/vt/vtgate/executor_select_test.go | 42 ++++++++++++------- .../planbuilder/testdata/union_cases.txt | 39 +++++++++++++++++ go/vt/vttablet/sandboxconn/sandboxconn.go | 8 ++++ 4 files changed, 81 insertions(+), 16 deletions(-) diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index c7d78d1cffe..b6d8639b582 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -107,9 +107,10 @@ func (c *Concatenate) StreamExecute(vcursor VCursor, bindVars map[string]*queryp var seenFields []*querypb.Field fieldsSent := false - var errs []error + errs := make([]error, len(c.Sources)) var wg, fieldset sync.WaitGroup fieldset.Add(1) + var mu sync.Mutex for i, source := range c.Sources { wg.Add(1) go func(i int, source Primitive) { @@ -120,6 +121,7 @@ func (c *Concatenate) StreamExecute(vcursor VCursor, bindVars map[string]*queryp defer fieldset.Done() seenFields = resultChunk.Fields fieldsSent = true + // No other call can happen before this call. return callback(resultChunk) } fieldset.Wait() @@ -129,9 +131,11 @@ func (c *Concatenate) StreamExecute(vcursor VCursor, bindVars map[string]*queryp return err } } + mu.Lock() + defer mu.Unlock() return callback(resultChunk) }) - errs = append(errs, err) + errs[i] = err }(i, source) } wg.Wait() diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 9c27df58629..628b4796bff 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -2161,26 +2161,22 @@ func TestSelectBindvarswithPrepare(t *testing.T) { func TestSelectWithUnionAll(t *testing.T) { executor, sbc1, sbc2, _ := createExecutorEnv() executor.normalize = true - sql := "select id from user where id = 1 union select id from user where id = 1 union all select id from user where id in (1, 2, 3)" - _, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) - require.NoError(t, err) - bv, err := sqltypes.BuildBindVariable([]int64{1, 2, 3}) - require.NoError(t, err) - bv1, err := sqltypes.BuildBindVariable([]int64{1, 2}) - require.NoError(t, err) - bv2, err := sqltypes.BuildBindVariable([]int64{3}) - require.NoError(t, err) + sql := "select id from user where id in (1, 2, 3) union all select id from user where id in (1, 2, 3)" + bv, _ := sqltypes.BuildBindVariable([]int64{1, 2, 3}) + bv1, _ := sqltypes.BuildBindVariable([]int64{1, 2}) + bv2, _ := sqltypes.BuildBindVariable([]int64{3}) sbc1WantQueries := []*querypb.BoundQuery{{ - Sql: "select id from user where id = :vtg1 union select id from user where id = :vtg1", + Sql: "select id from user where id in ::__vals", BindVariables: map[string]*querypb.BindVariable{ - "vtg1": sqltypes.Int64BindVariable(1), - "vtg2": bv, + "__vals": bv1, + "vtg1": bv, + "vtg2": bv, }, }, { Sql: "select id from user where id in ::__vals", BindVariables: map[string]*querypb.BindVariable{ "__vals": bv1, - "vtg1": sqltypes.Int64BindVariable(1), + "vtg1": bv, "vtg2": bv, }, }} @@ -2188,10 +2184,28 @@ func TestSelectWithUnionAll(t *testing.T) { Sql: "select id from user where id in ::__vals", BindVariables: map[string]*querypb.BindVariable{ "__vals": bv2, - "vtg1": sqltypes.Int64BindVariable(1), + "vtg1": bv, + "vtg2": bv, + }, + }, { + Sql: "select id from user where id in ::__vals", + BindVariables: map[string]*querypb.BindVariable{ + "__vals": bv2, + "vtg1": bv, "vtg2": bv, }, }} + _, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + require.NoError(t, err) + utils.MustMatch(t, sbc1WantQueries, sbc1.Queries, "sbc1") utils.MustMatch(t, sbc2WantQueries, sbc2.Queries, "sbc2") + + // Reset + sbc1.Queries = nil + sbc2.Queries = nil + + _, err = executorStream(executor, sql) + require.NoError(t, err) utils.MustMatch(t, sbc1WantQueries, sbc1.Queries, "sbc1") + utils.MustMatch(t, sbc2WantQueries, sbc2.Queries, "sbc2") } diff --git a/go/vt/vtgate/planbuilder/testdata/union_cases.txt b/go/vt/vtgate/planbuilder/testdata/union_cases.txt index f975b7e4e06..8a88dc48f0c 100644 --- a/go/vt/vtgate/planbuilder/testdata/union_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/union_cases.txt @@ -241,3 +241,42 @@ ] } } + +# union all on scatter and single route +"select id from user where id = 1 union select id from user where id = 1 union all select id from user" +{ + "QueryType": "SELECT", + "Original": "select id from user where id = 1 union select id from user where id = 1 union all select id from user", + "Instructions": { + "OperatorType": "Concatenate", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from user where 1 != 1 union select id from user where 1 != 1", + "Query": "select id from user where id = 1 union select id from user where id = 1", + "Table": "user", + "Values": [ + 1 + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from user where 1 != 1", + "Query": "select id from user", + "Table": "user" + } + ] + } +} + diff --git a/go/vt/vttablet/sandboxconn/sandboxconn.go b/go/vt/vttablet/sandboxconn/sandboxconn.go index 4a75deb5cc7..e9e16557870 100644 --- a/go/vt/vttablet/sandboxconn/sandboxconn.go +++ b/go/vt/vttablet/sandboxconn/sandboxconn.go @@ -20,6 +20,7 @@ package sandboxconn import ( "fmt" + "sync" "golang.org/x/net/context" "vitess.io/vitess/go/sqltypes" @@ -93,6 +94,9 @@ type SandboxConn struct { // transaction id generator TransactionID sync2.AtomicInt64 + + sExecMu sync.Mutex + execMu sync.Mutex } var _ queryservice.QueryService = (*SandboxConn)(nil) // compile-time interface check @@ -128,6 +132,8 @@ func (sbc *SandboxConn) Execute(ctx context.Context, target *querypb.Target, que for k, v := range bindVars { bv[k] = v } + sbc.execMu.Lock() + defer sbc.execMu.Unlock() sbc.Queries = append(sbc.Queries, &querypb.BoundQuery{ Sql: query, BindVariables: bv, @@ -164,6 +170,8 @@ func (sbc *SandboxConn) StreamExecute(ctx context.Context, target *querypb.Targe for k, v := range bindVars { bv[k] = v } + sbc.sExecMu.Lock() + defer sbc.sExecMu.Unlock() sbc.Queries = append(sbc.Queries, &querypb.BoundQuery{ Sql: query, BindVariables: bv, From 8640d59f145a251d4f8ac1902ddb8e54fb63bd05 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Mon, 29 Jun 2020 08:55:55 +0530 Subject: [PATCH 09/13] union-all: data race fix on sandbox client and added additional test Signed-off-by: Harshit Gangal --- go/test/endtoend/vtgate/misc_test.go | 19 +++ go/vt/vtgate/engine/concatenate.go | 5 + go/vt/vtgate/engine/concatenate_test.go | 145 +++++++++++++++++++++- go/vt/vttablet/sandboxconn/sandboxconn.go | 13 +- 4 files changed, 175 insertions(+), 7 deletions(-) diff --git a/go/test/endtoend/vtgate/misc_test.go b/go/test/endtoend/vtgate/misc_test.go index da66cd5d8c2..dba04c1cb01 100644 --- a/go/test/endtoend/vtgate/misc_test.go +++ b/go/test/endtoend/vtgate/misc_test.go @@ -19,6 +19,8 @@ package vtgate import ( "context" "fmt" + "sort" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -579,6 +581,23 @@ func TestUnionAll(t *testing.T) { // union all between two different tables assertMatches(t, conn, "select tbl2.id1 FROM ((select id1 from t1 order by id1 limit 5) union all (select id1 from t1 order by id1 desc limit 5)) as tbl1 INNER JOIN t1 as tbl2 ON tbl1.id1 = tbl2.id1", "[[INT64(1)] [INT64(2)] [INT64(2)] [INT64(1)]]") + + exec(t, conn, "insert into t1(id1, id2) values(3, 3), (4, 4), (5, 5), (6, 6), (7, 7), (8, 8)") + + // union all between two selectuniquein tables + qr := exec(t, conn, "select id1 from t1 where id1 in (1, 2, 3, 4, 5, 6, 7, 8) union all select id1 from t1 where id1 in (1, 2, 3, 4, 5, 6, 7, 8)") + expected := sortString("[[INT64(1)] [INT64(2)] [INT64(3)] [INT64(5)] [INT64(4)] [INT64(6)] [INT64(7)] [INT64(8)] [INT64(1)] [INT64(2)] [INT64(3)] [INT64(5)] [INT64(4)] [INT64(6)] [INT64(7)] [INT64(8)]]") + assert.Equal(t, expected, sortString(fmt.Sprintf("%v", qr.Rows))) + + // clean up + exec(t, conn, "delete from t1") + exec(t, conn, "delete from t2") +} + +func sortString(w string) string { + s := strings.Split(w, "") + sort.Strings(s) + return strings.Join(s, "") } func TestUnion(t *testing.T) { diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index b6d8639b582..1f6e32b788b 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -131,10 +131,15 @@ func (c *Concatenate) StreamExecute(vcursor VCursor, bindVars map[string]*queryp return err } } + // This to ensure only one send happens back to the client. mu.Lock() defer mu.Unlock() return callback(resultChunk) }) + // This is to ensure other streams complete if the first stream failed to unlock the wait. + if i == 0 && !fieldsSent { + fieldset.Done() + } errs[i] = err }(i, source) } diff --git a/go/vt/vtgate/engine/concatenate_test.go b/go/vt/vtgate/engine/concatenate_test.go index c09be81b123..facd781743a 100644 --- a/go/vt/vtgate/engine/concatenate_test.go +++ b/go/vt/vtgate/engine/concatenate_test.go @@ -17,13 +17,14 @@ limitations under the License. package engine import ( + "errors" "testing" "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" ) -func TestConcatenate_Execute(t *testing.T) { +func TestConcatenate_NoSourcesErr(t *testing.T) { type testCase struct { testName string inputs []*sqltypes.Result @@ -73,7 +74,7 @@ func TestConcatenate_Execute(t *testing.T) { t.Run(tc.testName, func(t *testing.T) { var fps []Primitive for _, input := range tc.inputs { - fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{input, input, input, input, input, input}}) + fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{input, input}, sendErr: errors.New("abc")}) } concatenate := &Concatenate{Sources: fps} @@ -99,3 +100,143 @@ func TestConcatenate_Execute(t *testing.T) { }) } } + +func TestConcatenate_WithSourcesErrFirst(t *testing.T) { + type testCase struct { + testName string + inputs []*sqltypes.Result + } + + executeErr := "Concatenate.Execute: failed" + streamExecuteErr := "failed" + + testCases := []*testCase{{ + testName: "empty results", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id1|col11|col12", "int64|varbinary|varbinary")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id2|col21|col22", "int64|varbinary|varbinary")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id3|col31|col32", "int64|varbinary|varbinary")), + }, + }, { + testName: "2 non empty result", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("myid|mycol1|mycol2", "int64|varchar|varbinary"), "11|m1|n1", "22|m2|n2"), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), + }, + }, { + testName: "mismatch field type", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary"), "1|a1|b1", "2|a2|b2"), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col3|col4", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), + }, + }, { + testName: "input source has different column count", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varchar"), "1|a1|b1", "2|a2|b2"), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col3|col4|col5", "int64|varchar|varchar|int32"), "1|a1|b1|5", "2|a2|b2|6"), + }, + }, { + testName: "1 empty result and 1 non empty result", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("myid|mycol1|mycol2", "int64|varchar|varbinary")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), + }, + }} + + for _, tc := range testCases { + t.Run(tc.testName, func(t *testing.T) { + var fps []Primitive + fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{nil, nil}, sendErr: errors.New("failed")}) + for _, input := range tc.inputs { + fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{input, input}}) + } + concatenate := &Concatenate{Sources: fps} + + t.Run("Execute wantfields true", func(t *testing.T) { + _, err := concatenate.Execute(&noopVCursor{}, nil, true) + require.EqualError(t, err, executeErr) + + }) + + t.Run("StreamExecute wantfields true", func(t *testing.T) { + _, err := wrapStreamExecute(concatenate, &noopVCursor{}, nil, true) + require.EqualError(t, err, streamExecuteErr) + }) + }) + } +} + +func TestConcatenate_WithSourcesErrLast(t *testing.T) { + type testCase struct { + testName string + inputs []*sqltypes.Result + execErr, streamExecErr string + } + + executeErr := "Concatenate.Execute: failed" + streamExecuteErr := "failed" + + testCases := []*testCase{{ + testName: "empty results", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id1|col11|col12", "int64|varbinary|varbinary")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id2|col21|col22", "int64|varbinary|varbinary")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id3|col31|col32", "int64|varbinary|varbinary")), + }, + execErr: executeErr, + streamExecErr: streamExecuteErr, + }, { + testName: "2 non empty result", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("myid|mycol1|mycol2", "int64|varchar|varbinary"), "11|m1|n1", "22|m2|n2"), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), + }, + execErr: executeErr, + streamExecErr: streamExecuteErr, + }, { + testName: "mismatch field type", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary"), "1|a1|b1", "2|a2|b2"), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col3|col4", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), + }, + execErr: "column field type does not match for name: (col1, col3) types: (VARBINARY, VARCHAR)", + streamExecErr: "column field type does not match for name: (col1, col3) types: (VARBINARY, VARCHAR)", + }, { + testName: "input source has different column count", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varchar"), "1|a1|b1", "2|a2|b2"), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col3|col4|col5", "int64|varchar|varchar|int32"), "1|a1|b1|5", "2|a2|b2|6"), + }, + execErr: "The used SELECT statements have a different number of columns (errno 1222) (sqlstate 21000)", + streamExecErr: "The used SELECT statements have a different number of columns (errno 1222) (sqlstate 21000)", + }, { + testName: "1 empty result and 1 non empty result", + inputs: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("myid|mycol1|mycol2", "int64|varchar|varbinary")), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), + }, + execErr: executeErr, + streamExecErr: streamExecuteErr, + }} + + for _, tc := range testCases { + t.Run(tc.testName, func(t *testing.T) { + var fps []Primitive + for _, input := range tc.inputs { + fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{input, input}}) + } + fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{nil, nil}, sendErr: errors.New("failed")}) + concatenate := &Concatenate{Sources: fps} + + t.Run("Execute wantfields true", func(t *testing.T) { + _, err := concatenate.Execute(&noopVCursor{}, nil, true) + require.EqualError(t, err, tc.execErr) + }) + + t.Run("StreamExecute wantfields true", func(t *testing.T) { + _, err := wrapStreamExecute(concatenate, &noopVCursor{}, nil, true) + require.EqualError(t, err, tc.streamExecErr) + }) + }) + } +} diff --git a/go/vt/vttablet/sandboxconn/sandboxconn.go b/go/vt/vttablet/sandboxconn/sandboxconn.go index e9e16557870..229b2ddd8f0 100644 --- a/go/vt/vttablet/sandboxconn/sandboxconn.go +++ b/go/vt/vttablet/sandboxconn/sandboxconn.go @@ -127,13 +127,13 @@ func (sbc *SandboxConn) SetResults(r []*sqltypes.Result) { // Execute is part of the QueryService interface. func (sbc *SandboxConn) Execute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (*sqltypes.Result, error) { + sbc.execMu.Lock() + defer sbc.execMu.Unlock() sbc.ExecCount.Add(1) bv := make(map[string]*querypb.BindVariable) for k, v := range bindVars { bv[k] = v } - sbc.execMu.Lock() - defer sbc.execMu.Unlock() sbc.Queries = append(sbc.Queries, &querypb.BoundQuery{ Sql: query, BindVariables: bv, @@ -165,13 +165,12 @@ func (sbc *SandboxConn) ExecuteBatch(ctx context.Context, target *querypb.Target // StreamExecute is part of the QueryService interface. func (sbc *SandboxConn) StreamExecute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) error { + sbc.sExecMu.Lock() sbc.ExecCount.Add(1) bv := make(map[string]*querypb.BindVariable) for k, v := range bindVars { bv[k] = v } - sbc.sExecMu.Lock() - defer sbc.sExecMu.Unlock() sbc.Queries = append(sbc.Queries, &querypb.BoundQuery{ Sql: query, BindVariables: bv, @@ -179,9 +178,13 @@ func (sbc *SandboxConn) StreamExecute(ctx context.Context, target *querypb.Targe sbc.Options = append(sbc.Options, options) err := sbc.getError() if err != nil { + sbc.sExecMu.Unlock() return err } - return callback(sbc.getNextResult()) + nextRs := sbc.getNextResult() + sbc.sExecMu.Unlock() + + return callback(nextRs) } // Begin is part of the QueryService interface. From 0b585ba91ee2a17d7fc4daf183944d6caf5b95de Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Mon, 29 Jun 2020 13:51:10 +0530 Subject: [PATCH 10/13] union-all: use errgroup context to cancel parallel streams Signed-off-by: Harshit Gangal --- go/vt/vtgate/engine/concatenate.go | 29 +++++++------- go/vt/vtgate/engine/concatenate_test.go | 48 +++++++++++++----------- go/vt/vtgate/engine/fake_vcursor_test.go | 22 +++++++++-- go/vt/vtgate/engine/primitive.go | 5 +++ go/vt/vtgate/vcursor_impl.go | 9 +++++ 5 files changed, 74 insertions(+), 39 deletions(-) diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 1f6e32b788b..3e1539dcb7c 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -105,16 +105,15 @@ func (c *Concatenate) Execute(vcursor VCursor, bindVars map[string]*querypb.Bind // StreamExecute performs a streaming exec. func (c *Concatenate) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { var seenFields []*querypb.Field + var fieldset sync.WaitGroup fieldsSent := false - errs := make([]error, len(c.Sources)) - var wg, fieldset sync.WaitGroup + g := vcursor.ErrorGroupCancellableContext() fieldset.Add(1) var mu sync.Mutex for i, source := range c.Sources { - wg.Add(1) - go func(i int, source Primitive) { - defer wg.Done() + i, source := i, source + g.Go(func() error { err := source.StreamExecute(vcursor, bindVars, wantfields, func(resultChunk *sqltypes.Result) error { // if we have fields to compare, make sure all the fields are all the same if i == 0 && !fieldsSent { @@ -134,21 +133,23 @@ func (c *Concatenate) StreamExecute(vcursor VCursor, bindVars map[string]*queryp // This to ensure only one send happens back to the client. mu.Lock() defer mu.Unlock() - return callback(resultChunk) + select { + case <-vcursor.Context().Done(): + return nil + default: + return callback(resultChunk) + } }) // This is to ensure other streams complete if the first stream failed to unlock the wait. if i == 0 && !fieldsSent { fieldset.Done() } - errs[i] = err - }(i, source) - } - wg.Wait() - - for _, err := range errs { - if err != nil { return err - } + }) + + } + if err := g.Wait(); err != nil { + return err } return nil } diff --git a/go/vt/vtgate/engine/concatenate_test.go b/go/vt/vtgate/engine/concatenate_test.go index facd781743a..7ad3207d4eb 100644 --- a/go/vt/vtgate/engine/concatenate_test.go +++ b/go/vt/vtgate/engine/concatenate_test.go @@ -17,6 +17,7 @@ limitations under the License. package engine import ( + "context" "errors" "testing" @@ -79,7 +80,7 @@ func TestConcatenate_NoSourcesErr(t *testing.T) { concatenate := &Concatenate{Sources: fps} t.Run("Execute wantfields true", func(t *testing.T) { - qr, err := concatenate.Execute(&noopVCursor{}, nil, true) + qr, err := concatenate.Execute(&noopVCursor{ctx: context.Background()}, nil, true) if tc.expectedError == "" { require.NoError(t, err) require.Equal(t, tc.expectedResult, qr) @@ -89,7 +90,7 @@ func TestConcatenate_NoSourcesErr(t *testing.T) { }) t.Run("StreamExecute wantfields true", func(t *testing.T) { - qr, err := wrapStreamExecute(concatenate, &noopVCursor{}, nil, true) + qr, err := wrapStreamExecute(concatenate, &noopVCursor{ctx: context.Background()}, nil, true) if tc.expectedError == "" { require.NoError(t, err) require.Equal(t, tc.expectedResult, qr) @@ -107,9 +108,8 @@ func TestConcatenate_WithSourcesErrFirst(t *testing.T) { inputs []*sqltypes.Result } - executeErr := "Concatenate.Execute: failed" - streamExecuteErr := "failed" - + strFailed := "failed" + executeErr := "Concatenate.Execute: " + strFailed testCases := []*testCase{{ testName: "empty results", inputs: []*sqltypes.Result{ @@ -146,21 +146,21 @@ func TestConcatenate_WithSourcesErrFirst(t *testing.T) { for _, tc := range testCases { t.Run(tc.testName, func(t *testing.T) { var fps []Primitive - fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{nil, nil}, sendErr: errors.New("failed")}) + fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{nil, nil}, sendErr: errors.New(strFailed)}) for _, input := range tc.inputs { fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{input, input}}) } concatenate := &Concatenate{Sources: fps} t.Run("Execute wantfields true", func(t *testing.T) { - _, err := concatenate.Execute(&noopVCursor{}, nil, true) + _, err := concatenate.Execute(&noopVCursor{ctx: context.Background()}, nil, true) require.EqualError(t, err, executeErr) }) t.Run("StreamExecute wantfields true", func(t *testing.T) { - _, err := wrapStreamExecute(concatenate, &noopVCursor{}, nil, true) - require.EqualError(t, err, streamExecuteErr) + _, err := wrapStreamExecute(concatenate, &noopVCursor{ctx: context.Background()}, nil, true) + require.EqualError(t, err, strFailed) }) }) } @@ -171,11 +171,11 @@ func TestConcatenate_WithSourcesErrLast(t *testing.T) { testName string inputs []*sqltypes.Result execErr, streamExecErr string + nonDeterministicErr bool } - executeErr := "Concatenate.Execute: failed" - streamExecuteErr := "failed" - + strFailed := "failed" + executeErr := "Concatenate.Execute: " + strFailed testCases := []*testCase{{ testName: "empty results", inputs: []*sqltypes.Result{ @@ -184,7 +184,7 @@ func TestConcatenate_WithSourcesErrLast(t *testing.T) { sqltypes.MakeTestResult(sqltypes.MakeTestFields("id3|col31|col32", "int64|varbinary|varbinary")), }, execErr: executeErr, - streamExecErr: streamExecuteErr, + streamExecErr: strFailed, }, { testName: "2 non empty result", inputs: []*sqltypes.Result{ @@ -192,15 +192,16 @@ func TestConcatenate_WithSourcesErrLast(t *testing.T) { sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), }, execErr: executeErr, - streamExecErr: streamExecuteErr, + streamExecErr: strFailed, }, { testName: "mismatch field type", inputs: []*sqltypes.Result{ sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varbinary|varbinary"), "1|a1|b1", "2|a2|b2"), sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col3|col4", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), }, - execErr: "column field type does not match for name: (col1, col3) types: (VARBINARY, VARCHAR)", - streamExecErr: "column field type does not match for name: (col1, col3) types: (VARBINARY, VARCHAR)", + execErr: "column field type does not match for name: (col1, col3) types: (VARBINARY, VARCHAR)", + streamExecErr: "column field type does not match for name: (col1, col3) types: (VARBINARY, VARCHAR)", + nonDeterministicErr: true, }, { testName: "input source has different column count", inputs: []*sqltypes.Result{ @@ -208,7 +209,7 @@ func TestConcatenate_WithSourcesErrLast(t *testing.T) { sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col3|col4|col5", "int64|varchar|varchar|int32"), "1|a1|b1|5", "2|a2|b2|6"), }, execErr: "The used SELECT statements have a different number of columns (errno 1222) (sqlstate 21000)", - streamExecErr: "The used SELECT statements have a different number of columns (errno 1222) (sqlstate 21000)", + streamExecErr: strFailed, }, { testName: "1 empty result and 1 non empty result", inputs: []*sqltypes.Result{ @@ -216,7 +217,7 @@ func TestConcatenate_WithSourcesErrLast(t *testing.T) { sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varbinary"), "1|a1|b1", "2|a2|b2"), }, execErr: executeErr, - streamExecErr: streamExecuteErr, + streamExecErr: strFailed, }} for _, tc := range testCases { @@ -225,17 +226,20 @@ func TestConcatenate_WithSourcesErrLast(t *testing.T) { for _, input := range tc.inputs { fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{input, input}}) } - fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{nil, nil}, sendErr: errors.New("failed")}) + fps = append(fps, &fakePrimitive{results: []*sqltypes.Result{nil, nil}, sendErr: errors.New(strFailed)}) concatenate := &Concatenate{Sources: fps} t.Run("Execute wantfields true", func(t *testing.T) { - _, err := concatenate.Execute(&noopVCursor{}, nil, true) + _, err := concatenate.Execute(&noopVCursor{ctx: context.Background()}, nil, true) require.EqualError(t, err, tc.execErr) }) t.Run("StreamExecute wantfields true", func(t *testing.T) { - _, err := wrapStreamExecute(concatenate, &noopVCursor{}, nil, true) - require.EqualError(t, err, tc.streamExecErr) + _, err := wrapStreamExecute(concatenate, &noopVCursor{ctx: context.Background()}, nil, true) + require.Error(t, err) + if !tc.nonDeterministicErr { + require.EqualError(t, err, tc.streamExecErr) + } }) }) } diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index 7d93c779a72..f0452eef371 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -25,6 +25,8 @@ import ( "testing" "time" + "golang.org/x/sync/errgroup" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/vt/sqlparser" @@ -48,6 +50,7 @@ var _ SessionActions = (*noopVCursor)(nil) // noopVCursor is used to build other vcursors. type noopVCursor struct { + ctx context.Context } func (t noopVCursor) SetUDV(key string, value interface{}) error { @@ -61,17 +64,20 @@ func (t noopVCursor) SetSysVar(name string, expr string) { func (t noopVCursor) ExecuteVSchema(keyspace string, vschemaDDL *sqlparser.DDL) error { panic("implement me") } + func (t noopVCursor) Session() SessionActions { return t } + func (t noopVCursor) SetTarget(target string) error { panic("implement me") } - func (t noopVCursor) Context() context.Context { - return context.Background() + if t.ctx == nil { + return context.Background() + } + return t.ctx } - func (t noopVCursor) MaxMemoryRows() int { return testMaxMemoryRows } @@ -80,6 +86,12 @@ func (t noopVCursor) SetContextTimeout(timeout time.Duration) context.CancelFunc return func() {} } +func (t noopVCursor) ErrorGroupCancellableContext() *errgroup.Group { + g, ctx := errgroup.WithContext(t.ctx) + t.ctx = ctx + return g +} + func (t noopVCursor) RecordWarning(warning *querypb.QueryWarning) { } @@ -169,6 +181,10 @@ func (f *loggingVCursor) SetContextTimeout(timeout time.Duration) context.Cancel return func() {} } +func (f *loggingVCursor) ErrorGroupCancellableContext() *errgroup.Group { + panic("implement me") +} + func (f *loggingVCursor) RecordWarning(warning *querypb.QueryWarning) { f.warnings = append(f.warnings, warning) } diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index 1fe163f2c84..8003a7ab58f 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -21,6 +21,8 @@ import ( "sync" "time" + "golang.org/x/sync/errgroup" + "vitess.io/vitess/go/vt/sqlparser" "golang.org/x/net/context" @@ -55,6 +57,9 @@ type ( // SetContextTimeout updates the context and sets a timeout. SetContextTimeout(timeout time.Duration) context.CancelFunc + // ErrorGroupCancellableContext updates context that can be cancelled. + ErrorGroupCancellableContext() *errgroup.Group + // V3 functions. Execute(method string, query string, bindvars map[string]*querypb.BindVariable, rollbackOnError bool, co vtgatepb.CommitOrder) (*sqltypes.Result, error) AutocommitApproval() bool diff --git a/go/vt/vtgate/vcursor_impl.go b/go/vt/vtgate/vcursor_impl.go index d6d8acf0483..2b739bd06d9 100644 --- a/go/vt/vtgate/vcursor_impl.go +++ b/go/vt/vtgate/vcursor_impl.go @@ -23,6 +23,8 @@ import ( "sync/atomic" "time" + "golang.org/x/sync/errgroup" + "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/vt/callerid" @@ -174,6 +176,13 @@ func (vc *vcursorImpl) SetContextTimeout(timeout time.Duration) context.CancelFu return cancel } +// ErrorGroupCancellableContext updates context that can be cancelled. +func (vc *vcursorImpl) ErrorGroupCancellableContext() *errgroup.Group { + g, ctx := errgroup.WithContext(vc.ctx) + vc.ctx = ctx + return g +} + // RecordWarning stores the given warning in the current session func (vc *vcursorImpl) RecordWarning(warning *querypb.QueryWarning) { vc.safeSession.RecordWarning(warning) From ef0b8872f65144de2e82ba760e2eee1f254e349d Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Mon, 29 Jun 2020 10:55:18 +0200 Subject: [PATCH 11/13] union-all: Mark test as non-deterministic Signed-off-by: Andres Taylor --- go/vt/vtgate/engine/concatenate_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/go/vt/vtgate/engine/concatenate_test.go b/go/vt/vtgate/engine/concatenate_test.go index 7ad3207d4eb..e5c88f667dd 100644 --- a/go/vt/vtgate/engine/concatenate_test.go +++ b/go/vt/vtgate/engine/concatenate_test.go @@ -208,8 +208,9 @@ func TestConcatenate_WithSourcesErrLast(t *testing.T) { sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1|col2", "int64|varchar|varchar"), "1|a1|b1", "2|a2|b2"), sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col3|col4|col5", "int64|varchar|varchar|int32"), "1|a1|b1|5", "2|a2|b2|6"), }, - execErr: "The used SELECT statements have a different number of columns (errno 1222) (sqlstate 21000)", - streamExecErr: strFailed, + execErr: "The used SELECT statements have a different number of columns (errno 1222) (sqlstate 21000)", + streamExecErr: strFailed, + nonDeterministicErr: true, }, { testName: "1 empty result and 1 non empty result", inputs: []*sqltypes.Result{ From bd59a4834c331a0b58d1318f76e542bead47e800 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Mon, 29 Jun 2020 10:58:18 +0200 Subject: [PATCH 12/13] union-all: use specific version of x/sync Signed-off-by: Andres Taylor --- go.mod | 1 + 1 file changed, 1 insertion(+) diff --git a/go.mod b/go.mod index 0fda1348b06..dd25c892b65 100644 --- a/go.mod +++ b/go.mod @@ -77,6 +77,7 @@ require ( golang.org/x/lint v0.0.0-20190409202823-959b441ac422 golang.org/x/net v0.0.0-20200202094626-16171245cfb2 golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 + golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e golang.org/x/text v0.3.2 golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 golang.org/x/tools v0.0.0-20191219041853-979b82bfef62 From a40e65dcd77851b3c02c0531f31a9ca85f44387b Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Mon, 29 Jun 2020 22:19:19 +0530 Subject: [PATCH 13/13] union-all: sort rows string for test comparison Signed-off-by: Harshit Gangal --- go/test/endtoend/vtgate/misc_test.go | 14 ++++---------- go/test/utils/sort.go | 13 +++++++++++++ go/vt/vtgate/engine/concatenate_test.go | 5 ++++- 3 files changed, 21 insertions(+), 11 deletions(-) create mode 100644 go/test/utils/sort.go diff --git a/go/test/endtoend/vtgate/misc_test.go b/go/test/endtoend/vtgate/misc_test.go index dba04c1cb01..55cf98328ae 100644 --- a/go/test/endtoend/vtgate/misc_test.go +++ b/go/test/endtoend/vtgate/misc_test.go @@ -19,10 +19,10 @@ package vtgate import ( "context" "fmt" - "sort" - "strings" "testing" + "vitess.io/vitess/go/test/utils" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" @@ -586,20 +586,14 @@ func TestUnionAll(t *testing.T) { // union all between two selectuniquein tables qr := exec(t, conn, "select id1 from t1 where id1 in (1, 2, 3, 4, 5, 6, 7, 8) union all select id1 from t1 where id1 in (1, 2, 3, 4, 5, 6, 7, 8)") - expected := sortString("[[INT64(1)] [INT64(2)] [INT64(3)] [INT64(5)] [INT64(4)] [INT64(6)] [INT64(7)] [INT64(8)] [INT64(1)] [INT64(2)] [INT64(3)] [INT64(5)] [INT64(4)] [INT64(6)] [INT64(7)] [INT64(8)]]") - assert.Equal(t, expected, sortString(fmt.Sprintf("%v", qr.Rows))) + expected := utils.SortString("[[INT64(1)] [INT64(2)] [INT64(3)] [INT64(5)] [INT64(4)] [INT64(6)] [INT64(7)] [INT64(8)] [INT64(1)] [INT64(2)] [INT64(3)] [INT64(5)] [INT64(4)] [INT64(6)] [INT64(7)] [INT64(8)]]") + assert.Equal(t, expected, utils.SortString(fmt.Sprintf("%v", qr.Rows))) // clean up exec(t, conn, "delete from t1") exec(t, conn, "delete from t2") } -func sortString(w string) string { - s := strings.Split(w, "") - sort.Strings(s) - return strings.Join(s, "") -} - func TestUnion(t *testing.T) { conn, err := mysql.Connect(context.Background(), &vtParams) require.NoError(t, err) diff --git a/go/test/utils/sort.go b/go/test/utils/sort.go new file mode 100644 index 00000000000..f3584ef70ac --- /dev/null +++ b/go/test/utils/sort.go @@ -0,0 +1,13 @@ +package utils + +import ( + "sort" + "strings" +) + +//SortString sorts the string. +func SortString(w string) string { + s := strings.Split(w, "") + sort.Strings(s) + return strings.Join(s, "") +} diff --git a/go/vt/vtgate/engine/concatenate_test.go b/go/vt/vtgate/engine/concatenate_test.go index e5c88f667dd..27cda59ce40 100644 --- a/go/vt/vtgate/engine/concatenate_test.go +++ b/go/vt/vtgate/engine/concatenate_test.go @@ -19,8 +19,11 @@ package engine import ( "context" "errors" + "fmt" "testing" + "vitess.io/vitess/go/test/utils" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" ) @@ -93,7 +96,7 @@ func TestConcatenate_NoSourcesErr(t *testing.T) { qr, err := wrapStreamExecute(concatenate, &noopVCursor{ctx: context.Background()}, nil, true) if tc.expectedError == "" { require.NoError(t, err) - require.Equal(t, tc.expectedResult, qr) + require.Equal(t, utils.SortString(fmt.Sprintf("%v", tc.expectedResult.Rows)), utils.SortString(fmt.Sprintf("%v", qr.Rows))) } else { require.EqualError(t, err, tc.expectedError) }