From 0ba15caf20945553816146a0655da9a9861db866 Mon Sep 17 00:00:00 2001 From: Shengzhe Yao Date: Tue, 21 Jul 2015 18:07:02 -0700 Subject: [PATCH] remove panic out of queryservice 1. Remove panics out of QueryExecutor and return error instead. 2. Some places like TxPool, CachePool, SchemaInfo still panics and will be captured and handled in SqlQuery. --- go/vt/tabletserver/codex.go | 29 +- go/vt/tabletserver/codex_test.go | 79 +-- go/vt/tabletserver/dbconn.go | 3 +- go/vt/tabletserver/query_executor.go | 429 ++++++++++------ go/vt/tabletserver/query_executor_test.go | 582 ++++++++++++++++------ go/vt/tabletserver/sqlquery.go | 71 ++- 6 files changed, 807 insertions(+), 386 deletions(-) diff --git a/go/vt/tabletserver/codex.go b/go/vt/tabletserver/codex.go index 75d5bd406fd..64a52b3649f 100644 --- a/go/vt/tabletserver/codex.go +++ b/go/vt/tabletserver/codex.go @@ -40,12 +40,13 @@ func buildValueList(tableInfo *TableInfo, pkValues []interface{}, bindVars map[s func resolvePKValues(tableInfo *TableInfo, pkValues []interface{}, bindVars map[string]interface{}) (resolved []interface{}, length int, err error) { length = -1 - setLength := func(list []sqltypes.Value) { + setLengthFunc := func(list []sqltypes.Value) error { if length == -1 { length = len(list) } else if len(list) != length { - panic(NewTabletError(ErrFail, "mismatched lengths for values %v", pkValues)) + return NewTabletError(ErrFail, "mismatched lengths for values %v", pkValues) } + return nil } resolved = make([]interface{}, len(pkValues)) for i, val := range pkValues { @@ -61,7 +62,9 @@ func resolvePKValues(tableInfo *TableInfo, pkValues []interface{}, bindVars map[ if err != nil { return nil, 0, err } - setLength(list) + if err := setLengthFunc(list); err != nil { + return nil, 0, err + } resolved[i] = list } case []interface{}: @@ -72,7 +75,9 @@ func resolvePKValues(tableInfo *TableInfo, pkValues []interface{}, bindVars map[ return nil, 0, err } } - setLength(list) + if err := setLengthFunc(list); err != nil { + return nil, 0, err + } resolved[i] = list default: resolved[i], err = resolveValue(tableInfo.GetPKColumn(i), val, nil) @@ -144,7 +149,7 @@ func resolveValue(col *schema.TableColumn, value interface{}, bindVars map[strin case sqltypes.Value: result = v default: - panic(NewTabletError(ErrFail, "incompatible value type %v", v)) + return result, NewTabletError(ErrFail, "incompatible value type %v", v) } if err = validateValue(col, result); err != nil { @@ -185,12 +190,12 @@ func validateValue(col *schema.TableColumn, value sqltypes.Value) error { // getLimit resolves the rowcount or offset of the limit clause value. // It returns -1 if it's not set. -func getLimit(limit interface{}, bv map[string]interface{}) int64 { +func getLimit(limit interface{}, bv map[string]interface{}) (int64, error) { switch lim := limit.(type) { case string: lookup, ok := bv[lim[1:]] if !ok { - panic(NewTabletError(ErrFail, "missing bind var %s", lim)) + return -1, NewTabletError(ErrFail, "missing bind var %s", lim) } var newlim int64 switch l := lookup.(type) { @@ -201,16 +206,16 @@ func getLimit(limit interface{}, bv map[string]interface{}) int64 { case int: newlim = int64(l) default: - panic(NewTabletError(ErrFail, "want number type for %s, got %T", lim, lookup)) + return -1, NewTabletError(ErrFail, "want number type for %s, got %T", lim, lookup) } if newlim < 0 { - panic(NewTabletError(ErrFail, "negative limit %d", newlim)) + return -1, NewTabletError(ErrFail, "negative limit %d", newlim) } - return newlim + return newlim, nil case int64: - return lim + return lim, nil default: - return -1 + return -1, nil } } diff --git a/go/vt/tabletserver/codex_test.go b/go/vt/tabletserver/codex_test.go index c956eaae35f..3779ad212fd 100644 --- a/go/vt/tabletserver/codex_test.go +++ b/go/vt/tabletserver/codex_test.go @@ -229,10 +229,8 @@ func TestCodexResolvePKValues(t *testing.T) { pkValues = make([]interface{}, 0, 10) pkValues = append(pkValues, []interface{}{":" + key}) pkValues = append(pkValues, []interface{}{":" + key2, ":" + key3}) - func() { - defer testUtils.checkTabletErrorWithRecover(t, ErrFail, "mismatched lengths") - _, _, err = resolvePKValues(&tableInfo, pkValues, bindVariables) - }() + _, _, err = resolvePKValues(&tableInfo, pkValues, bindVariables) + testUtils.checkTabletError(t, err, ErrFail, "mismatched lengths") } func TestCodexResolveListArg(t *testing.T) { @@ -322,11 +320,8 @@ func TestCodexResolveValueWithIncompatibleValueType(t *testing.T) { []string{"pk1", "pk2", "col1"}, []string{"int", "varbinary(128)", "int"}, []string{"pk1", "pk2"}) - - func() { - defer testUtils.checkTabletErrorWithRecover(t, ErrFail, "incompatible value type ") - resolveValue(tableInfo.GetPKColumn(0), 0, nil) - }() + _, err := resolveValue(tableInfo.GetPKColumn(0), 0, nil) + testUtils.checkTabletError(t, err, ErrFail, "incompatible value type ") } func TestCodexValidateRow(t *testing.T) { @@ -352,46 +347,52 @@ func TestCodexGetLimit(t *testing.T) { "uint": uint(1), } testUtils := newTestUtils() - // to handle panics - func() { - defer testUtils.checkTabletErrorWithRecover(t, ErrFail, "missing bind var") - getLimit(":unknown", bv) - }() - if result := getLimit(int64(1), bv); result != 1 { + _, err := getLimit(":unknown", bv) + if err == nil { + t.Fatal("got nil, want error: missing bind var") + } + testUtils.checkTabletError(t, err, ErrFail, "missing bind var") + result, err := getLimit(int64(1), bv) + if err != nil { + t.Fatalf("getLimit(1, bv) = %v, want nil", err) + } + if result != 1 { t.Fatalf("got %d, want 1", result) } - if result := getLimit(nil, bv); result != -1 { + result, err = getLimit(nil, bv) + if err != nil { + t.Fatalf("getLimit(nil, bv) = %v, want nil", err) + } + if result != -1 { t.Fatalf("got %d, want -1", result) } - func() { - defer func() { - x := recover().(error).Error() - want := "error: negative limit -1" - if x != want { - t.Fatalf("got %s, want %s", x, want) - } - }() - getLimit(":negative", bv) - }() - if result := getLimit(":int64", bv); result != 1 { + + result, err = getLimit(":negative", bv) + if err == nil { + t.Fatalf("getLimit(':negative', bv) should return an error") + } + want := "error: negative limit -1" + if err.Error() != want { + t.Fatalf("got %s, want %s", err.Error(), want) + } + if result, _ := getLimit(":int64", bv); result != 1 { t.Fatalf("got %d, want 1", result) } - if result := getLimit(":int32", bv); result != 1 { + if result, _ := getLimit(":int32", bv); result != 1 { t.Fatalf("got %d, want 1", result) } - if result := getLimit(":int", bv); result != 1 { + if result, _ := getLimit(":int", bv); result != 1 { t.Fatalf("got %d, want 1", result) } - func() { - defer func() { - x := recover().(error).Error() - want := "error: want number type for :uint, got uint" - if x != want { - t.Fatalf("got %s, want %s", x, want) - } - }() - getLimit(":uint", bv) - }() + + _, err = getLimit(":uint", bv) + if err == nil { + t.Fatalf("getLimit(':uint', bv) should return an error") + } + want = "error: want number type for :uint, got uint" + if err.Error() != want { + t.Fatalf("got %s, want %s", err.Error(), want) + } } func TestCodexBuildKey(t *testing.T) { diff --git a/go/vt/tabletserver/dbconn.go b/go/vt/tabletserver/dbconn.go index 21f0bb1aa99..8800acd6fdd 100644 --- a/go/vt/tabletserver/dbconn.go +++ b/go/vt/tabletserver/dbconn.go @@ -5,6 +5,7 @@ package tabletserver import ( + "errors" "fmt" "time" @@ -67,7 +68,7 @@ func (dbc *DBConn) Exec(ctx context.Context, query string, maxrows int, wantfiel return nil, NewTabletErrorSql(ErrFatal, err) } } - panic("unreachable") + return nil, NewTabletErrorSql(ErrFatal, errors.New("dbconn.Exec: unreachable code")) } func (dbc *DBConn) execOnce(ctx context.Context, query string, maxrows int, wantfields bool) (*mproto.QueryResult, error) { diff --git a/go/vt/tabletserver/query_executor.go b/go/vt/tabletserver/query_executor.go index 73b1ed46c82..a1eda00ea59 100644 --- a/go/vt/tabletserver/query_executor.go +++ b/go/vt/tabletserver/query_executor.go @@ -36,7 +36,7 @@ type poolConn interface { } // Execute performs a non-streaming query execution. -func (qre *QueryExecutor) Execute() (reply *mproto.QueryResult) { +func (qre *QueryExecutor) Execute() (reply *mproto.QueryResult, err error) { qre.logStats.OriginalSql = qre.query qre.logStats.BindVariables = qre.bindVars qre.logStats.TransactionID = qre.transactionID @@ -55,7 +55,9 @@ func (qre *QueryExecutor) Execute() (reply *mproto.QueryResult) { qre.qe.queryServiceStats.ResultStats.Add(int64(len(reply.Rows))) }(time.Now()) - qre.checkPermissions() + if err := qre.checkPermissions(); err != nil { + return nil, err + } if qre.plan.PlanId == planbuilder.PLAN_DDL { return qre.execDDL() @@ -73,79 +75,89 @@ func (qre *QueryExecutor) Execute() (reply *mproto.QueryResult) { switch qre.plan.PlanId { case planbuilder.PLAN_PASS_DML: if qre.qe.strictMode.Get() != 0 { - panic(NewTabletError(ErrFail, "DML too complex")) + return nil, NewTabletError(ErrFail, "DML too complex") } - reply = qre.directFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) + reply, err = qre.directFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) case planbuilder.PLAN_INSERT_PK: - reply = qre.execInsertPK(conn) + reply, err = qre.execInsertPK(conn) case planbuilder.PLAN_INSERT_SUBQUERY: - reply = qre.execInsertSubquery(conn) + reply, err = qre.execInsertSubquery(conn) case planbuilder.PLAN_DML_PK: - reply = qre.execDMLPK(conn, invalidator) + reply, err = qre.execDMLPK(conn, invalidator) case planbuilder.PLAN_DML_SUBQUERY: - reply = qre.execDMLSubquery(conn, invalidator) + reply, err = qre.execDMLSubquery(conn, invalidator) case planbuilder.PLAN_OTHER: - reply = qre.execSQL(conn, qre.query, true) + reply, err = qre.execSQL(conn, qre.query, true) default: // select or set in a transaction, just count as select - reply = qre.execDirect(conn) + reply, err = qre.execDirect(conn) } } else { switch qre.plan.PlanId { case planbuilder.PLAN_PASS_SELECT: if qre.plan.Reason == planbuilder.REASON_LOCK { - panic(NewTabletError(ErrFail, "Disallowed outside transaction")) + return nil, NewTabletError(ErrFail, "Disallowed outside transaction") } - reply = qre.execSelect() + reply, err = qre.execSelect() case planbuilder.PLAN_PK_IN: - reply = qre.execPKIN() + reply, err = qre.execPKIN() case planbuilder.PLAN_SELECT_SUBQUERY: - reply = qre.execSubquery() + reply, err = qre.execSubquery() case planbuilder.PLAN_SET: - reply = qre.execSet() + reply, err = qre.execSet() case planbuilder.PLAN_OTHER: - conn := qre.getConn(qre.qe.connPool) + conn, err := qre.getConn(qre.qe.connPool) + if err != nil { + return nil, err + } defer conn.Recycle() - reply = qre.execSQL(conn, qre.query, true) + reply, err = qre.execSQL(conn, qre.query, true) default: if !qre.qe.enableAutoCommit { - panic(NewTabletError(ErrFatal, "unsupported query: %s", qre.query)) + return nil, NewTabletError(ErrFatal, "unsupported query: %s", qre.query) } - reply = qre.execDmlAutoCommit() + reply, err = qre.execDmlAutoCommit() } } - return reply + return reply, err } // Stream performs a streaming query execution. -func (qre *QueryExecutor) Stream(sendReply func(*mproto.QueryResult) error) { +func (qre *QueryExecutor) Stream(sendReply func(*mproto.QueryResult) error) error { qre.logStats.OriginalSql = qre.query qre.logStats.PlanType = qre.plan.PlanId.String() defer qre.qe.queryServiceStats.QueryStats.Record(qre.plan.PlanId.String(), time.Now()) - qre.checkPermissions() + if err := qre.checkPermissions(); err != nil { + return err + } - conn := qre.getConn(qre.qe.streamConnPool) + conn, err := qre.getConn(qre.qe.streamConnPool) + if err != nil { + return err + } defer conn.Recycle() qd := NewQueryDetail(qre.logStats.ctx, conn) qre.qe.streamQList.Add(qd) defer qre.qe.streamQList.Remove(qd) - qre.fullStreamFetch(conn, qre.plan.FullQuery, qre.bindVars, nil, sendReply) + return qre.fullStreamFetch(conn, qre.plan.FullQuery, qre.bindVars, nil, sendReply) } -func (qre *QueryExecutor) execDmlAutoCommit() (reply *mproto.QueryResult) { +func (qre *QueryExecutor) execDmlAutoCommit() (reply *mproto.QueryResult, err error) { transactionID := qre.qe.txPool.Begin(qre.ctx) qre.logStats.AddRewrittenSql("begin", time.Now()) defer func() { - err := recover() - if err == nil { - qre.qe.Commit(qre.ctx, qre.logStats, transactionID) - qre.logStats.AddRewrittenSql("commit", time.Now()) - } else { + // TxPool.Get may panic + if panicErr := recover(); panicErr != nil { + err = fmt.Errorf("DML autocommit got panic: %v", panicErr) + } + if err != nil { qre.qe.txPool.Rollback(qre.ctx, transactionID) qre.logStats.AddRewrittenSql("rollback", time.Now()) - panic(err) + } else { + qre.qe.Commit(qre.ctx, qre.logStats, transactionID) + qre.logStats.AddRewrittenSql("commit", time.Now()) } }() conn := qre.qe.txPool.Get(transactionID) @@ -157,27 +169,27 @@ func (qre *QueryExecutor) execDmlAutoCommit() (reply *mproto.QueryResult) { switch qre.plan.PlanId { case planbuilder.PLAN_PASS_DML: if qre.qe.strictMode.Get() != 0 { - panic(NewTabletError(ErrFail, "DML too complex")) + return nil, NewTabletError(ErrFail, "DML too complex") } - reply = qre.directFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) + reply, err = qre.directFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) case planbuilder.PLAN_INSERT_PK: - reply = qre.execInsertPK(conn) + reply, err = qre.execInsertPK(conn) case planbuilder.PLAN_INSERT_SUBQUERY: - reply = qre.execInsertSubquery(conn) + reply, err = qre.execInsertSubquery(conn) case planbuilder.PLAN_DML_PK: - reply = qre.execDMLPK(conn, invalidator) + reply, err = qre.execDMLPK(conn, invalidator) case planbuilder.PLAN_DML_SUBQUERY: - reply = qre.execDMLSubquery(conn, invalidator) + reply, err = qre.execDMLSubquery(conn, invalidator) default: - panic(NewTabletError(ErrFatal, "unsupported query: %s", qre.query)) + return nil, NewTabletError(ErrFatal, "unsupported query: %s", qre.query) } - return reply + return reply, nil } -func (qre *QueryExecutor) checkPermissions() { +func (qre *QueryExecutor) checkPermissions() error { // Skip permissions check if we have a background context. if qre.ctx == context.Background() { - return + return nil } // Blacklist @@ -191,9 +203,9 @@ func (qre *QueryExecutor) checkPermissions() { action, desc := qre.plan.Rules.getAction(remoteAddr, username, qre.bindVars) switch action { case QR_FAIL: - panic(NewTabletError(ErrFail, "Query disallowed due to rule: %s", desc)) + return NewTabletError(ErrFail, "Query disallowed due to rule: %s", desc) case QR_FAIL_RETRY: - panic(NewTabletError(ErrRetry, "Query disallowed due to rule: %s", desc)) + return NewTabletError(ErrRetry, "Query disallowed due to rule: %s", desc) } // Perform table ACL check if it is enabled @@ -201,16 +213,17 @@ func (qre *QueryExecutor) checkPermissions() { errStr := fmt.Sprintf("table acl error: %q cannot run %v on table %q", username, qre.plan.PlanId, qre.plan.TableName) // Raise error if in strictTableAcl mode, else just log an error if qre.qe.strictTableAcl { - panic(NewTabletError(ErrFail, "%s", errStr)) + return NewTabletError(ErrFail, "%s", errStr) } qre.qe.accessCheckerLogger.Errorf("%s", errStr) } + return nil } -func (qre *QueryExecutor) execDDL() *mproto.QueryResult { +func (qre *QueryExecutor) execDDL() (*mproto.QueryResult, error) { ddlPlan := planbuilder.DDLParse(qre.query) if ddlPlan.Action == "" { - panic(NewTabletError(ErrFail, "DDL is not understood")) + return nil, NewTabletError(ErrFail, "DDL is not understood") } txid := qre.qe.txPool.Begin(qre.ctx) @@ -219,8 +232,10 @@ func (qre *QueryExecutor) execDDL() *mproto.QueryResult { // Stolen from Execute conn := qre.qe.txPool.Get(txid) defer conn.Recycle() - result := qre.execSQL(conn, qre.query, false) - + result, err := qre.execSQL(conn, qre.query, false) + if err != nil { + return nil, err + } if ddlPlan.TableName != "" && ddlPlan.TableName != ddlPlan.NewName { // It's a drop or rename. qre.qe.schemaInfo.DropTable(ddlPlan.TableName) @@ -228,31 +243,37 @@ func (qre *QueryExecutor) execDDL() *mproto.QueryResult { if ddlPlan.NewName != "" { qre.qe.schemaInfo.CreateOrUpdateTable(qre.ctx, ddlPlan.NewName) } - return result + return result, nil } -func (qre *QueryExecutor) execPKIN() (result *mproto.QueryResult) { +func (qre *QueryExecutor) execPKIN() (*mproto.QueryResult, error) { pkRows, err := buildValueList(qre.plan.TableInfo, qre.plan.PKValues, qre.bindVars) if err != nil { - panic(err) + return nil, err } - return qre.fetchMulti(pkRows, getLimit(qre.plan.Limit, qre.bindVars)) + limit, err := getLimit(qre.plan.Limit, qre.bindVars) + if err != nil { + return nil, err + } + return qre.fetchMulti(pkRows, limit) } -func (qre *QueryExecutor) execSubquery() (result *mproto.QueryResult) { - innerResult := qre.qFetch(qre.logStats, qre.plan.Subquery, qre.bindVars) +func (qre *QueryExecutor) execSubquery() (*mproto.QueryResult, error) { + innerResult, err := qre.qFetch(qre.logStats, qre.plan.Subquery, qre.bindVars) + if err != nil { + return nil, err + } return qre.fetchMulti(innerResult.Rows, -1) } -func (qre *QueryExecutor) fetchMulti(pkRows [][]sqltypes.Value, limit int64) (result *mproto.QueryResult) { +func (qre *QueryExecutor) fetchMulti(pkRows [][]sqltypes.Value, limit int64) (*mproto.QueryResult, error) { if qre.plan.Fields == nil { - panic("unexpected") + return nil, NewTabletError(ErrFatal, "query plan.Fields is empty") } - result = &mproto.QueryResult{Fields: qre.plan.Fields} + result := &mproto.QueryResult{Fields: qre.plan.Fields} if len(pkRows) == 0 || limit == 0 { - return + return result, nil } - tableInfo := qre.plan.TableInfo keys := make([]string, len(pkRows)) for i, pk := range pkRows { @@ -266,7 +287,10 @@ func (qre *QueryExecutor) fetchMulti(pkRows [][]sqltypes.Value, limit int64) (re rcresult := rcresults[keys[i]] if rcresult.Row != nil { if qre.mustVerify() { - qre.spotCheck(rcresult, pk) + err := qre.spotCheck(rcresult, pk) + if err != nil { + return nil, err + } } rows = append(rows, applyFilter(qre.plan.ColumnNumbers, rcresult.Row)) hits++ @@ -281,7 +305,10 @@ func (qre *QueryExecutor) fetchMulti(pkRows [][]sqltypes.Value, limit int64) (re Rows: missingRows, }, } - resultFromdb := qre.qFetch(qre.logStats, qre.plan.OuterQuery, bv) + resultFromdb, err := qre.qFetch(qre.logStats, qre.plan.OuterQuery, bv) + if err != nil { + return nil, err + } misses = int64(len(resultFromdb.Rows)) absent = int64(len(pkRows)) - hits - misses for _, row := range resultFromdb.Rows { @@ -307,14 +334,14 @@ func (qre *QueryExecutor) fetchMulti(pkRows [][]sqltypes.Value, limit int64) (re result.Rows = result.Rows[:limit] result.RowsAffected = uint64(limit) } - return result + return result, nil } func (qre *QueryExecutor) mustVerify() bool { return (Rand() % spotCheckMultiplier) < qre.qe.spotCheckFreq.Get() } -func (qre *QueryExecutor) spotCheck(rcresult RCResult, pk []sqltypes.Value) { +func (qre *QueryExecutor) spotCheck(rcresult RCResult, pk []sqltypes.Value) error { qre.qe.queryServiceStats.SpotCheckCount.Add(1) bv := map[string]interface{}{ "#pk": sqlparser.TupleEqualityList{ @@ -322,7 +349,10 @@ func (qre *QueryExecutor) spotCheck(rcresult RCResult, pk []sqltypes.Value) { Rows: [][]sqltypes.Value{pk}, }, } - resultFromdb := qre.qFetch(qre.logStats, qre.plan.OuterQuery, bv) + resultFromdb, err := qre.qFetch(qre.logStats, qre.plan.OuterQuery, bv) + if err != nil { + return err + } var dbrow []sqltypes.Value if len(resultFromdb.Rows) != 0 { dbrow = resultFromdb.Rows[0] @@ -330,6 +360,7 @@ func (qre *QueryExecutor) spotCheck(rcresult RCResult, pk []sqltypes.Value) { if dbrow == nil || !rowsAreEqual(rcresult.Row, dbrow) { qre.qe.Launch(func() { qre.recheckLater(rcresult, dbrow, pk) }) } + return nil } func (qre *QueryExecutor) recheckLater(rcresult RCResult, dbrow []sqltypes.Value, pk []sqltypes.Value) { @@ -347,45 +378,56 @@ func (qre *QueryExecutor) recheckLater(rcresult RCResult, dbrow []sqltypes.Value } // execDirect always sends the query to mysql -func (qre *QueryExecutor) execDirect(conn poolConn) (result *mproto.QueryResult) { +func (qre *QueryExecutor) execDirect(conn poolConn) (*mproto.QueryResult, error) { if qre.plan.Fields != nil { - result = qre.directFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) + result, err := qre.directFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) + if err != nil { + return nil, err + } result.Fields = qre.plan.Fields - return + return result, err } - result = qre.fullFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) - return + return qre.fullFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) } // execSelect sends a query to mysql only if another identical query is not running. Otherwise, it waits and // reuses the result. If the plan is missng field info, it sends the query to mysql requesting full info. -func (qre *QueryExecutor) execSelect() (result *mproto.QueryResult) { +func (qre *QueryExecutor) execSelect() (*mproto.QueryResult, error) { if qre.plan.Fields != nil { - result = qre.qFetch(qre.logStats, qre.plan.FullQuery, qre.bindVars) + result, err := qre.qFetch(qre.logStats, qre.plan.FullQuery, qre.bindVars) + if err != nil { + return nil, err + } result.Fields = qre.plan.Fields - return + return result, nil + } + conn, err := qre.getConn(qre.qe.connPool) + if err != nil { + return nil, err } - conn := qre.getConn(qre.qe.connPool) defer conn.Recycle() return qre.fullFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) } -func (qre *QueryExecutor) execInsertPK(conn poolConn) (result *mproto.QueryResult) { +func (qre *QueryExecutor) execInsertPK(conn poolConn) (*mproto.QueryResult, error) { pkRows, err := buildValueList(qre.plan.TableInfo, qre.plan.PKValues, qre.bindVars) if err != nil { - panic(err) + return nil, err } return qre.execInsertPKRows(conn, pkRows) } -func (qre *QueryExecutor) execInsertSubquery(conn poolConn) (result *mproto.QueryResult) { - innerResult := qre.directFetch(conn, qre.plan.Subquery, qre.bindVars, nil) +func (qre *QueryExecutor) execInsertSubquery(conn poolConn) (*mproto.QueryResult, error) { + innerResult, err := qre.directFetch(conn, qre.plan.Subquery, qre.bindVars, nil) + if err != nil { + return nil, err + } innerRows := innerResult.Rows if len(innerRows) == 0 { - return &mproto.QueryResult{RowsAffected: 0} + return &mproto.QueryResult{RowsAffected: 0}, nil } if len(qre.plan.ColumnNumbers) != len(innerRows[0]) { - panic(NewTabletError(ErrFail, "Subquery length does not match column list")) + return nil, NewTabletError(ErrFail, "Subquery length does not match column list") } pkRows := make([][]sqltypes.Value, len(innerRows)) for i, innerRow := range innerRows { @@ -393,46 +435,48 @@ func (qre *QueryExecutor) execInsertSubquery(conn poolConn) (result *mproto.Quer } // Validating first row is sufficient if err := validateRow(qre.plan.TableInfo, qre.plan.TableInfo.PKColumns, pkRows[0]); err != nil { - panic(err) + return nil, err } qre.bindVars["#values"] = innerRows return qre.execInsertPKRows(conn, pkRows) } -func (qre *QueryExecutor) execInsertPKRows(conn poolConn, pkRows [][]sqltypes.Value) (result *mproto.QueryResult) { +func (qre *QueryExecutor) execInsertPKRows(conn poolConn, pkRows [][]sqltypes.Value) (*mproto.QueryResult, error) { secondaryList, err := buildSecondaryList(qre.plan.TableInfo, pkRows, qre.plan.SecondaryPKValues, qre.bindVars) if err != nil { - panic(err) + return nil, err } bsc := buildStreamComment(qre.plan.TableInfo, pkRows, secondaryList) - result = qre.directFetch(conn, qre.plan.OuterQuery, qre.bindVars, bsc) - return result + return qre.directFetch(conn, qre.plan.OuterQuery, qre.bindVars, bsc) } -func (qre *QueryExecutor) execDMLPK(conn poolConn, invalidator CacheInvalidator) (result *mproto.QueryResult) { +func (qre *QueryExecutor) execDMLPK(conn poolConn, invalidator CacheInvalidator) (*mproto.QueryResult, error) { pkRows, err := buildValueList(qre.plan.TableInfo, qre.plan.PKValues, qre.bindVars) if err != nil { - panic(err) + return nil, err } return qre.execDMLPKRows(conn, pkRows, invalidator) } -func (qre *QueryExecutor) execDMLSubquery(conn poolConn, invalidator CacheInvalidator) (result *mproto.QueryResult) { - innerResult := qre.directFetch(conn, qre.plan.Subquery, qre.bindVars, nil) +func (qre *QueryExecutor) execDMLSubquery(conn poolConn, invalidator CacheInvalidator) (*mproto.QueryResult, error) { + innerResult, err := qre.directFetch(conn, qre.plan.Subquery, qre.bindVars, nil) + if err != nil { + return nil, err + } return qre.execDMLPKRows(conn, innerResult.Rows, invalidator) } -func (qre *QueryExecutor) execDMLPKRows(conn poolConn, pkRows [][]sqltypes.Value, invalidator CacheInvalidator) (result *mproto.QueryResult) { +func (qre *QueryExecutor) execDMLPKRows(conn poolConn, pkRows [][]sqltypes.Value, invalidator CacheInvalidator) (*mproto.QueryResult, error) { if len(pkRows) == 0 { - return &mproto.QueryResult{RowsAffected: 0} + return &mproto.QueryResult{RowsAffected: 0}, nil } secondaryList, err := buildSecondaryList(qre.plan.TableInfo, pkRows, qre.plan.SecondaryPKValues, qre.bindVars) if err != nil { - panic(err) + return nil, err } - result = &mproto.QueryResult{} + result := &mproto.QueryResult{} maxRows := int(qre.qe.maxDMLRows.Get()) for i := 0; i < len(pkRows); i += maxRows { end := i + maxRows @@ -449,93 +493,157 @@ func (qre *QueryExecutor) execDMLPKRows(conn poolConn, pkRows [][]sqltypes.Value Columns: qre.plan.TableInfo.Indexes[0].Columns, Rows: pkRows, } - r := qre.directFetch(conn, qre.plan.OuterQuery, qre.bindVars, bsc) + r, err := qre.directFetch(conn, qre.plan.OuterQuery, qre.bindVars, bsc) + if err != nil { + return nil, err + } // DMLs should only return RowsAffected. result.RowsAffected += r.RowsAffected } if invalidator == nil { - return result + return result, nil } for _, pk := range pkRows { key := buildKey(pk) invalidator.Delete(key) } - return result + return result, nil } -func (qre *QueryExecutor) execSet() (result *mproto.QueryResult) { +func (qre *QueryExecutor) execSet() (*mproto.QueryResult, error) { switch qre.plan.SetKey { case "vt_pool_size": - qre.qe.connPool.SetCapacity(int(getInt64(qre.plan.SetValue))) + val, err := parseInt64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_pool_size = %v, want to int64", err) + } + qre.qe.connPool.SetCapacity(int(val)) case "vt_stream_pool_size": - qre.qe.streamConnPool.SetCapacity(int(getInt64(qre.plan.SetValue))) + val, err := parseInt64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_stream_pool_size = %v, want int64", err) + } + qre.qe.streamConnPool.SetCapacity(int(val)) case "vt_transaction_cap": - qre.qe.txPool.pool.SetCapacity(int(getInt64(qre.plan.SetValue))) + val, err := parseInt64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_transaction_cap = %v, want int64", err) + } + qre.qe.txPool.pool.SetCapacity(int(val)) case "vt_transaction_timeout": - qre.qe.txPool.SetTimeout(getDuration(qre.plan.SetValue)) + val, err := parseDuration(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_transaction_timeout = %v, want int64 or float64", err) + } + qre.qe.txPool.SetTimeout(val) case "vt_schema_reload_time": - qre.qe.schemaInfo.SetReloadTime(getDuration(qre.plan.SetValue)) + val, err := parseDuration(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_schema_reload_time = %v, want int64 or float64", err) + } + qre.qe.schemaInfo.SetReloadTime(val) case "vt_query_cache_size": - qre.qe.schemaInfo.SetQueryCacheSize(int(getInt64(qre.plan.SetValue))) + val, err := parseInt64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_query_cache_size = %v, want int64", err) + } + qre.qe.schemaInfo.SetQueryCacheSize(int(val)) case "vt_max_result_size": - val := getInt64(qre.plan.SetValue) + val, err := parseInt64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_max_result_size = %v, want int64", err) + } if val < 1 { - panic(NewTabletError(ErrFail, "vt_max_result_size out of range %v", val)) + return nil, NewTabletError(ErrFail, "vt_max_result_size out of range %v", val) } qre.qe.maxResultSize.Set(val) case "vt_max_dml_rows": - val := getInt64(qre.plan.SetValue) + val, err := parseInt64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_max_dml_rows = %v, want to int64", err) + } if val < 1 { - panic(NewTabletError(ErrFail, "vt_max_dml_rows out of range %v", val)) + return nil, NewTabletError(ErrFail, "vt_max_dml_rows out of range %v", val) } qre.qe.maxDMLRows.Set(val) case "vt_stream_buffer_size": - val := getInt64(qre.plan.SetValue) + val, err := parseInt64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_stream_buffer_size = %v, want int64", err) + } + if val < 1024 { - panic(NewTabletError(ErrFail, "vt_stream_buffer_size out of range %v", val)) + return nil, NewTabletError(ErrFail, "vt_stream_buffer_size out of range %v", val) } qre.qe.streamBufferSize.Set(val) case "vt_query_timeout": - qre.qe.queryTimeout.Set(getDuration(qre.plan.SetValue)) + val, err := parseDuration(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_query_timeout = %v, want int64 or float64", err) + } + qre.qe.queryTimeout.Set(val) case "vt_idle_timeout": - t := getDuration(qre.plan.SetValue) - qre.qe.connPool.SetIdleTimeout(t) - qre.qe.streamConnPool.SetIdleTimeout(t) - qre.qe.txPool.pool.SetIdleTimeout(t) + val, err := parseDuration(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_idle_timeout = %v, want int64 or float64", err) + } + qre.qe.connPool.SetIdleTimeout(val) + qre.qe.streamConnPool.SetIdleTimeout(val) + qre.qe.txPool.pool.SetIdleTimeout(val) case "vt_spot_check_ratio": - qre.qe.spotCheckFreq.Set(int64(getFloat64(qre.plan.SetValue) * spotCheckMultiplier)) + val, err := parseFloat64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_spot_check_ratio = %v, want float64", err) + } + qre.qe.spotCheckFreq.Set(int64(val * spotCheckMultiplier)) case "vt_strict_mode": - qre.qe.strictMode.Set(getInt64(qre.plan.SetValue)) + val, err := parseInt64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_strict_mode = %v, want to int64", err) + } + qre.qe.strictMode.Set(val) case "vt_txpool_timeout": - t := getDuration(qre.plan.SetValue) - qre.qe.txPool.SetPoolTimeout(t) + val, err := parseDuration(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_txpool_timeout = %v, want int64 or float64", err) + } + qre.qe.txPool.SetPoolTimeout(val) default: - conn := qre.getConn(qre.qe.connPool) + conn, err := qre.getConn(qre.qe.connPool) + if err != nil { + return nil, err + } defer conn.Recycle() return qre.directFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) } - return &mproto.QueryResult{} + return &mproto.QueryResult{}, nil } -func getInt64(v interface{}) int64 { +func parseInt64(v interface{}) (int64, error) { if ival, ok := v.(int64); ok { - return ival + return ival, nil } - panic(NewTabletError(ErrFail, "expecting int")) + return -1, NewTabletError(ErrFail, "got %v, want int64", v) } -func getFloat64(v interface{}) float64 { +func parseFloat64(v interface{}) (float64, error) { if ival, ok := v.(int64); ok { - return float64(ival) + return float64(ival), nil } if fval, ok := v.(float64); ok { - return fval + return fval, nil } - panic(NewTabletError(ErrFail, "expecting number")) + return -1, NewTabletError(ErrFail, "got %v, want int64 or float64", v) } -func getDuration(v interface{}) time.Duration { - return time.Duration(getFloat64(v) * 1e9) +func parseDuration(v interface{}) (time.Duration, error) { + val, err := parseFloat64(v) + if err != nil { + return 0, err + } + // time.Duration is an int64, have to multiple by 1e9 because + // val might be in range (0, 1) + return time.Duration(val * 1e9), nil } func rowsAreEqual(row1, row2 []sqltypes.Value) bool { @@ -553,21 +661,24 @@ func rowsAreEqual(row1, row2 []sqltypes.Value) bool { return true } -func (qre *QueryExecutor) getConn(pool *ConnPool) *DBConn { +func (qre *QueryExecutor) getConn(pool *ConnPool) (*DBConn, error) { start := time.Now() conn, err := pool.Get(qre.ctx) switch err { case nil: qre.logStats.WaitingForConnection += time.Now().Sub(start) - return conn + return conn, nil case ErrConnPoolClosed: - panic(err) + return nil, err } - panic(NewTabletErrorSql(ErrFatal, err)) + return nil, NewTabletErrorSql(ErrFatal, err) } -func (qre *QueryExecutor) qFetch(logStats *SQLQueryStats, parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}) (result *mproto.QueryResult) { - sql := qre.generateFinalSql(parsedQuery, bindVars, nil) +func (qre *QueryExecutor) qFetch(logStats *SQLQueryStats, parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}) (*mproto.QueryResult, error) { + sql, err := qre.generateFinalSQL(parsedQuery, bindVars, nil) + if err != nil { + return nil, err + } q, ok := qre.qe.consolidator.Create(string(sql)) if ok { defer q.Broadcast() @@ -578,7 +689,7 @@ func (qre *QueryExecutor) qFetch(logStats *SQLQueryStats, parsedQuery *sqlparser q.Err = NewTabletErrorSql(ErrFatal, err) } else { defer conn.Recycle() - q.Result, q.Err = qre.execSQLNoPanic(conn, sql, false) + q.Result, q.Err = qre.execSQL(conn, sql, false) } } else { logStats.QuerySources |= QuerySourceConsolidator @@ -587,59 +698,61 @@ func (qre *QueryExecutor) qFetch(logStats *SQLQueryStats, parsedQuery *sqlparser qre.qe.queryServiceStats.WaitStats.Record("Consolidations", startTime) } if q.Err != nil { - panic(q.Err) + return nil, q.Err } - return q.Result.(*mproto.QueryResult) + return q.Result.(*mproto.QueryResult), nil } -func (qre *QueryExecutor) directFetch(conn poolConn, parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}, buildStreamComment []byte) (result *mproto.QueryResult) { - sql := qre.generateFinalSql(parsedQuery, bindVars, buildStreamComment) +func (qre *QueryExecutor) directFetch(conn poolConn, parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}, buildStreamComment []byte) (*mproto.QueryResult, error) { + sql, err := qre.generateFinalSQL(parsedQuery, bindVars, buildStreamComment) + if err != nil { + return nil, err + } return qre.execSQL(conn, sql, false) } // fullFetch also fetches field info -func (qre *QueryExecutor) fullFetch(conn poolConn, parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}, buildStreamComment []byte) (result *mproto.QueryResult) { - sql := qre.generateFinalSql(parsedQuery, bindVars, buildStreamComment) +func (qre *QueryExecutor) fullFetch(conn poolConn, parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}, buildStreamComment []byte) (*mproto.QueryResult, error) { + sql, err := qre.generateFinalSQL(parsedQuery, bindVars, buildStreamComment) + if err != nil { + return nil, err + } return qre.execSQL(conn, sql, true) } -func (qre *QueryExecutor) fullStreamFetch(conn *DBConn, parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}, buildStreamComment []byte, callback func(*mproto.QueryResult) error) { - sql := qre.generateFinalSql(parsedQuery, bindVars, buildStreamComment) - qre.execStreamSQL(conn, sql, callback) +func (qre *QueryExecutor) fullStreamFetch(conn *DBConn, parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}, buildStreamComment []byte, callback func(*mproto.QueryResult) error) error { + sql, err := qre.generateFinalSQL(parsedQuery, bindVars, buildStreamComment) + if err != nil { + return err + } + return qre.execStreamSQL(conn, sql, callback) } -func (qre *QueryExecutor) generateFinalSql(parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}, buildStreamComment []byte) string { +func (qre *QueryExecutor) generateFinalSQL(parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}, buildStreamComment []byte) (string, error) { bindVars["#maxLimit"] = qre.qe.maxResultSize.Get() + 1 sql, err := parsedQuery.GenerateQuery(bindVars) if err != nil { - panic(NewTabletError(ErrFail, "%s", err)) + return "", NewTabletError(ErrFail, "%s", err) } if buildStreamComment != nil { sql = append(sql, buildStreamComment...) } // undo hack done by stripTrailing sql = restoreTrailing(sql, bindVars) - return hack.String(sql) -} - -func (qre *QueryExecutor) execSQL(conn poolConn, sql string, wantfields bool) *mproto.QueryResult { - result, err := qre.execSQLNoPanic(conn, sql, wantfields) - if err != nil { - panic(err) - } - return result + return hack.String(sql), nil } -func (qre *QueryExecutor) execSQLNoPanic(conn poolConn, sql string, wantfields bool) (*mproto.QueryResult, error) { +func (qre *QueryExecutor) execSQL(conn poolConn, sql string, wantfields bool) (*mproto.QueryResult, error) { defer qre.logStats.AddRewrittenSql(sql, time.Now()) return conn.Exec(qre.ctx, sql, int(qre.qe.maxResultSize.Get()), wantfields) } -func (qre *QueryExecutor) execStreamSQL(conn *DBConn, sql string, callback func(*mproto.QueryResult) error) { +func (qre *QueryExecutor) execStreamSQL(conn *DBConn, sql string, callback func(*mproto.QueryResult) error) error { start := time.Now() err := conn.Stream(qre.ctx, sql, callback, int(qre.qe.streamBufferSize.Get())) qre.logStats.AddRewrittenSql(sql, start) if err != nil { - panic(NewTabletErrorSql(ErrFail, err)) + return NewTabletErrorSql(ErrFail, err) } + return nil } diff --git a/go/vt/tabletserver/query_executor_test.go b/go/vt/tabletserver/query_executor_test.go index abc7c5b3b47..f091bff0b47 100644 --- a/go/vt/tabletserver/query_executor_test.go +++ b/go/vt/tabletserver/query_executor_test.go @@ -26,31 +26,41 @@ import ( func TestQueryExecutorPlanDDL(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "alter table test_table add zipcode int" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) qre, sqlQuery := newTestQueryExecutor( query, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_DDL, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanPassDmlStrictMode(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "update test_table set pk = foo()" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) // non strict mode qre, sqlQuery := newTestQueryExecutor(query, context.Background(), enableTx) checkPlanID(t, planbuilder.PLAN_PASS_DML, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } testCommitHelper(t, sqlQuery, qre) sqlQuery.disallowQueries() @@ -62,39 +72,57 @@ func TestQueryExecutorPlanPassDmlStrictMode(t *testing.T) { defer sqlQuery.disallowQueries() defer testCommitHelper(t, sqlQuery, qre) checkPlanID(t, planbuilder.PLAN_PASS_DML, qre.plan.PlanId) - defer handleAndVerifyTabletError( - t, - "update should fail because strict mode is enabled", - ErrFail) - qre.Execute() + got, err = qre.Execute() + if err == nil { + t.Fatal("qre.Execute() = nil, want error") + } + tabletError, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: a TabletError", tabletError) + } + if tabletError.ErrorType != ErrFail { + t.Fatalf("got: %s, want: ErrFail", getTabletErrorString(ErrFail)) + } } func TestQueryExecutorPlanPassDmlStrictModeAutoCommit(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "update test_table set pk = foo()" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) // non strict mode qre, sqlQuery := newTestQueryExecutor(query, context.Background(), noFlags) checkPlanID(t, planbuilder.PLAN_PASS_DML, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } sqlQuery.disallowQueries() // strict mode + // update should fail because strict mode is not enabled qre, sqlQuery = newTestQueryExecutor( "update test_table set pk = foo()", context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_PASS_DML, qre.plan.PlanId) - defer handleAndVerifyTabletError( - t, - "update should fail because strict mode is not enabled", - ErrFail) - qre.Execute() + _, err = qre.Execute() + if err == nil { + t.Fatal("got: nil, want: error") + } + tabletError, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: *TabletError", tabletError) + } + if tabletError.ErrorType != ErrFail { + t.Fatalf("got: %s, want: ErrFail", getTabletErrorString(ErrFail)) + } } func TestQueryExecutorPlanInsertPk(t *testing.T) { @@ -110,20 +138,22 @@ func TestQueryExecutorPlanInsertPk(t *testing.T) { enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_INSERT_PK, qre.plan.PlanId) - got := qre.Execute() + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } if !reflect.DeepEqual(got, want) { - t.Fatalf("query: %s, QueryExecutor.Execute() = %v, want: %v", sql, got, want) + t.Fatalf("got: %v, want: %v", got, want) } } func TestQueryExecutorPlanInsertSubQueryAutoCommmit(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "insert into test_table(pk) select pk from test_table where pk = 1 limit 1000" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) selectQuery := "select pk from test_table where pk = 1 limit 1000" db.AddQuery(selectQuery, &mproto.QueryResult{ RowsAffected: 1, @@ -140,17 +170,22 @@ func TestQueryExecutorPlanInsertSubQueryAutoCommmit(t *testing.T) { query, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_INSERT_SUBQUERY, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanInsertSubQuery(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "insert into test_table(pk) select pk from test_table where pk = 1 limit 1000" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) selectQuery := "select pk from test_table where pk = 1 limit 1000" db.AddQuery(selectQuery, &mproto.QueryResult{ RowsAffected: 1, @@ -168,107 +203,130 @@ func TestQueryExecutorPlanInsertSubQuery(t *testing.T) { defer sqlQuery.disallowQueries() defer testCommitHelper(t, sqlQuery, qre) checkPlanID(t, planbuilder.PLAN_INSERT_SUBQUERY, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanDmlPk(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "update test_table set name = 2 where pk in (1) /* _stream test_table (pk ) (1 ); */" - // expected.Rows is always nil, intended - expected := &mproto.QueryResult{} - db.AddQuery(query, expected) - + want := &mproto.QueryResult{} + db.AddQuery(query, want) qre, sqlQuery := newTestQueryExecutor( query, context.Background(), enableRowCache|enableTx|enableStrict) defer sqlQuery.disallowQueries() defer testCommitHelper(t, sqlQuery, qre) checkPlanID(t, planbuilder.PLAN_DML_PK, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanDmlAutoCommit(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "update test_table set name = 2 where pk in (1) /* _stream test_table (pk ) (1 ); */" - // expected.Rows is always nil, intended - expected := &mproto.QueryResult{} - db.AddQuery(query, expected) - + want := &mproto.QueryResult{} + db.AddQuery(query, want) qre, sqlQuery := newTestQueryExecutor( query, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_DML_PK, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanDmlSubQuery(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "update test_table set addr = 3 where name = 1 limit 1000" expandedQuery := "select pk from test_table where name = 1 limit 1000 for update" - // expected.Rows is always nil, intended - expected := &mproto.QueryResult{} - db.AddQuery(query, expected) - db.AddQuery(expandedQuery, expected) + want := &mproto.QueryResult{} + db.AddQuery(query, want) + db.AddQuery(expandedQuery, want) qre, sqlQuery := newTestQueryExecutor( query, context.Background(), enableRowCache|enableTx|enableStrict) defer sqlQuery.disallowQueries() defer testCommitHelper(t, sqlQuery, qre) checkPlanID(t, planbuilder.PLAN_DML_SUBQUERY, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanDmlSubQueryAutoCommit(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "update test_table set addr = 3 where name = 1 limit 1000" expandedQuery := "select pk from test_table where name = 1 limit 1000 for update" - // expected.Rows is always nil, intended - expected := &mproto.QueryResult{} - db.AddQuery(query, expected) - db.AddQuery(expandedQuery, expected) + want := &mproto.QueryResult{} + db.AddQuery(query, want) + db.AddQuery(expandedQuery, want) qre, sqlQuery := newTestQueryExecutor( query, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_DML_SUBQUERY, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanOtherWithinATransaction(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "show test_table" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Fields: getTestTableFields(), RowsAffected: 0, Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) qre, sqlQuery := newTestQueryExecutor( query, context.Background(), enableTx|enableRowCache|enableSchemaOverrides|enableStrict) defer sqlQuery.disallowQueries() defer testCommitHelper(t, sqlQuery, qre) checkPlanID(t, planbuilder.PLAN_OTHER, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanPassSelectWithInATransaction(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} fields := []mproto.Field{ mproto.Field{Name: "addr", Type: mproto.VT_LONG}, } - query := "select addr from test_table where pk = 1 limit 1000" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Fields: fields, RowsAffected: 1, Rows: [][]sqltypes.Value{ []sqltypes.Value{sqltypes.MakeString([]byte("123"))}, }, } - db.AddQuery(query, expected) + db.AddQuery(query, want) db.AddQuery("select addr from test_table where 1 != 1", &mproto.QueryResult{ Fields: fields, }) @@ -278,17 +336,23 @@ func TestQueryExecutorPlanPassSelectWithInATransaction(t *testing.T) { defer sqlQuery.disallowQueries() defer testCommitHelper(t, sqlQuery, qre) checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanPassSelectWithLockOutsideATransaction(t *testing.T) { db := setUpQueryExecutorTest() query := "select * from test_table for update" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Fields: getTestTableFields(), Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) db.AddQuery("select * from test_table where 1 != 1", &mproto.QueryResult{ Fields: getTestTableFields(), }) @@ -297,19 +361,27 @@ func TestQueryExecutorPlanPassSelectWithLockOutsideATransaction(t *testing.T) { query, context.Background(), enableRowCache|enableSchemaOverrides|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId) - defer handleAndVerifyTabletError(t, "query should fail because the select holds a lock but outside a transaction", ErrFail) - qre.Execute() + _, err := qre.Execute() + if err == nil { + t.Fatal("got: nil, want: error") + } + got, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: *TabletError", err) + } + if got.ErrorType != ErrFail { + t.Fatalf("got: %s, want: ErrFail", getTabletErrorString(got.ErrorType)) + } } func TestQueryExecutorPlanPassSelect(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "select * from test_table limit 1000" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Fields: getTestTableFields(), Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) db.AddQuery("select * from test_table where 1 != 1", &mproto.QueryResult{ Fields: getTestTableFields(), }) @@ -318,16 +390,20 @@ func TestQueryExecutorPlanPassSelect(t *testing.T) { query, context.Background(), enableRowCache|enableSchemaOverrides|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanPKIn(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "select * from test_table where pk in (1, 2, 3) limit 1000" expandedQuery := "select pk, name, addr from test_table where pk in (1, 2, 3)" - - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Fields: getTestTableFields(), RowsAffected: 1, Rows: [][]sqltypes.Value{ @@ -338,18 +414,22 @@ func TestQueryExecutorPlanPKIn(t *testing.T) { }, }, } - db.AddQuery(query, expected) - db.AddQuery(expandedQuery, expected) - + db.AddQuery(query, want) + db.AddQuery(expandedQuery, want) db.AddQuery("select * from test_table where 1 != 1", &mproto.QueryResult{ Fields: getTestTableFields(), }) - qre, sqlQuery := newTestQueryExecutor( query, context.Background(), enableRowCache|enableStrict|enableSchemaOverrides) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_PK_IN, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } cachedQuery := "select pk, name, addr from test_table where pk in (1)" db.AddQuery(cachedQuery, &mproto.QueryResult{ @@ -366,24 +446,26 @@ func TestQueryExecutorPlanPKIn(t *testing.T) { nonCachedQuery := "select pk, name, addr from test_table where pk in (2, 3)" db.AddQuery(nonCachedQuery, &mproto.QueryResult{}) - - db.AddQuery(cachedQuery, expected) - + db.AddQuery(cachedQuery, want) // run again, this time pk=1 should hit the rowcache - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanSelectSubQuery(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "select * from test_table where name = 1 limit 1000" expandedQuery := "select pk from test_table use index (INDEX) where name = 1 limit 1000" - - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Fields: getTestTableFields(), } - db.AddQuery(query, expected) - db.AddQuery(expandedQuery, expected) + db.AddQuery(query, want) + db.AddQuery(expandedQuery, want) db.AddQuery("select * from test_table where 1 != 1", &mproto.QueryResult{ Fields: getTestTableFields(), @@ -393,14 +475,17 @@ func TestQueryExecutorPlanSelectSubQuery(t *testing.T) { query, context.Background(), enableRowCache|enableSchemaOverrides|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SELECT_SUBQUERY, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanSet(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} - expected := &mproto.QueryResult{} - setQuery := "set unknown_key = 1" db.AddQuery(setQuery, &mproto.QueryResult{}) qre, sqlQuery := newTestQueryExecutor( @@ -411,9 +496,12 @@ func TestQueryExecutorPlanSet(t *testing.T) { want := &mproto.QueryResult{ Rows: make([][]sqltypes.Value, 0), } - got := qre.Execute() + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } if !reflect.DeepEqual(got, want) { - t.Fatalf("query: %s failed, got: %+v, want: %+v", setQuery, got, want) + t.Fatalf("qre.Execute() = %v, want: %v", got, want) } sqlQuery.disallowQueries() @@ -424,7 +512,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } if qre.qe.connPool.Capacity() != vtPoolSize { t.Fatalf("set query failed, expected to have vt_pool_size: %d, but got: %d", vtPoolSize, qre.qe.connPool.Capacity()) @@ -438,7 +533,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } if qre.qe.streamConnPool.Capacity() != vtStreamPoolSize { t.Fatalf("set query failed, expected to have vt_stream_pool_size: %d, but got: %d", vtStreamPoolSize, qre.qe.streamConnPool.Capacity()) } @@ -451,7 +553,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } if qre.qe.txPool.pool.Capacity() != vtTransactionCap { t.Fatalf("set query failed, expected to have vt_transaction_cap: %d, but got: %d", vtTransactionCap, qre.qe.txPool.pool.Capacity()) } @@ -464,7 +573,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } vtTransactionTimeoutInMillis := time.Duration(vtTransactionTimeout) * time.Second if qre.qe.txPool.Timeout() != vtTransactionTimeoutInMillis { t.Fatalf("set query failed, expected to have vt_transaction_timeout: %d, but got: %d", vtTransactionTimeoutInMillis, qre.qe.txPool.Timeout()) @@ -478,7 +594,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } vtSchemaReloadTimeInMills := time.Duration(vtSchemaReloadTime) * time.Second if qre.qe.schemaInfo.ReloadTime() != vtSchemaReloadTimeInMills { t.Fatalf("set query failed, expected to have vt_schema_reload_time: %d, but got: %d", vtSchemaReloadTimeInMills, qre.qe.schemaInfo.ReloadTime()) @@ -492,7 +615,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } if int64(qre.qe.schemaInfo.queries.Capacity()) != vtQueryCacheSize { t.Fatalf("set query failed, expected to have vt_query_cache_size: %d, but got: %d", vtQueryCacheSize, qre.qe.schemaInfo.queries.Capacity()) } @@ -505,7 +635,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } vtQueryTimeoutInMillis := time.Duration(vtQueryTimeout) * time.Second if qre.qe.queryTimeout.Get() != vtQueryTimeoutInMillis { t.Fatalf("set query failed, expected to have vt_query_timeout: %d, but got: %d", vtQueryTimeoutInMillis, qre.qe.queryTimeout.Get()) @@ -519,7 +656,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } vtIdleTimeoutInMillis := time.Duration(vtIdleTimeout) * time.Second if qre.qe.connPool.IdleTimeout() != vtIdleTimeoutInMillis { t.Fatalf("set query failed, expected to have vt_idle_timeout: %d, but got: %d in conn pool", vtIdleTimeoutInMillis, qre.qe.connPool.IdleTimeout()) @@ -539,7 +683,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } vtSpotCheckFreq := int64(vtSpotCheckRatio * spotCheckMultiplier) if qre.qe.spotCheckFreq.Get() != vtSpotCheckFreq { t.Fatalf("set query failed, expected to have vt_spot_check_freq: %d, but got: %d", vtSpotCheckFreq, qre.qe.spotCheckFreq.Get()) @@ -553,7 +704,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } if qre.qe.strictMode.Get() != vtStrictMode { t.Fatalf("set query failed, expected to have vt_strict_mode: %d, but got: %d", vtStrictMode, qre.qe.strictMode.Get()) } @@ -566,7 +724,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } vtTxPoolTimeoutInMillis := time.Duration(vtTxPoolTimeout) * time.Second if qre.qe.txPool.PoolTimeout() != vtTxPoolTimeoutInMillis { t.Fatalf("set query failed, expected to have vt_txpool_timeout: %d, but got: %d", vtTxPoolTimeoutInMillis, qre.qe.txPool.PoolTimeout()) @@ -576,26 +741,44 @@ func TestQueryExecutorPlanSet(t *testing.T) { func TestQueryExecutorPlanSetMaxResultSize(t *testing.T) { setUpQueryExecutorTest() - testUtils := &testUtils{} - expected := &mproto.QueryResult{} + want := &mproto.QueryResult{} vtMaxResultSize := int64(128) setQuery := fmt.Sprintf("set vt_max_result_size = %d", vtMaxResultSize) qre, sqlQuery := newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } if qre.qe.maxResultSize.Get() != vtMaxResultSize { t.Fatalf("set query failed, expected to have vt_max_result_size: %d, but got: %d", vtMaxResultSize, qre.qe.maxResultSize.Get()) } - // set vt_max_result_size fail - setQuery = "set vt_max_result_size = 0" - qre, sqlQuery = newTestQueryExecutor( +} + +func TestQueryExecutorPlanSetMaxResultSizeFail(t *testing.T) { + setUpQueryExecutorTest() + setQuery := "set vt_max_result_size = 0" + qre, sqlQuery := newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - defer handleAndVerifyTabletError(t, "vt_max_result_size out of range, should always larger than 0", ErrFail) - qre.Execute() + // vt_max_result_size out of range, should always larger than 0 + _, err := qre.Execute() + if err == nil { + t.Fatal("got: nil, want: error") + } + got, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: *TabletError", err) + } + if got.ErrorType != ErrFail { + t.Fatalf("got: %s, want: ErrFail", getTabletErrorString(got.ErrorType)) + } } func TestQueryExecutorPlanSetMaxDmlRows(t *testing.T) { @@ -607,78 +790,113 @@ func TestQueryExecutorPlanSetMaxDmlRows(t *testing.T) { setQuery, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - got := qre.Execute() + got, err := qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } if !reflect.DeepEqual(got, want) { - t.Fatalf("query executor Execute() = %v, want: %v", got, want) + t.Fatalf("qre.Execute() = %v, want: %v", got, want) } if qre.qe.maxDMLRows.Get() != vtMaxDmlRows { t.Fatalf("set query failed, expected to have vt_max_dml_rows: %d, but got: %d", vtMaxDmlRows, qre.qe.maxDMLRows.Get()) } - // set vt_max_result_size fail - setQuery = "set vt_max_dml_rows = 0" - qre, sqlQuery = newTestQueryExecutor( +} + +func TestQueryExecutorPlanSetMaxDmlRowsFail(t *testing.T) { + setUpQueryExecutorTest() + setQuery := "set vt_max_dml_rows = 0" + qre, sqlQuery := newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - defer handleAndVerifyTabletError(t, "vt_max_dml_rows out of range, should always larger than 0", ErrFail) - qre.Execute() + _, err := qre.Execute() + if err == nil { + t.Fatal("got: nil, want: error") + } + got, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: *TabletError", err) + } + if got.ErrorType != ErrFail { + t.Fatalf("got: %s, want: ErrFail", getTabletErrorString(got.ErrorType)) + } } func TestQueryExecutorPlanSetStreamBufferSize(t *testing.T) { setUpQueryExecutorTest() - testUtils := &testUtils{} - expected := &mproto.QueryResult{} + want := &mproto.QueryResult{} vtStreamBufferSize := int64(2048) setQuery := fmt.Sprintf("set vt_stream_buffer_size = %d", vtStreamBufferSize) qre, sqlQuery := newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } if qre.qe.streamBufferSize.Get() != vtStreamBufferSize { t.Fatalf("set query failed, expected to have vt_stream_buffer_size: %d, but got: %d", vtStreamBufferSize, qre.qe.streamBufferSize.Get()) } - // set vt_max_result_size fail - setQuery = "set vt_stream_buffer_size = 128" - qre, sqlQuery = newTestQueryExecutor( +} + +func TestQueryExecutorPlanSetStreamBufferSizeFail(t *testing.T) { + setUpQueryExecutorTest() + setQuery := "set vt_stream_buffer_size = 128" + qre, sqlQuery := newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - defer handleAndVerifyTabletError(t, "vt_stream_buffer_size out of range, should always larger than or equal to 1024", ErrFail) - qre.Execute() + _, err := qre.Execute() + if err == nil { + t.Fatal("got: nil, want: error") + } + got, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: *TabletError", err) + } + if got.ErrorType != ErrFail { + t.Fatalf("got: %s, want: ErrFail", getTabletErrorString(got.ErrorType)) + } } func TestQueryExecutorPlanOther(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "show test_table" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Fields: getTestTableFields(), RowsAffected: 0, Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) qre, sqlQuery := newTestQueryExecutor( query, context.Background(), enableRowCache|enableSchemaOverrides|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_OTHER, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } } func TestQueryExecutorTableAcl(t *testing.T) { - testUtils := &testUtils{} aclName := fmt.Sprintf("simpleacl-test-%d", rand.Int63()) tableacl.Register(aclName, &simpleacl.Factory{}) tableacl.SetDefaultACL(aclName) - db := setUpQueryExecutorTest() query := "select * from test_table limit 1000" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Fields: getTestTableFields(), RowsAffected: 0, Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) db.AddQuery("select * from test_table where 1 != 1", &mproto.QueryResult{ Fields: getTestTableFields(), }) @@ -702,11 +920,41 @@ func TestQueryExecutorTableAcl(t *testing.T) { qre, sqlQuery := newTestQueryExecutor( query, ctx, enableRowCache|enableSchemaOverrides|enableStrict) + defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) - sqlQuery.disallowQueries() + got, err := qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } +} - config = &tableaclpb.Config{ +func TestQueryExecutorTableAclNoPermission(t *testing.T) { + aclName := fmt.Sprintf("simpleacl-test-%d", rand.Int63()) + tableacl.Register(aclName, &simpleacl.Factory{}) + tableacl.SetDefaultACL(aclName) + db := setUpQueryExecutorTest() + query := "select * from test_table limit 1000" + want := &mproto.QueryResult{ + Fields: getTestTableFields(), + RowsAffected: 0, + Rows: [][]sqltypes.Value{}, + } + db.AddQuery(query, want) + db.AddQuery("select * from test_table where 1 != 1", &mproto.QueryResult{ + Fields: getTestTableFields(), + }) + + username := "u2" + callInfo := &fakeCallInfo{ + remoteAddr: "1.2.3.4", + username: username, + } + ctx := callinfo.NewContext(context.Background(), callInfo) + + config := &tableaclpb.Config{ TableGroups: []*tableaclpb.TableGroupSpec{{ Name: "group02", TableNamesOrPrefixes: []string{"test_table"}, @@ -718,18 +966,35 @@ func TestQueryExecutorTableAcl(t *testing.T) { t.Fatalf("unable to load tableacl config, error: %v", err) } // without enabling Config.StrictTableAcl - qre, sqlQuery = newTestQueryExecutor( + qre, sqlQuery := newTestQueryExecutor( query, ctx, enableRowCache|enableSchemaOverrides|enableStrict) checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId) - qre.Execute() + got, err := qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } sqlQuery.disallowQueries() + // enable Config.StrictTableAcl qre, sqlQuery = newTestQueryExecutor( query, ctx, enableRowCache|enableSchemaOverrides|enableStrict|enableStrictTableAcl) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId) - defer handleAndVerifyTabletError(t, "query should fail because current user do not have read permissions", ErrFail) - qre.Execute() + // query should fail because current user do not have read permissions + _, err = qre.Execute() + if err == nil { + t.Fatal("got: nil, want: error") + } + tabletError, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: *TabletError", err) + } + if tabletError.ErrorType != ErrFail { + t.Fatalf("got: %s, want: ErrFail", getTabletErrorString(tabletError.ErrorType)) + } } func TestQueryExecutorBlacklistQRFail(t *testing.T) { @@ -776,8 +1041,18 @@ func TestQueryExecutorBlacklistQRFail(t *testing.T) { qre, sqlQuery := newTestQueryExecutor(query, ctx, enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SELECT_SUBQUERY, qre.plan.PlanId) - defer handleAndVerifyTabletError(t, "execute should fail because query has been blacklisted", ErrFail) - qre.Execute() + // execute should fail because query has been blacklisted + _, err := qre.Execute() + if err == nil { + t.Fatal("got: nil, want: error") + } + got, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: *TabletError", err) + } + if got.ErrorType != ErrFail { + t.Fatalf("got: %s, want: ErrFail", getTabletErrorString(got.ErrorType)) + } } func TestQueryExecutorBlacklistQRRetry(t *testing.T) { @@ -824,8 +1099,17 @@ func TestQueryExecutorBlacklistQRRetry(t *testing.T) { qre, sqlQuery := newTestQueryExecutor(query, ctx, enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SELECT_SUBQUERY, qre.plan.PlanId) - defer handleAndVerifyTabletError(t, "execute should fail because query has been blacklisted", ErrRetry) - qre.Execute() + _, err := qre.Execute() + if err == nil { + t.Fatal("got: nil, want: error") + } + got, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: *TabletError", err) + } + if got.ErrorType != ErrRetry { + t.Fatalf("got: %s, want: ErrRetry", getTabletErrorString(got.ErrorType)) + } } type executorFlags int64 diff --git a/go/vt/tabletserver/sqlquery.go b/go/vt/tabletserver/sqlquery.go index 7009d2df3ff..833a39f55c9 100644 --- a/go/vt/tabletserver/sqlquery.go +++ b/go/vt/tabletserver/sqlquery.go @@ -348,28 +348,7 @@ func (sq *SqlQuery) Rollback(ctx context.Context, session *proto.Session) (err e // the supplied error return value. func (sq *SqlQuery) handleExecError(query *proto.Query, err *error, logStats *SQLQueryStats) { if x := recover(); x != nil { - terr, ok := x.(*TabletError) - if !ok { - log.Errorf("Uncaught panic for %v:\n%v\n%s", query, x, tb.Stack(4)) - *err = NewTabletError(ErrFail, "%v: uncaught panic for %v", x, query) - sq.qe.queryServiceStats.InternalErrors.Add("Panic", 1) - return - } - if sq.config.TerseErrors && terr.SqlError != 0 { - *err = fmt.Errorf("%s(errno %d) during query: %s", terr.Prefix(), terr.SqlError, query.Sql) - } else { - *err = terr - } - terr.RecordStats(sq.qe.queryServiceStats) - // suppress these errors in logs - if terr.ErrorType == ErrRetry || terr.ErrorType == ErrTxPoolFull || terr.SqlError == mysql.ErrDupEntry { - return - } - if terr.ErrorType == ErrFatal { - log.Errorf("%v: %v", terr, query) - } else { - log.Warningf("%v: %v", terr, query) - } + *err = sq.handleExecErrorNoPanic(query, x, logStats) } if logStats != nil { logStats.Error = *err @@ -377,6 +356,31 @@ func (sq *SqlQuery) handleExecError(query *proto.Query, err *error, logStats *SQ } } +func (sq *SqlQuery) handleExecErrorNoPanic(query *proto.Query, err interface{}, logStats *SQLQueryStats) error { + terr, ok := err.(*TabletError) + if !ok { + log.Errorf("Uncaught panic for %v:\n%v\n%s", query, err, tb.Stack(4)) + sq.qe.queryServiceStats.InternalErrors.Add("Panic", 1) + return NewTabletError(ErrFail, "%v: uncaught panic for %v", err, query) + } + var myError error + if sq.config.TerseErrors && terr.SqlError != 0 { + myError = fmt.Errorf("%s(errno %d) during query: %s", terr.Prefix(), terr.SqlError, query.Sql) + } else { + myError = terr + } + terr.RecordStats(sq.qe.queryServiceStats) + // suppress these errors in logs + if terr.ErrorType == ErrRetry || terr.ErrorType == ErrTxPoolFull || terr.SqlError == mysql.ErrDupEntry { + return myError + } + if terr.ErrorType == ErrFatal { + log.Errorf("%v: %v", terr, query) + } + log.Warningf("%v: %v", terr, query) + return myError +} + // Execute executes the query and returns the result as response. func (sq *SqlQuery) Execute(ctx context.Context, query *proto.Query, reply *mproto.QueryResult) (err error) { logStats := newSqlQueryStats("Execute", ctx) @@ -405,7 +409,11 @@ func (sq *SqlQuery) Execute(ctx context.Context, query *proto.Query, reply *mpro logStats: logStats, qe: sq.qe, } - *reply = *qre.Execute() + result, err := qre.Execute() + if err != nil { + return sq.handleExecErrorNoPanic(query, err, logStats) + } + *reply = *result return nil } @@ -439,7 +447,10 @@ func (sq *SqlQuery) StreamExecute(ctx context.Context, query *proto.Query, sendR logStats: logStats, qe: sq.qe, } - qre.Stream(sendReply) + err = qre.Stream(sendReply) + if err != nil { + return sq.handleExecErrorNoPanic(query, err, logStats) + } return nil } @@ -528,13 +539,19 @@ func (sq *SqlQuery) SplitQuery(ctx context.Context, req *proto.SplitQueryRequest logStats: logStats, qe: sq.qe, } - conn := qre.getConn(sq.qe.connPool) + conn, err := qre.getConn(sq.qe.connPool) + if err != nil { + return err + } defer conn.Recycle() // TODO: For fetching MinMax, include where clauses on the // primary key, if any, in the original query which might give a narrower // range of split column to work with. - minMaxSql := fmt.Sprintf("SELECT MIN(%v), MAX(%v) FROM %v", splitter.splitColumn, splitter.splitColumn, splitter.tableName) - splitColumnMinMax := qre.execSQL(conn, minMaxSql, true) + minMaxSQL := fmt.Sprintf("SELECT MIN(%v), MAX(%v) FROM %v", splitter.splitColumn, splitter.splitColumn, splitter.tableName) + splitColumnMinMax, err := qre.execSQL(conn, minMaxSQL, true) + if err != nil { + return err + } reply.Queries, err = splitter.split(splitColumnMinMax) if err != nil { return NewTabletError(ErrFail, "splitQuery: query split error: %s, request: %#v", err, req)