diff --git a/go/vt/tabletserver/codex.go b/go/vt/tabletserver/codex.go index 75d5bd406fd..64a52b3649f 100644 --- a/go/vt/tabletserver/codex.go +++ b/go/vt/tabletserver/codex.go @@ -40,12 +40,13 @@ func buildValueList(tableInfo *TableInfo, pkValues []interface{}, bindVars map[s func resolvePKValues(tableInfo *TableInfo, pkValues []interface{}, bindVars map[string]interface{}) (resolved []interface{}, length int, err error) { length = -1 - setLength := func(list []sqltypes.Value) { + setLengthFunc := func(list []sqltypes.Value) error { if length == -1 { length = len(list) } else if len(list) != length { - panic(NewTabletError(ErrFail, "mismatched lengths for values %v", pkValues)) + return NewTabletError(ErrFail, "mismatched lengths for values %v", pkValues) } + return nil } resolved = make([]interface{}, len(pkValues)) for i, val := range pkValues { @@ -61,7 +62,9 @@ func resolvePKValues(tableInfo *TableInfo, pkValues []interface{}, bindVars map[ if err != nil { return nil, 0, err } - setLength(list) + if err := setLengthFunc(list); err != nil { + return nil, 0, err + } resolved[i] = list } case []interface{}: @@ -72,7 +75,9 @@ func resolvePKValues(tableInfo *TableInfo, pkValues []interface{}, bindVars map[ return nil, 0, err } } - setLength(list) + if err := setLengthFunc(list); err != nil { + return nil, 0, err + } resolved[i] = list default: resolved[i], err = resolveValue(tableInfo.GetPKColumn(i), val, nil) @@ -144,7 +149,7 @@ func resolveValue(col *schema.TableColumn, value interface{}, bindVars map[strin case sqltypes.Value: result = v default: - panic(NewTabletError(ErrFail, "incompatible value type %v", v)) + return result, NewTabletError(ErrFail, "incompatible value type %v", v) } if err = validateValue(col, result); err != nil { @@ -185,12 +190,12 @@ func validateValue(col *schema.TableColumn, value sqltypes.Value) error { // getLimit resolves the rowcount or offset of the limit clause value. // It returns -1 if it's not set. -func getLimit(limit interface{}, bv map[string]interface{}) int64 { +func getLimit(limit interface{}, bv map[string]interface{}) (int64, error) { switch lim := limit.(type) { case string: lookup, ok := bv[lim[1:]] if !ok { - panic(NewTabletError(ErrFail, "missing bind var %s", lim)) + return -1, NewTabletError(ErrFail, "missing bind var %s", lim) } var newlim int64 switch l := lookup.(type) { @@ -201,16 +206,16 @@ func getLimit(limit interface{}, bv map[string]interface{}) int64 { case int: newlim = int64(l) default: - panic(NewTabletError(ErrFail, "want number type for %s, got %T", lim, lookup)) + return -1, NewTabletError(ErrFail, "want number type for %s, got %T", lim, lookup) } if newlim < 0 { - panic(NewTabletError(ErrFail, "negative limit %d", newlim)) + return -1, NewTabletError(ErrFail, "negative limit %d", newlim) } - return newlim + return newlim, nil case int64: - return lim + return lim, nil default: - return -1 + return -1, nil } } diff --git a/go/vt/tabletserver/codex_test.go b/go/vt/tabletserver/codex_test.go index c956eaae35f..3779ad212fd 100644 --- a/go/vt/tabletserver/codex_test.go +++ b/go/vt/tabletserver/codex_test.go @@ -229,10 +229,8 @@ func TestCodexResolvePKValues(t *testing.T) { pkValues = make([]interface{}, 0, 10) pkValues = append(pkValues, []interface{}{":" + key}) pkValues = append(pkValues, []interface{}{":" + key2, ":" + key3}) - func() { - defer testUtils.checkTabletErrorWithRecover(t, ErrFail, "mismatched lengths") - _, _, err = resolvePKValues(&tableInfo, pkValues, bindVariables) - }() + _, _, err = resolvePKValues(&tableInfo, pkValues, bindVariables) + testUtils.checkTabletError(t, err, ErrFail, "mismatched lengths") } func TestCodexResolveListArg(t *testing.T) { @@ -322,11 +320,8 @@ func TestCodexResolveValueWithIncompatibleValueType(t *testing.T) { []string{"pk1", "pk2", "col1"}, []string{"int", "varbinary(128)", "int"}, []string{"pk1", "pk2"}) - - func() { - defer testUtils.checkTabletErrorWithRecover(t, ErrFail, "incompatible value type ") - resolveValue(tableInfo.GetPKColumn(0), 0, nil) - }() + _, err := resolveValue(tableInfo.GetPKColumn(0), 0, nil) + testUtils.checkTabletError(t, err, ErrFail, "incompatible value type ") } func TestCodexValidateRow(t *testing.T) { @@ -352,46 +347,52 @@ func TestCodexGetLimit(t *testing.T) { "uint": uint(1), } testUtils := newTestUtils() - // to handle panics - func() { - defer testUtils.checkTabletErrorWithRecover(t, ErrFail, "missing bind var") - getLimit(":unknown", bv) - }() - if result := getLimit(int64(1), bv); result != 1 { + _, err := getLimit(":unknown", bv) + if err == nil { + t.Fatal("got nil, want error: missing bind var") + } + testUtils.checkTabletError(t, err, ErrFail, "missing bind var") + result, err := getLimit(int64(1), bv) + if err != nil { + t.Fatalf("getLimit(1, bv) = %v, want nil", err) + } + if result != 1 { t.Fatalf("got %d, want 1", result) } - if result := getLimit(nil, bv); result != -1 { + result, err = getLimit(nil, bv) + if err != nil { + t.Fatalf("getLimit(nil, bv) = %v, want nil", err) + } + if result != -1 { t.Fatalf("got %d, want -1", result) } - func() { - defer func() { - x := recover().(error).Error() - want := "error: negative limit -1" - if x != want { - t.Fatalf("got %s, want %s", x, want) - } - }() - getLimit(":negative", bv) - }() - if result := getLimit(":int64", bv); result != 1 { + + result, err = getLimit(":negative", bv) + if err == nil { + t.Fatalf("getLimit(':negative', bv) should return an error") + } + want := "error: negative limit -1" + if err.Error() != want { + t.Fatalf("got %s, want %s", err.Error(), want) + } + if result, _ := getLimit(":int64", bv); result != 1 { t.Fatalf("got %d, want 1", result) } - if result := getLimit(":int32", bv); result != 1 { + if result, _ := getLimit(":int32", bv); result != 1 { t.Fatalf("got %d, want 1", result) } - if result := getLimit(":int", bv); result != 1 { + if result, _ := getLimit(":int", bv); result != 1 { t.Fatalf("got %d, want 1", result) } - func() { - defer func() { - x := recover().(error).Error() - want := "error: want number type for :uint, got uint" - if x != want { - t.Fatalf("got %s, want %s", x, want) - } - }() - getLimit(":uint", bv) - }() + + _, err = getLimit(":uint", bv) + if err == nil { + t.Fatalf("getLimit(':uint', bv) should return an error") + } + want = "error: want number type for :uint, got uint" + if err.Error() != want { + t.Fatalf("got %s, want %s", err.Error(), want) + } } func TestCodexBuildKey(t *testing.T) { diff --git a/go/vt/tabletserver/dbconn.go b/go/vt/tabletserver/dbconn.go index 21f0bb1aa99..8800acd6fdd 100644 --- a/go/vt/tabletserver/dbconn.go +++ b/go/vt/tabletserver/dbconn.go @@ -5,6 +5,7 @@ package tabletserver import ( + "errors" "fmt" "time" @@ -67,7 +68,7 @@ func (dbc *DBConn) Exec(ctx context.Context, query string, maxrows int, wantfiel return nil, NewTabletErrorSql(ErrFatal, err) } } - panic("unreachable") + return nil, NewTabletErrorSql(ErrFatal, errors.New("dbconn.Exec: unreachable code")) } func (dbc *DBConn) execOnce(ctx context.Context, query string, maxrows int, wantfields bool) (*mproto.QueryResult, error) { diff --git a/go/vt/tabletserver/query_executor.go b/go/vt/tabletserver/query_executor.go index 73b1ed46c82..a1eda00ea59 100644 --- a/go/vt/tabletserver/query_executor.go +++ b/go/vt/tabletserver/query_executor.go @@ -36,7 +36,7 @@ type poolConn interface { } // Execute performs a non-streaming query execution. -func (qre *QueryExecutor) Execute() (reply *mproto.QueryResult) { +func (qre *QueryExecutor) Execute() (reply *mproto.QueryResult, err error) { qre.logStats.OriginalSql = qre.query qre.logStats.BindVariables = qre.bindVars qre.logStats.TransactionID = qre.transactionID @@ -55,7 +55,9 @@ func (qre *QueryExecutor) Execute() (reply *mproto.QueryResult) { qre.qe.queryServiceStats.ResultStats.Add(int64(len(reply.Rows))) }(time.Now()) - qre.checkPermissions() + if err := qre.checkPermissions(); err != nil { + return nil, err + } if qre.plan.PlanId == planbuilder.PLAN_DDL { return qre.execDDL() @@ -73,79 +75,89 @@ func (qre *QueryExecutor) Execute() (reply *mproto.QueryResult) { switch qre.plan.PlanId { case planbuilder.PLAN_PASS_DML: if qre.qe.strictMode.Get() != 0 { - panic(NewTabletError(ErrFail, "DML too complex")) + return nil, NewTabletError(ErrFail, "DML too complex") } - reply = qre.directFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) + reply, err = qre.directFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) case planbuilder.PLAN_INSERT_PK: - reply = qre.execInsertPK(conn) + reply, err = qre.execInsertPK(conn) case planbuilder.PLAN_INSERT_SUBQUERY: - reply = qre.execInsertSubquery(conn) + reply, err = qre.execInsertSubquery(conn) case planbuilder.PLAN_DML_PK: - reply = qre.execDMLPK(conn, invalidator) + reply, err = qre.execDMLPK(conn, invalidator) case planbuilder.PLAN_DML_SUBQUERY: - reply = qre.execDMLSubquery(conn, invalidator) + reply, err = qre.execDMLSubquery(conn, invalidator) case planbuilder.PLAN_OTHER: - reply = qre.execSQL(conn, qre.query, true) + reply, err = qre.execSQL(conn, qre.query, true) default: // select or set in a transaction, just count as select - reply = qre.execDirect(conn) + reply, err = qre.execDirect(conn) } } else { switch qre.plan.PlanId { case planbuilder.PLAN_PASS_SELECT: if qre.plan.Reason == planbuilder.REASON_LOCK { - panic(NewTabletError(ErrFail, "Disallowed outside transaction")) + return nil, NewTabletError(ErrFail, "Disallowed outside transaction") } - reply = qre.execSelect() + reply, err = qre.execSelect() case planbuilder.PLAN_PK_IN: - reply = qre.execPKIN() + reply, err = qre.execPKIN() case planbuilder.PLAN_SELECT_SUBQUERY: - reply = qre.execSubquery() + reply, err = qre.execSubquery() case planbuilder.PLAN_SET: - reply = qre.execSet() + reply, err = qre.execSet() case planbuilder.PLAN_OTHER: - conn := qre.getConn(qre.qe.connPool) + conn, err := qre.getConn(qre.qe.connPool) + if err != nil { + return nil, err + } defer conn.Recycle() - reply = qre.execSQL(conn, qre.query, true) + reply, err = qre.execSQL(conn, qre.query, true) default: if !qre.qe.enableAutoCommit { - panic(NewTabletError(ErrFatal, "unsupported query: %s", qre.query)) + return nil, NewTabletError(ErrFatal, "unsupported query: %s", qre.query) } - reply = qre.execDmlAutoCommit() + reply, err = qre.execDmlAutoCommit() } } - return reply + return reply, err } // Stream performs a streaming query execution. -func (qre *QueryExecutor) Stream(sendReply func(*mproto.QueryResult) error) { +func (qre *QueryExecutor) Stream(sendReply func(*mproto.QueryResult) error) error { qre.logStats.OriginalSql = qre.query qre.logStats.PlanType = qre.plan.PlanId.String() defer qre.qe.queryServiceStats.QueryStats.Record(qre.plan.PlanId.String(), time.Now()) - qre.checkPermissions() + if err := qre.checkPermissions(); err != nil { + return err + } - conn := qre.getConn(qre.qe.streamConnPool) + conn, err := qre.getConn(qre.qe.streamConnPool) + if err != nil { + return err + } defer conn.Recycle() qd := NewQueryDetail(qre.logStats.ctx, conn) qre.qe.streamQList.Add(qd) defer qre.qe.streamQList.Remove(qd) - qre.fullStreamFetch(conn, qre.plan.FullQuery, qre.bindVars, nil, sendReply) + return qre.fullStreamFetch(conn, qre.plan.FullQuery, qre.bindVars, nil, sendReply) } -func (qre *QueryExecutor) execDmlAutoCommit() (reply *mproto.QueryResult) { +func (qre *QueryExecutor) execDmlAutoCommit() (reply *mproto.QueryResult, err error) { transactionID := qre.qe.txPool.Begin(qre.ctx) qre.logStats.AddRewrittenSql("begin", time.Now()) defer func() { - err := recover() - if err == nil { - qre.qe.Commit(qre.ctx, qre.logStats, transactionID) - qre.logStats.AddRewrittenSql("commit", time.Now()) - } else { + // TxPool.Get may panic + if panicErr := recover(); panicErr != nil { + err = fmt.Errorf("DML autocommit got panic: %v", panicErr) + } + if err != nil { qre.qe.txPool.Rollback(qre.ctx, transactionID) qre.logStats.AddRewrittenSql("rollback", time.Now()) - panic(err) + } else { + qre.qe.Commit(qre.ctx, qre.logStats, transactionID) + qre.logStats.AddRewrittenSql("commit", time.Now()) } }() conn := qre.qe.txPool.Get(transactionID) @@ -157,27 +169,27 @@ func (qre *QueryExecutor) execDmlAutoCommit() (reply *mproto.QueryResult) { switch qre.plan.PlanId { case planbuilder.PLAN_PASS_DML: if qre.qe.strictMode.Get() != 0 { - panic(NewTabletError(ErrFail, "DML too complex")) + return nil, NewTabletError(ErrFail, "DML too complex") } - reply = qre.directFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) + reply, err = qre.directFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) case planbuilder.PLAN_INSERT_PK: - reply = qre.execInsertPK(conn) + reply, err = qre.execInsertPK(conn) case planbuilder.PLAN_INSERT_SUBQUERY: - reply = qre.execInsertSubquery(conn) + reply, err = qre.execInsertSubquery(conn) case planbuilder.PLAN_DML_PK: - reply = qre.execDMLPK(conn, invalidator) + reply, err = qre.execDMLPK(conn, invalidator) case planbuilder.PLAN_DML_SUBQUERY: - reply = qre.execDMLSubquery(conn, invalidator) + reply, err = qre.execDMLSubquery(conn, invalidator) default: - panic(NewTabletError(ErrFatal, "unsupported query: %s", qre.query)) + return nil, NewTabletError(ErrFatal, "unsupported query: %s", qre.query) } - return reply + return reply, nil } -func (qre *QueryExecutor) checkPermissions() { +func (qre *QueryExecutor) checkPermissions() error { // Skip permissions check if we have a background context. if qre.ctx == context.Background() { - return + return nil } // Blacklist @@ -191,9 +203,9 @@ func (qre *QueryExecutor) checkPermissions() { action, desc := qre.plan.Rules.getAction(remoteAddr, username, qre.bindVars) switch action { case QR_FAIL: - panic(NewTabletError(ErrFail, "Query disallowed due to rule: %s", desc)) + return NewTabletError(ErrFail, "Query disallowed due to rule: %s", desc) case QR_FAIL_RETRY: - panic(NewTabletError(ErrRetry, "Query disallowed due to rule: %s", desc)) + return NewTabletError(ErrRetry, "Query disallowed due to rule: %s", desc) } // Perform table ACL check if it is enabled @@ -201,16 +213,17 @@ func (qre *QueryExecutor) checkPermissions() { errStr := fmt.Sprintf("table acl error: %q cannot run %v on table %q", username, qre.plan.PlanId, qre.plan.TableName) // Raise error if in strictTableAcl mode, else just log an error if qre.qe.strictTableAcl { - panic(NewTabletError(ErrFail, "%s", errStr)) + return NewTabletError(ErrFail, "%s", errStr) } qre.qe.accessCheckerLogger.Errorf("%s", errStr) } + return nil } -func (qre *QueryExecutor) execDDL() *mproto.QueryResult { +func (qre *QueryExecutor) execDDL() (*mproto.QueryResult, error) { ddlPlan := planbuilder.DDLParse(qre.query) if ddlPlan.Action == "" { - panic(NewTabletError(ErrFail, "DDL is not understood")) + return nil, NewTabletError(ErrFail, "DDL is not understood") } txid := qre.qe.txPool.Begin(qre.ctx) @@ -219,8 +232,10 @@ func (qre *QueryExecutor) execDDL() *mproto.QueryResult { // Stolen from Execute conn := qre.qe.txPool.Get(txid) defer conn.Recycle() - result := qre.execSQL(conn, qre.query, false) - + result, err := qre.execSQL(conn, qre.query, false) + if err != nil { + return nil, err + } if ddlPlan.TableName != "" && ddlPlan.TableName != ddlPlan.NewName { // It's a drop or rename. qre.qe.schemaInfo.DropTable(ddlPlan.TableName) @@ -228,31 +243,37 @@ func (qre *QueryExecutor) execDDL() *mproto.QueryResult { if ddlPlan.NewName != "" { qre.qe.schemaInfo.CreateOrUpdateTable(qre.ctx, ddlPlan.NewName) } - return result + return result, nil } -func (qre *QueryExecutor) execPKIN() (result *mproto.QueryResult) { +func (qre *QueryExecutor) execPKIN() (*mproto.QueryResult, error) { pkRows, err := buildValueList(qre.plan.TableInfo, qre.plan.PKValues, qre.bindVars) if err != nil { - panic(err) + return nil, err } - return qre.fetchMulti(pkRows, getLimit(qre.plan.Limit, qre.bindVars)) + limit, err := getLimit(qre.plan.Limit, qre.bindVars) + if err != nil { + return nil, err + } + return qre.fetchMulti(pkRows, limit) } -func (qre *QueryExecutor) execSubquery() (result *mproto.QueryResult) { - innerResult := qre.qFetch(qre.logStats, qre.plan.Subquery, qre.bindVars) +func (qre *QueryExecutor) execSubquery() (*mproto.QueryResult, error) { + innerResult, err := qre.qFetch(qre.logStats, qre.plan.Subquery, qre.bindVars) + if err != nil { + return nil, err + } return qre.fetchMulti(innerResult.Rows, -1) } -func (qre *QueryExecutor) fetchMulti(pkRows [][]sqltypes.Value, limit int64) (result *mproto.QueryResult) { +func (qre *QueryExecutor) fetchMulti(pkRows [][]sqltypes.Value, limit int64) (*mproto.QueryResult, error) { if qre.plan.Fields == nil { - panic("unexpected") + return nil, NewTabletError(ErrFatal, "query plan.Fields is empty") } - result = &mproto.QueryResult{Fields: qre.plan.Fields} + result := &mproto.QueryResult{Fields: qre.plan.Fields} if len(pkRows) == 0 || limit == 0 { - return + return result, nil } - tableInfo := qre.plan.TableInfo keys := make([]string, len(pkRows)) for i, pk := range pkRows { @@ -266,7 +287,10 @@ func (qre *QueryExecutor) fetchMulti(pkRows [][]sqltypes.Value, limit int64) (re rcresult := rcresults[keys[i]] if rcresult.Row != nil { if qre.mustVerify() { - qre.spotCheck(rcresult, pk) + err := qre.spotCheck(rcresult, pk) + if err != nil { + return nil, err + } } rows = append(rows, applyFilter(qre.plan.ColumnNumbers, rcresult.Row)) hits++ @@ -281,7 +305,10 @@ func (qre *QueryExecutor) fetchMulti(pkRows [][]sqltypes.Value, limit int64) (re Rows: missingRows, }, } - resultFromdb := qre.qFetch(qre.logStats, qre.plan.OuterQuery, bv) + resultFromdb, err := qre.qFetch(qre.logStats, qre.plan.OuterQuery, bv) + if err != nil { + return nil, err + } misses = int64(len(resultFromdb.Rows)) absent = int64(len(pkRows)) - hits - misses for _, row := range resultFromdb.Rows { @@ -307,14 +334,14 @@ func (qre *QueryExecutor) fetchMulti(pkRows [][]sqltypes.Value, limit int64) (re result.Rows = result.Rows[:limit] result.RowsAffected = uint64(limit) } - return result + return result, nil } func (qre *QueryExecutor) mustVerify() bool { return (Rand() % spotCheckMultiplier) < qre.qe.spotCheckFreq.Get() } -func (qre *QueryExecutor) spotCheck(rcresult RCResult, pk []sqltypes.Value) { +func (qre *QueryExecutor) spotCheck(rcresult RCResult, pk []sqltypes.Value) error { qre.qe.queryServiceStats.SpotCheckCount.Add(1) bv := map[string]interface{}{ "#pk": sqlparser.TupleEqualityList{ @@ -322,7 +349,10 @@ func (qre *QueryExecutor) spotCheck(rcresult RCResult, pk []sqltypes.Value) { Rows: [][]sqltypes.Value{pk}, }, } - resultFromdb := qre.qFetch(qre.logStats, qre.plan.OuterQuery, bv) + resultFromdb, err := qre.qFetch(qre.logStats, qre.plan.OuterQuery, bv) + if err != nil { + return err + } var dbrow []sqltypes.Value if len(resultFromdb.Rows) != 0 { dbrow = resultFromdb.Rows[0] @@ -330,6 +360,7 @@ func (qre *QueryExecutor) spotCheck(rcresult RCResult, pk []sqltypes.Value) { if dbrow == nil || !rowsAreEqual(rcresult.Row, dbrow) { qre.qe.Launch(func() { qre.recheckLater(rcresult, dbrow, pk) }) } + return nil } func (qre *QueryExecutor) recheckLater(rcresult RCResult, dbrow []sqltypes.Value, pk []sqltypes.Value) { @@ -347,45 +378,56 @@ func (qre *QueryExecutor) recheckLater(rcresult RCResult, dbrow []sqltypes.Value } // execDirect always sends the query to mysql -func (qre *QueryExecutor) execDirect(conn poolConn) (result *mproto.QueryResult) { +func (qre *QueryExecutor) execDirect(conn poolConn) (*mproto.QueryResult, error) { if qre.plan.Fields != nil { - result = qre.directFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) + result, err := qre.directFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) + if err != nil { + return nil, err + } result.Fields = qre.plan.Fields - return + return result, err } - result = qre.fullFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) - return + return qre.fullFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) } // execSelect sends a query to mysql only if another identical query is not running. Otherwise, it waits and // reuses the result. If the plan is missng field info, it sends the query to mysql requesting full info. -func (qre *QueryExecutor) execSelect() (result *mproto.QueryResult) { +func (qre *QueryExecutor) execSelect() (*mproto.QueryResult, error) { if qre.plan.Fields != nil { - result = qre.qFetch(qre.logStats, qre.plan.FullQuery, qre.bindVars) + result, err := qre.qFetch(qre.logStats, qre.plan.FullQuery, qre.bindVars) + if err != nil { + return nil, err + } result.Fields = qre.plan.Fields - return + return result, nil + } + conn, err := qre.getConn(qre.qe.connPool) + if err != nil { + return nil, err } - conn := qre.getConn(qre.qe.connPool) defer conn.Recycle() return qre.fullFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) } -func (qre *QueryExecutor) execInsertPK(conn poolConn) (result *mproto.QueryResult) { +func (qre *QueryExecutor) execInsertPK(conn poolConn) (*mproto.QueryResult, error) { pkRows, err := buildValueList(qre.plan.TableInfo, qre.plan.PKValues, qre.bindVars) if err != nil { - panic(err) + return nil, err } return qre.execInsertPKRows(conn, pkRows) } -func (qre *QueryExecutor) execInsertSubquery(conn poolConn) (result *mproto.QueryResult) { - innerResult := qre.directFetch(conn, qre.plan.Subquery, qre.bindVars, nil) +func (qre *QueryExecutor) execInsertSubquery(conn poolConn) (*mproto.QueryResult, error) { + innerResult, err := qre.directFetch(conn, qre.plan.Subquery, qre.bindVars, nil) + if err != nil { + return nil, err + } innerRows := innerResult.Rows if len(innerRows) == 0 { - return &mproto.QueryResult{RowsAffected: 0} + return &mproto.QueryResult{RowsAffected: 0}, nil } if len(qre.plan.ColumnNumbers) != len(innerRows[0]) { - panic(NewTabletError(ErrFail, "Subquery length does not match column list")) + return nil, NewTabletError(ErrFail, "Subquery length does not match column list") } pkRows := make([][]sqltypes.Value, len(innerRows)) for i, innerRow := range innerRows { @@ -393,46 +435,48 @@ func (qre *QueryExecutor) execInsertSubquery(conn poolConn) (result *mproto.Quer } // Validating first row is sufficient if err := validateRow(qre.plan.TableInfo, qre.plan.TableInfo.PKColumns, pkRows[0]); err != nil { - panic(err) + return nil, err } qre.bindVars["#values"] = innerRows return qre.execInsertPKRows(conn, pkRows) } -func (qre *QueryExecutor) execInsertPKRows(conn poolConn, pkRows [][]sqltypes.Value) (result *mproto.QueryResult) { +func (qre *QueryExecutor) execInsertPKRows(conn poolConn, pkRows [][]sqltypes.Value) (*mproto.QueryResult, error) { secondaryList, err := buildSecondaryList(qre.plan.TableInfo, pkRows, qre.plan.SecondaryPKValues, qre.bindVars) if err != nil { - panic(err) + return nil, err } bsc := buildStreamComment(qre.plan.TableInfo, pkRows, secondaryList) - result = qre.directFetch(conn, qre.plan.OuterQuery, qre.bindVars, bsc) - return result + return qre.directFetch(conn, qre.plan.OuterQuery, qre.bindVars, bsc) } -func (qre *QueryExecutor) execDMLPK(conn poolConn, invalidator CacheInvalidator) (result *mproto.QueryResult) { +func (qre *QueryExecutor) execDMLPK(conn poolConn, invalidator CacheInvalidator) (*mproto.QueryResult, error) { pkRows, err := buildValueList(qre.plan.TableInfo, qre.plan.PKValues, qre.bindVars) if err != nil { - panic(err) + return nil, err } return qre.execDMLPKRows(conn, pkRows, invalidator) } -func (qre *QueryExecutor) execDMLSubquery(conn poolConn, invalidator CacheInvalidator) (result *mproto.QueryResult) { - innerResult := qre.directFetch(conn, qre.plan.Subquery, qre.bindVars, nil) +func (qre *QueryExecutor) execDMLSubquery(conn poolConn, invalidator CacheInvalidator) (*mproto.QueryResult, error) { + innerResult, err := qre.directFetch(conn, qre.plan.Subquery, qre.bindVars, nil) + if err != nil { + return nil, err + } return qre.execDMLPKRows(conn, innerResult.Rows, invalidator) } -func (qre *QueryExecutor) execDMLPKRows(conn poolConn, pkRows [][]sqltypes.Value, invalidator CacheInvalidator) (result *mproto.QueryResult) { +func (qre *QueryExecutor) execDMLPKRows(conn poolConn, pkRows [][]sqltypes.Value, invalidator CacheInvalidator) (*mproto.QueryResult, error) { if len(pkRows) == 0 { - return &mproto.QueryResult{RowsAffected: 0} + return &mproto.QueryResult{RowsAffected: 0}, nil } secondaryList, err := buildSecondaryList(qre.plan.TableInfo, pkRows, qre.plan.SecondaryPKValues, qre.bindVars) if err != nil { - panic(err) + return nil, err } - result = &mproto.QueryResult{} + result := &mproto.QueryResult{} maxRows := int(qre.qe.maxDMLRows.Get()) for i := 0; i < len(pkRows); i += maxRows { end := i + maxRows @@ -449,93 +493,157 @@ func (qre *QueryExecutor) execDMLPKRows(conn poolConn, pkRows [][]sqltypes.Value Columns: qre.plan.TableInfo.Indexes[0].Columns, Rows: pkRows, } - r := qre.directFetch(conn, qre.plan.OuterQuery, qre.bindVars, bsc) + r, err := qre.directFetch(conn, qre.plan.OuterQuery, qre.bindVars, bsc) + if err != nil { + return nil, err + } // DMLs should only return RowsAffected. result.RowsAffected += r.RowsAffected } if invalidator == nil { - return result + return result, nil } for _, pk := range pkRows { key := buildKey(pk) invalidator.Delete(key) } - return result + return result, nil } -func (qre *QueryExecutor) execSet() (result *mproto.QueryResult) { +func (qre *QueryExecutor) execSet() (*mproto.QueryResult, error) { switch qre.plan.SetKey { case "vt_pool_size": - qre.qe.connPool.SetCapacity(int(getInt64(qre.plan.SetValue))) + val, err := parseInt64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_pool_size = %v, want to int64", err) + } + qre.qe.connPool.SetCapacity(int(val)) case "vt_stream_pool_size": - qre.qe.streamConnPool.SetCapacity(int(getInt64(qre.plan.SetValue))) + val, err := parseInt64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_stream_pool_size = %v, want int64", err) + } + qre.qe.streamConnPool.SetCapacity(int(val)) case "vt_transaction_cap": - qre.qe.txPool.pool.SetCapacity(int(getInt64(qre.plan.SetValue))) + val, err := parseInt64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_transaction_cap = %v, want int64", err) + } + qre.qe.txPool.pool.SetCapacity(int(val)) case "vt_transaction_timeout": - qre.qe.txPool.SetTimeout(getDuration(qre.plan.SetValue)) + val, err := parseDuration(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_transaction_timeout = %v, want int64 or float64", err) + } + qre.qe.txPool.SetTimeout(val) case "vt_schema_reload_time": - qre.qe.schemaInfo.SetReloadTime(getDuration(qre.plan.SetValue)) + val, err := parseDuration(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_schema_reload_time = %v, want int64 or float64", err) + } + qre.qe.schemaInfo.SetReloadTime(val) case "vt_query_cache_size": - qre.qe.schemaInfo.SetQueryCacheSize(int(getInt64(qre.plan.SetValue))) + val, err := parseInt64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_query_cache_size = %v, want int64", err) + } + qre.qe.schemaInfo.SetQueryCacheSize(int(val)) case "vt_max_result_size": - val := getInt64(qre.plan.SetValue) + val, err := parseInt64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_max_result_size = %v, want int64", err) + } if val < 1 { - panic(NewTabletError(ErrFail, "vt_max_result_size out of range %v", val)) + return nil, NewTabletError(ErrFail, "vt_max_result_size out of range %v", val) } qre.qe.maxResultSize.Set(val) case "vt_max_dml_rows": - val := getInt64(qre.plan.SetValue) + val, err := parseInt64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_max_dml_rows = %v, want to int64", err) + } if val < 1 { - panic(NewTabletError(ErrFail, "vt_max_dml_rows out of range %v", val)) + return nil, NewTabletError(ErrFail, "vt_max_dml_rows out of range %v", val) } qre.qe.maxDMLRows.Set(val) case "vt_stream_buffer_size": - val := getInt64(qre.plan.SetValue) + val, err := parseInt64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_stream_buffer_size = %v, want int64", err) + } + if val < 1024 { - panic(NewTabletError(ErrFail, "vt_stream_buffer_size out of range %v", val)) + return nil, NewTabletError(ErrFail, "vt_stream_buffer_size out of range %v", val) } qre.qe.streamBufferSize.Set(val) case "vt_query_timeout": - qre.qe.queryTimeout.Set(getDuration(qre.plan.SetValue)) + val, err := parseDuration(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_query_timeout = %v, want int64 or float64", err) + } + qre.qe.queryTimeout.Set(val) case "vt_idle_timeout": - t := getDuration(qre.plan.SetValue) - qre.qe.connPool.SetIdleTimeout(t) - qre.qe.streamConnPool.SetIdleTimeout(t) - qre.qe.txPool.pool.SetIdleTimeout(t) + val, err := parseDuration(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_idle_timeout = %v, want int64 or float64", err) + } + qre.qe.connPool.SetIdleTimeout(val) + qre.qe.streamConnPool.SetIdleTimeout(val) + qre.qe.txPool.pool.SetIdleTimeout(val) case "vt_spot_check_ratio": - qre.qe.spotCheckFreq.Set(int64(getFloat64(qre.plan.SetValue) * spotCheckMultiplier)) + val, err := parseFloat64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_spot_check_ratio = %v, want float64", err) + } + qre.qe.spotCheckFreq.Set(int64(val * spotCheckMultiplier)) case "vt_strict_mode": - qre.qe.strictMode.Set(getInt64(qre.plan.SetValue)) + val, err := parseInt64(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_strict_mode = %v, want to int64", err) + } + qre.qe.strictMode.Set(val) case "vt_txpool_timeout": - t := getDuration(qre.plan.SetValue) - qre.qe.txPool.SetPoolTimeout(t) + val, err := parseDuration(qre.plan.SetValue) + if err != nil { + return nil, NewTabletError(ErrFail, "got set vt_txpool_timeout = %v, want int64 or float64", err) + } + qre.qe.txPool.SetPoolTimeout(val) default: - conn := qre.getConn(qre.qe.connPool) + conn, err := qre.getConn(qre.qe.connPool) + if err != nil { + return nil, err + } defer conn.Recycle() return qre.directFetch(conn, qre.plan.FullQuery, qre.bindVars, nil) } - return &mproto.QueryResult{} + return &mproto.QueryResult{}, nil } -func getInt64(v interface{}) int64 { +func parseInt64(v interface{}) (int64, error) { if ival, ok := v.(int64); ok { - return ival + return ival, nil } - panic(NewTabletError(ErrFail, "expecting int")) + return -1, NewTabletError(ErrFail, "got %v, want int64", v) } -func getFloat64(v interface{}) float64 { +func parseFloat64(v interface{}) (float64, error) { if ival, ok := v.(int64); ok { - return float64(ival) + return float64(ival), nil } if fval, ok := v.(float64); ok { - return fval + return fval, nil } - panic(NewTabletError(ErrFail, "expecting number")) + return -1, NewTabletError(ErrFail, "got %v, want int64 or float64", v) } -func getDuration(v interface{}) time.Duration { - return time.Duration(getFloat64(v) * 1e9) +func parseDuration(v interface{}) (time.Duration, error) { + val, err := parseFloat64(v) + if err != nil { + return 0, err + } + // time.Duration is an int64, have to multiple by 1e9 because + // val might be in range (0, 1) + return time.Duration(val * 1e9), nil } func rowsAreEqual(row1, row2 []sqltypes.Value) bool { @@ -553,21 +661,24 @@ func rowsAreEqual(row1, row2 []sqltypes.Value) bool { return true } -func (qre *QueryExecutor) getConn(pool *ConnPool) *DBConn { +func (qre *QueryExecutor) getConn(pool *ConnPool) (*DBConn, error) { start := time.Now() conn, err := pool.Get(qre.ctx) switch err { case nil: qre.logStats.WaitingForConnection += time.Now().Sub(start) - return conn + return conn, nil case ErrConnPoolClosed: - panic(err) + return nil, err } - panic(NewTabletErrorSql(ErrFatal, err)) + return nil, NewTabletErrorSql(ErrFatal, err) } -func (qre *QueryExecutor) qFetch(logStats *SQLQueryStats, parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}) (result *mproto.QueryResult) { - sql := qre.generateFinalSql(parsedQuery, bindVars, nil) +func (qre *QueryExecutor) qFetch(logStats *SQLQueryStats, parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}) (*mproto.QueryResult, error) { + sql, err := qre.generateFinalSQL(parsedQuery, bindVars, nil) + if err != nil { + return nil, err + } q, ok := qre.qe.consolidator.Create(string(sql)) if ok { defer q.Broadcast() @@ -578,7 +689,7 @@ func (qre *QueryExecutor) qFetch(logStats *SQLQueryStats, parsedQuery *sqlparser q.Err = NewTabletErrorSql(ErrFatal, err) } else { defer conn.Recycle() - q.Result, q.Err = qre.execSQLNoPanic(conn, sql, false) + q.Result, q.Err = qre.execSQL(conn, sql, false) } } else { logStats.QuerySources |= QuerySourceConsolidator @@ -587,59 +698,61 @@ func (qre *QueryExecutor) qFetch(logStats *SQLQueryStats, parsedQuery *sqlparser qre.qe.queryServiceStats.WaitStats.Record("Consolidations", startTime) } if q.Err != nil { - panic(q.Err) + return nil, q.Err } - return q.Result.(*mproto.QueryResult) + return q.Result.(*mproto.QueryResult), nil } -func (qre *QueryExecutor) directFetch(conn poolConn, parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}, buildStreamComment []byte) (result *mproto.QueryResult) { - sql := qre.generateFinalSql(parsedQuery, bindVars, buildStreamComment) +func (qre *QueryExecutor) directFetch(conn poolConn, parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}, buildStreamComment []byte) (*mproto.QueryResult, error) { + sql, err := qre.generateFinalSQL(parsedQuery, bindVars, buildStreamComment) + if err != nil { + return nil, err + } return qre.execSQL(conn, sql, false) } // fullFetch also fetches field info -func (qre *QueryExecutor) fullFetch(conn poolConn, parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}, buildStreamComment []byte) (result *mproto.QueryResult) { - sql := qre.generateFinalSql(parsedQuery, bindVars, buildStreamComment) +func (qre *QueryExecutor) fullFetch(conn poolConn, parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}, buildStreamComment []byte) (*mproto.QueryResult, error) { + sql, err := qre.generateFinalSQL(parsedQuery, bindVars, buildStreamComment) + if err != nil { + return nil, err + } return qre.execSQL(conn, sql, true) } -func (qre *QueryExecutor) fullStreamFetch(conn *DBConn, parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}, buildStreamComment []byte, callback func(*mproto.QueryResult) error) { - sql := qre.generateFinalSql(parsedQuery, bindVars, buildStreamComment) - qre.execStreamSQL(conn, sql, callback) +func (qre *QueryExecutor) fullStreamFetch(conn *DBConn, parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}, buildStreamComment []byte, callback func(*mproto.QueryResult) error) error { + sql, err := qre.generateFinalSQL(parsedQuery, bindVars, buildStreamComment) + if err != nil { + return err + } + return qre.execStreamSQL(conn, sql, callback) } -func (qre *QueryExecutor) generateFinalSql(parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}, buildStreamComment []byte) string { +func (qre *QueryExecutor) generateFinalSQL(parsedQuery *sqlparser.ParsedQuery, bindVars map[string]interface{}, buildStreamComment []byte) (string, error) { bindVars["#maxLimit"] = qre.qe.maxResultSize.Get() + 1 sql, err := parsedQuery.GenerateQuery(bindVars) if err != nil { - panic(NewTabletError(ErrFail, "%s", err)) + return "", NewTabletError(ErrFail, "%s", err) } if buildStreamComment != nil { sql = append(sql, buildStreamComment...) } // undo hack done by stripTrailing sql = restoreTrailing(sql, bindVars) - return hack.String(sql) -} - -func (qre *QueryExecutor) execSQL(conn poolConn, sql string, wantfields bool) *mproto.QueryResult { - result, err := qre.execSQLNoPanic(conn, sql, wantfields) - if err != nil { - panic(err) - } - return result + return hack.String(sql), nil } -func (qre *QueryExecutor) execSQLNoPanic(conn poolConn, sql string, wantfields bool) (*mproto.QueryResult, error) { +func (qre *QueryExecutor) execSQL(conn poolConn, sql string, wantfields bool) (*mproto.QueryResult, error) { defer qre.logStats.AddRewrittenSql(sql, time.Now()) return conn.Exec(qre.ctx, sql, int(qre.qe.maxResultSize.Get()), wantfields) } -func (qre *QueryExecutor) execStreamSQL(conn *DBConn, sql string, callback func(*mproto.QueryResult) error) { +func (qre *QueryExecutor) execStreamSQL(conn *DBConn, sql string, callback func(*mproto.QueryResult) error) error { start := time.Now() err := conn.Stream(qre.ctx, sql, callback, int(qre.qe.streamBufferSize.Get())) qre.logStats.AddRewrittenSql(sql, start) if err != nil { - panic(NewTabletErrorSql(ErrFail, err)) + return NewTabletErrorSql(ErrFail, err) } + return nil } diff --git a/go/vt/tabletserver/query_executor_test.go b/go/vt/tabletserver/query_executor_test.go index abc7c5b3b47..f091bff0b47 100644 --- a/go/vt/tabletserver/query_executor_test.go +++ b/go/vt/tabletserver/query_executor_test.go @@ -26,31 +26,41 @@ import ( func TestQueryExecutorPlanDDL(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "alter table test_table add zipcode int" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) qre, sqlQuery := newTestQueryExecutor( query, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_DDL, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanPassDmlStrictMode(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "update test_table set pk = foo()" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) // non strict mode qre, sqlQuery := newTestQueryExecutor(query, context.Background(), enableTx) checkPlanID(t, planbuilder.PLAN_PASS_DML, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } testCommitHelper(t, sqlQuery, qre) sqlQuery.disallowQueries() @@ -62,39 +72,57 @@ func TestQueryExecutorPlanPassDmlStrictMode(t *testing.T) { defer sqlQuery.disallowQueries() defer testCommitHelper(t, sqlQuery, qre) checkPlanID(t, planbuilder.PLAN_PASS_DML, qre.plan.PlanId) - defer handleAndVerifyTabletError( - t, - "update should fail because strict mode is enabled", - ErrFail) - qre.Execute() + got, err = qre.Execute() + if err == nil { + t.Fatal("qre.Execute() = nil, want error") + } + tabletError, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: a TabletError", tabletError) + } + if tabletError.ErrorType != ErrFail { + t.Fatalf("got: %s, want: ErrFail", getTabletErrorString(ErrFail)) + } } func TestQueryExecutorPlanPassDmlStrictModeAutoCommit(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "update test_table set pk = foo()" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) // non strict mode qre, sqlQuery := newTestQueryExecutor(query, context.Background(), noFlags) checkPlanID(t, planbuilder.PLAN_PASS_DML, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } sqlQuery.disallowQueries() // strict mode + // update should fail because strict mode is not enabled qre, sqlQuery = newTestQueryExecutor( "update test_table set pk = foo()", context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_PASS_DML, qre.plan.PlanId) - defer handleAndVerifyTabletError( - t, - "update should fail because strict mode is not enabled", - ErrFail) - qre.Execute() + _, err = qre.Execute() + if err == nil { + t.Fatal("got: nil, want: error") + } + tabletError, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: *TabletError", tabletError) + } + if tabletError.ErrorType != ErrFail { + t.Fatalf("got: %s, want: ErrFail", getTabletErrorString(ErrFail)) + } } func TestQueryExecutorPlanInsertPk(t *testing.T) { @@ -110,20 +138,22 @@ func TestQueryExecutorPlanInsertPk(t *testing.T) { enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_INSERT_PK, qre.plan.PlanId) - got := qre.Execute() + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } if !reflect.DeepEqual(got, want) { - t.Fatalf("query: %s, QueryExecutor.Execute() = %v, want: %v", sql, got, want) + t.Fatalf("got: %v, want: %v", got, want) } } func TestQueryExecutorPlanInsertSubQueryAutoCommmit(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "insert into test_table(pk) select pk from test_table where pk = 1 limit 1000" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) selectQuery := "select pk from test_table where pk = 1 limit 1000" db.AddQuery(selectQuery, &mproto.QueryResult{ RowsAffected: 1, @@ -140,17 +170,22 @@ func TestQueryExecutorPlanInsertSubQueryAutoCommmit(t *testing.T) { query, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_INSERT_SUBQUERY, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanInsertSubQuery(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "insert into test_table(pk) select pk from test_table where pk = 1 limit 1000" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) selectQuery := "select pk from test_table where pk = 1 limit 1000" db.AddQuery(selectQuery, &mproto.QueryResult{ RowsAffected: 1, @@ -168,107 +203,130 @@ func TestQueryExecutorPlanInsertSubQuery(t *testing.T) { defer sqlQuery.disallowQueries() defer testCommitHelper(t, sqlQuery, qre) checkPlanID(t, planbuilder.PLAN_INSERT_SUBQUERY, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanDmlPk(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "update test_table set name = 2 where pk in (1) /* _stream test_table (pk ) (1 ); */" - // expected.Rows is always nil, intended - expected := &mproto.QueryResult{} - db.AddQuery(query, expected) - + want := &mproto.QueryResult{} + db.AddQuery(query, want) qre, sqlQuery := newTestQueryExecutor( query, context.Background(), enableRowCache|enableTx|enableStrict) defer sqlQuery.disallowQueries() defer testCommitHelper(t, sqlQuery, qre) checkPlanID(t, planbuilder.PLAN_DML_PK, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanDmlAutoCommit(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "update test_table set name = 2 where pk in (1) /* _stream test_table (pk ) (1 ); */" - // expected.Rows is always nil, intended - expected := &mproto.QueryResult{} - db.AddQuery(query, expected) - + want := &mproto.QueryResult{} + db.AddQuery(query, want) qre, sqlQuery := newTestQueryExecutor( query, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_DML_PK, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanDmlSubQuery(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "update test_table set addr = 3 where name = 1 limit 1000" expandedQuery := "select pk from test_table where name = 1 limit 1000 for update" - // expected.Rows is always nil, intended - expected := &mproto.QueryResult{} - db.AddQuery(query, expected) - db.AddQuery(expandedQuery, expected) + want := &mproto.QueryResult{} + db.AddQuery(query, want) + db.AddQuery(expandedQuery, want) qre, sqlQuery := newTestQueryExecutor( query, context.Background(), enableRowCache|enableTx|enableStrict) defer sqlQuery.disallowQueries() defer testCommitHelper(t, sqlQuery, qre) checkPlanID(t, planbuilder.PLAN_DML_SUBQUERY, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanDmlSubQueryAutoCommit(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "update test_table set addr = 3 where name = 1 limit 1000" expandedQuery := "select pk from test_table where name = 1 limit 1000 for update" - // expected.Rows is always nil, intended - expected := &mproto.QueryResult{} - db.AddQuery(query, expected) - db.AddQuery(expandedQuery, expected) + want := &mproto.QueryResult{} + db.AddQuery(query, want) + db.AddQuery(expandedQuery, want) qre, sqlQuery := newTestQueryExecutor( query, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_DML_SUBQUERY, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanOtherWithinATransaction(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "show test_table" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Fields: getTestTableFields(), RowsAffected: 0, Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) qre, sqlQuery := newTestQueryExecutor( query, context.Background(), enableTx|enableRowCache|enableSchemaOverrides|enableStrict) defer sqlQuery.disallowQueries() defer testCommitHelper(t, sqlQuery, qre) checkPlanID(t, planbuilder.PLAN_OTHER, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanPassSelectWithInATransaction(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} fields := []mproto.Field{ mproto.Field{Name: "addr", Type: mproto.VT_LONG}, } - query := "select addr from test_table where pk = 1 limit 1000" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Fields: fields, RowsAffected: 1, Rows: [][]sqltypes.Value{ []sqltypes.Value{sqltypes.MakeString([]byte("123"))}, }, } - db.AddQuery(query, expected) + db.AddQuery(query, want) db.AddQuery("select addr from test_table where 1 != 1", &mproto.QueryResult{ Fields: fields, }) @@ -278,17 +336,23 @@ func TestQueryExecutorPlanPassSelectWithInATransaction(t *testing.T) { defer sqlQuery.disallowQueries() defer testCommitHelper(t, sqlQuery, qre) checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanPassSelectWithLockOutsideATransaction(t *testing.T) { db := setUpQueryExecutorTest() query := "select * from test_table for update" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Fields: getTestTableFields(), Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) db.AddQuery("select * from test_table where 1 != 1", &mproto.QueryResult{ Fields: getTestTableFields(), }) @@ -297,19 +361,27 @@ func TestQueryExecutorPlanPassSelectWithLockOutsideATransaction(t *testing.T) { query, context.Background(), enableRowCache|enableSchemaOverrides|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId) - defer handleAndVerifyTabletError(t, "query should fail because the select holds a lock but outside a transaction", ErrFail) - qre.Execute() + _, err := qre.Execute() + if err == nil { + t.Fatal("got: nil, want: error") + } + got, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: *TabletError", err) + } + if got.ErrorType != ErrFail { + t.Fatalf("got: %s, want: ErrFail", getTabletErrorString(got.ErrorType)) + } } func TestQueryExecutorPlanPassSelect(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "select * from test_table limit 1000" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Fields: getTestTableFields(), Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) db.AddQuery("select * from test_table where 1 != 1", &mproto.QueryResult{ Fields: getTestTableFields(), }) @@ -318,16 +390,20 @@ func TestQueryExecutorPlanPassSelect(t *testing.T) { query, context.Background(), enableRowCache|enableSchemaOverrides|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanPKIn(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "select * from test_table where pk in (1, 2, 3) limit 1000" expandedQuery := "select pk, name, addr from test_table where pk in (1, 2, 3)" - - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Fields: getTestTableFields(), RowsAffected: 1, Rows: [][]sqltypes.Value{ @@ -338,18 +414,22 @@ func TestQueryExecutorPlanPKIn(t *testing.T) { }, }, } - db.AddQuery(query, expected) - db.AddQuery(expandedQuery, expected) - + db.AddQuery(query, want) + db.AddQuery(expandedQuery, want) db.AddQuery("select * from test_table where 1 != 1", &mproto.QueryResult{ Fields: getTestTableFields(), }) - qre, sqlQuery := newTestQueryExecutor( query, context.Background(), enableRowCache|enableStrict|enableSchemaOverrides) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_PK_IN, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } cachedQuery := "select pk, name, addr from test_table where pk in (1)" db.AddQuery(cachedQuery, &mproto.QueryResult{ @@ -366,24 +446,26 @@ func TestQueryExecutorPlanPKIn(t *testing.T) { nonCachedQuery := "select pk, name, addr from test_table where pk in (2, 3)" db.AddQuery(nonCachedQuery, &mproto.QueryResult{}) - - db.AddQuery(cachedQuery, expected) - + db.AddQuery(cachedQuery, want) // run again, this time pk=1 should hit the rowcache - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanSelectSubQuery(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "select * from test_table where name = 1 limit 1000" expandedQuery := "select pk from test_table use index (INDEX) where name = 1 limit 1000" - - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Fields: getTestTableFields(), } - db.AddQuery(query, expected) - db.AddQuery(expandedQuery, expected) + db.AddQuery(query, want) + db.AddQuery(expandedQuery, want) db.AddQuery("select * from test_table where 1 != 1", &mproto.QueryResult{ Fields: getTestTableFields(), @@ -393,14 +475,17 @@ func TestQueryExecutorPlanSelectSubQuery(t *testing.T) { query, context.Background(), enableRowCache|enableSchemaOverrides|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SELECT_SUBQUERY, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } } func TestQueryExecutorPlanSet(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} - expected := &mproto.QueryResult{} - setQuery := "set unknown_key = 1" db.AddQuery(setQuery, &mproto.QueryResult{}) qre, sqlQuery := newTestQueryExecutor( @@ -411,9 +496,12 @@ func TestQueryExecutorPlanSet(t *testing.T) { want := &mproto.QueryResult{ Rows: make([][]sqltypes.Value, 0), } - got := qre.Execute() + got, err := qre.Execute() + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } if !reflect.DeepEqual(got, want) { - t.Fatalf("query: %s failed, got: %+v, want: %+v", setQuery, got, want) + t.Fatalf("qre.Execute() = %v, want: %v", got, want) } sqlQuery.disallowQueries() @@ -424,7 +512,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } if qre.qe.connPool.Capacity() != vtPoolSize { t.Fatalf("set query failed, expected to have vt_pool_size: %d, but got: %d", vtPoolSize, qre.qe.connPool.Capacity()) @@ -438,7 +533,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } if qre.qe.streamConnPool.Capacity() != vtStreamPoolSize { t.Fatalf("set query failed, expected to have vt_stream_pool_size: %d, but got: %d", vtStreamPoolSize, qre.qe.streamConnPool.Capacity()) } @@ -451,7 +553,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } if qre.qe.txPool.pool.Capacity() != vtTransactionCap { t.Fatalf("set query failed, expected to have vt_transaction_cap: %d, but got: %d", vtTransactionCap, qre.qe.txPool.pool.Capacity()) } @@ -464,7 +573,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } vtTransactionTimeoutInMillis := time.Duration(vtTransactionTimeout) * time.Second if qre.qe.txPool.Timeout() != vtTransactionTimeoutInMillis { t.Fatalf("set query failed, expected to have vt_transaction_timeout: %d, but got: %d", vtTransactionTimeoutInMillis, qre.qe.txPool.Timeout()) @@ -478,7 +594,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } vtSchemaReloadTimeInMills := time.Duration(vtSchemaReloadTime) * time.Second if qre.qe.schemaInfo.ReloadTime() != vtSchemaReloadTimeInMills { t.Fatalf("set query failed, expected to have vt_schema_reload_time: %d, but got: %d", vtSchemaReloadTimeInMills, qre.qe.schemaInfo.ReloadTime()) @@ -492,7 +615,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } if int64(qre.qe.schemaInfo.queries.Capacity()) != vtQueryCacheSize { t.Fatalf("set query failed, expected to have vt_query_cache_size: %d, but got: %d", vtQueryCacheSize, qre.qe.schemaInfo.queries.Capacity()) } @@ -505,7 +635,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } vtQueryTimeoutInMillis := time.Duration(vtQueryTimeout) * time.Second if qre.qe.queryTimeout.Get() != vtQueryTimeoutInMillis { t.Fatalf("set query failed, expected to have vt_query_timeout: %d, but got: %d", vtQueryTimeoutInMillis, qre.qe.queryTimeout.Get()) @@ -519,7 +656,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } vtIdleTimeoutInMillis := time.Duration(vtIdleTimeout) * time.Second if qre.qe.connPool.IdleTimeout() != vtIdleTimeoutInMillis { t.Fatalf("set query failed, expected to have vt_idle_timeout: %d, but got: %d in conn pool", vtIdleTimeoutInMillis, qre.qe.connPool.IdleTimeout()) @@ -539,7 +683,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } vtSpotCheckFreq := int64(vtSpotCheckRatio * spotCheckMultiplier) if qre.qe.spotCheckFreq.Get() != vtSpotCheckFreq { t.Fatalf("set query failed, expected to have vt_spot_check_freq: %d, but got: %d", vtSpotCheckFreq, qre.qe.spotCheckFreq.Get()) @@ -553,7 +704,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } if qre.qe.strictMode.Get() != vtStrictMode { t.Fatalf("set query failed, expected to have vt_strict_mode: %d, but got: %d", vtStrictMode, qre.qe.strictMode.Get()) } @@ -566,7 +724,14 @@ func TestQueryExecutorPlanSet(t *testing.T) { qre, sqlQuery = newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err = qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + want = &mproto.QueryResult{} + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } vtTxPoolTimeoutInMillis := time.Duration(vtTxPoolTimeout) * time.Second if qre.qe.txPool.PoolTimeout() != vtTxPoolTimeoutInMillis { t.Fatalf("set query failed, expected to have vt_txpool_timeout: %d, but got: %d", vtTxPoolTimeoutInMillis, qre.qe.txPool.PoolTimeout()) @@ -576,26 +741,44 @@ func TestQueryExecutorPlanSet(t *testing.T) { func TestQueryExecutorPlanSetMaxResultSize(t *testing.T) { setUpQueryExecutorTest() - testUtils := &testUtils{} - expected := &mproto.QueryResult{} + want := &mproto.QueryResult{} vtMaxResultSize := int64(128) setQuery := fmt.Sprintf("set vt_max_result_size = %d", vtMaxResultSize) qre, sqlQuery := newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } if qre.qe.maxResultSize.Get() != vtMaxResultSize { t.Fatalf("set query failed, expected to have vt_max_result_size: %d, but got: %d", vtMaxResultSize, qre.qe.maxResultSize.Get()) } - // set vt_max_result_size fail - setQuery = "set vt_max_result_size = 0" - qre, sqlQuery = newTestQueryExecutor( +} + +func TestQueryExecutorPlanSetMaxResultSizeFail(t *testing.T) { + setUpQueryExecutorTest() + setQuery := "set vt_max_result_size = 0" + qre, sqlQuery := newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - defer handleAndVerifyTabletError(t, "vt_max_result_size out of range, should always larger than 0", ErrFail) - qre.Execute() + // vt_max_result_size out of range, should always larger than 0 + _, err := qre.Execute() + if err == nil { + t.Fatal("got: nil, want: error") + } + got, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: *TabletError", err) + } + if got.ErrorType != ErrFail { + t.Fatalf("got: %s, want: ErrFail", getTabletErrorString(got.ErrorType)) + } } func TestQueryExecutorPlanSetMaxDmlRows(t *testing.T) { @@ -607,78 +790,113 @@ func TestQueryExecutorPlanSetMaxDmlRows(t *testing.T) { setQuery, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - got := qre.Execute() + got, err := qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } if !reflect.DeepEqual(got, want) { - t.Fatalf("query executor Execute() = %v, want: %v", got, want) + t.Fatalf("qre.Execute() = %v, want: %v", got, want) } if qre.qe.maxDMLRows.Get() != vtMaxDmlRows { t.Fatalf("set query failed, expected to have vt_max_dml_rows: %d, but got: %d", vtMaxDmlRows, qre.qe.maxDMLRows.Get()) } - // set vt_max_result_size fail - setQuery = "set vt_max_dml_rows = 0" - qre, sqlQuery = newTestQueryExecutor( +} + +func TestQueryExecutorPlanSetMaxDmlRowsFail(t *testing.T) { + setUpQueryExecutorTest() + setQuery := "set vt_max_dml_rows = 0" + qre, sqlQuery := newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - defer handleAndVerifyTabletError(t, "vt_max_dml_rows out of range, should always larger than 0", ErrFail) - qre.Execute() + _, err := qre.Execute() + if err == nil { + t.Fatal("got: nil, want: error") + } + got, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: *TabletError", err) + } + if got.ErrorType != ErrFail { + t.Fatalf("got: %s, want: ErrFail", getTabletErrorString(got.ErrorType)) + } } func TestQueryExecutorPlanSetStreamBufferSize(t *testing.T) { setUpQueryExecutorTest() - testUtils := &testUtils{} - expected := &mproto.QueryResult{} + want := &mproto.QueryResult{} vtStreamBufferSize := int64(2048) setQuery := fmt.Sprintf("set vt_stream_buffer_size = %d", vtStreamBufferSize) qre, sqlQuery := newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } if qre.qe.streamBufferSize.Get() != vtStreamBufferSize { t.Fatalf("set query failed, expected to have vt_stream_buffer_size: %d, but got: %d", vtStreamBufferSize, qre.qe.streamBufferSize.Get()) } - // set vt_max_result_size fail - setQuery = "set vt_stream_buffer_size = 128" - qre, sqlQuery = newTestQueryExecutor( +} + +func TestQueryExecutorPlanSetStreamBufferSizeFail(t *testing.T) { + setUpQueryExecutorTest() + setQuery := "set vt_stream_buffer_size = 128" + qre, sqlQuery := newTestQueryExecutor( setQuery, context.Background(), enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SET, qre.plan.PlanId) - defer handleAndVerifyTabletError(t, "vt_stream_buffer_size out of range, should always larger than or equal to 1024", ErrFail) - qre.Execute() + _, err := qre.Execute() + if err == nil { + t.Fatal("got: nil, want: error") + } + got, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: *TabletError", err) + } + if got.ErrorType != ErrFail { + t.Fatalf("got: %s, want: ErrFail", getTabletErrorString(got.ErrorType)) + } } func TestQueryExecutorPlanOther(t *testing.T) { db := setUpQueryExecutorTest() - testUtils := &testUtils{} query := "show test_table" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Fields: getTestTableFields(), RowsAffected: 0, Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) qre, sqlQuery := newTestQueryExecutor( query, context.Background(), enableRowCache|enableSchemaOverrides|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_OTHER, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) + got, err := qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } } func TestQueryExecutorTableAcl(t *testing.T) { - testUtils := &testUtils{} aclName := fmt.Sprintf("simpleacl-test-%d", rand.Int63()) tableacl.Register(aclName, &simpleacl.Factory{}) tableacl.SetDefaultACL(aclName) - db := setUpQueryExecutorTest() query := "select * from test_table limit 1000" - expected := &mproto.QueryResult{ + want := &mproto.QueryResult{ Fields: getTestTableFields(), RowsAffected: 0, Rows: [][]sqltypes.Value{}, } - db.AddQuery(query, expected) + db.AddQuery(query, want) db.AddQuery("select * from test_table where 1 != 1", &mproto.QueryResult{ Fields: getTestTableFields(), }) @@ -702,11 +920,41 @@ func TestQueryExecutorTableAcl(t *testing.T) { qre, sqlQuery := newTestQueryExecutor( query, ctx, enableRowCache|enableSchemaOverrides|enableStrict) + defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId) - testUtils.checkEqual(t, expected, qre.Execute()) - sqlQuery.disallowQueries() + got, err := qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } +} - config = &tableaclpb.Config{ +func TestQueryExecutorTableAclNoPermission(t *testing.T) { + aclName := fmt.Sprintf("simpleacl-test-%d", rand.Int63()) + tableacl.Register(aclName, &simpleacl.Factory{}) + tableacl.SetDefaultACL(aclName) + db := setUpQueryExecutorTest() + query := "select * from test_table limit 1000" + want := &mproto.QueryResult{ + Fields: getTestTableFields(), + RowsAffected: 0, + Rows: [][]sqltypes.Value{}, + } + db.AddQuery(query, want) + db.AddQuery("select * from test_table where 1 != 1", &mproto.QueryResult{ + Fields: getTestTableFields(), + }) + + username := "u2" + callInfo := &fakeCallInfo{ + remoteAddr: "1.2.3.4", + username: username, + } + ctx := callinfo.NewContext(context.Background(), callInfo) + + config := &tableaclpb.Config{ TableGroups: []*tableaclpb.TableGroupSpec{{ Name: "group02", TableNamesOrPrefixes: []string{"test_table"}, @@ -718,18 +966,35 @@ func TestQueryExecutorTableAcl(t *testing.T) { t.Fatalf("unable to load tableacl config, error: %v", err) } // without enabling Config.StrictTableAcl - qre, sqlQuery = newTestQueryExecutor( + qre, sqlQuery := newTestQueryExecutor( query, ctx, enableRowCache|enableSchemaOverrides|enableStrict) checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId) - qre.Execute() + got, err := qre.Execute() + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("qre.Execute() = %v, want: %v", got, want) + } sqlQuery.disallowQueries() + // enable Config.StrictTableAcl qre, sqlQuery = newTestQueryExecutor( query, ctx, enableRowCache|enableSchemaOverrides|enableStrict|enableStrictTableAcl) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId) - defer handleAndVerifyTabletError(t, "query should fail because current user do not have read permissions", ErrFail) - qre.Execute() + // query should fail because current user do not have read permissions + _, err = qre.Execute() + if err == nil { + t.Fatal("got: nil, want: error") + } + tabletError, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: *TabletError", err) + } + if tabletError.ErrorType != ErrFail { + t.Fatalf("got: %s, want: ErrFail", getTabletErrorString(tabletError.ErrorType)) + } } func TestQueryExecutorBlacklistQRFail(t *testing.T) { @@ -776,8 +1041,18 @@ func TestQueryExecutorBlacklistQRFail(t *testing.T) { qre, sqlQuery := newTestQueryExecutor(query, ctx, enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SELECT_SUBQUERY, qre.plan.PlanId) - defer handleAndVerifyTabletError(t, "execute should fail because query has been blacklisted", ErrFail) - qre.Execute() + // execute should fail because query has been blacklisted + _, err := qre.Execute() + if err == nil { + t.Fatal("got: nil, want: error") + } + got, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: *TabletError", err) + } + if got.ErrorType != ErrFail { + t.Fatalf("got: %s, want: ErrFail", getTabletErrorString(got.ErrorType)) + } } func TestQueryExecutorBlacklistQRRetry(t *testing.T) { @@ -824,8 +1099,17 @@ func TestQueryExecutorBlacklistQRRetry(t *testing.T) { qre, sqlQuery := newTestQueryExecutor(query, ctx, enableRowCache|enableStrict) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_SELECT_SUBQUERY, qre.plan.PlanId) - defer handleAndVerifyTabletError(t, "execute should fail because query has been blacklisted", ErrRetry) - qre.Execute() + _, err := qre.Execute() + if err == nil { + t.Fatal("got: nil, want: error") + } + got, ok := err.(*TabletError) + if !ok { + t.Fatalf("got: %v, want: *TabletError", err) + } + if got.ErrorType != ErrRetry { + t.Fatalf("got: %s, want: ErrRetry", getTabletErrorString(got.ErrorType)) + } } type executorFlags int64 diff --git a/go/vt/tabletserver/sqlquery.go b/go/vt/tabletserver/sqlquery.go index 7009d2df3ff..833a39f55c9 100644 --- a/go/vt/tabletserver/sqlquery.go +++ b/go/vt/tabletserver/sqlquery.go @@ -348,28 +348,7 @@ func (sq *SqlQuery) Rollback(ctx context.Context, session *proto.Session) (err e // the supplied error return value. func (sq *SqlQuery) handleExecError(query *proto.Query, err *error, logStats *SQLQueryStats) { if x := recover(); x != nil { - terr, ok := x.(*TabletError) - if !ok { - log.Errorf("Uncaught panic for %v:\n%v\n%s", query, x, tb.Stack(4)) - *err = NewTabletError(ErrFail, "%v: uncaught panic for %v", x, query) - sq.qe.queryServiceStats.InternalErrors.Add("Panic", 1) - return - } - if sq.config.TerseErrors && terr.SqlError != 0 { - *err = fmt.Errorf("%s(errno %d) during query: %s", terr.Prefix(), terr.SqlError, query.Sql) - } else { - *err = terr - } - terr.RecordStats(sq.qe.queryServiceStats) - // suppress these errors in logs - if terr.ErrorType == ErrRetry || terr.ErrorType == ErrTxPoolFull || terr.SqlError == mysql.ErrDupEntry { - return - } - if terr.ErrorType == ErrFatal { - log.Errorf("%v: %v", terr, query) - } else { - log.Warningf("%v: %v", terr, query) - } + *err = sq.handleExecErrorNoPanic(query, x, logStats) } if logStats != nil { logStats.Error = *err @@ -377,6 +356,31 @@ func (sq *SqlQuery) handleExecError(query *proto.Query, err *error, logStats *SQ } } +func (sq *SqlQuery) handleExecErrorNoPanic(query *proto.Query, err interface{}, logStats *SQLQueryStats) error { + terr, ok := err.(*TabletError) + if !ok { + log.Errorf("Uncaught panic for %v:\n%v\n%s", query, err, tb.Stack(4)) + sq.qe.queryServiceStats.InternalErrors.Add("Panic", 1) + return NewTabletError(ErrFail, "%v: uncaught panic for %v", err, query) + } + var myError error + if sq.config.TerseErrors && terr.SqlError != 0 { + myError = fmt.Errorf("%s(errno %d) during query: %s", terr.Prefix(), terr.SqlError, query.Sql) + } else { + myError = terr + } + terr.RecordStats(sq.qe.queryServiceStats) + // suppress these errors in logs + if terr.ErrorType == ErrRetry || terr.ErrorType == ErrTxPoolFull || terr.SqlError == mysql.ErrDupEntry { + return myError + } + if terr.ErrorType == ErrFatal { + log.Errorf("%v: %v", terr, query) + } + log.Warningf("%v: %v", terr, query) + return myError +} + // Execute executes the query and returns the result as response. func (sq *SqlQuery) Execute(ctx context.Context, query *proto.Query, reply *mproto.QueryResult) (err error) { logStats := newSqlQueryStats("Execute", ctx) @@ -405,7 +409,11 @@ func (sq *SqlQuery) Execute(ctx context.Context, query *proto.Query, reply *mpro logStats: logStats, qe: sq.qe, } - *reply = *qre.Execute() + result, err := qre.Execute() + if err != nil { + return sq.handleExecErrorNoPanic(query, err, logStats) + } + *reply = *result return nil } @@ -439,7 +447,10 @@ func (sq *SqlQuery) StreamExecute(ctx context.Context, query *proto.Query, sendR logStats: logStats, qe: sq.qe, } - qre.Stream(sendReply) + err = qre.Stream(sendReply) + if err != nil { + return sq.handleExecErrorNoPanic(query, err, logStats) + } return nil } @@ -528,13 +539,19 @@ func (sq *SqlQuery) SplitQuery(ctx context.Context, req *proto.SplitQueryRequest logStats: logStats, qe: sq.qe, } - conn := qre.getConn(sq.qe.connPool) + conn, err := qre.getConn(sq.qe.connPool) + if err != nil { + return err + } defer conn.Recycle() // TODO: For fetching MinMax, include where clauses on the // primary key, if any, in the original query which might give a narrower // range of split column to work with. - minMaxSql := fmt.Sprintf("SELECT MIN(%v), MAX(%v) FROM %v", splitter.splitColumn, splitter.splitColumn, splitter.tableName) - splitColumnMinMax := qre.execSQL(conn, minMaxSql, true) + minMaxSQL := fmt.Sprintf("SELECT MIN(%v), MAX(%v) FROM %v", splitter.splitColumn, splitter.splitColumn, splitter.tableName) + splitColumnMinMax, err := qre.execSQL(conn, minMaxSQL, true) + if err != nil { + return err + } reply.Queries, err = splitter.split(splitColumnMinMax) if err != nil { return NewTabletError(ErrFail, "splitQuery: query split error: %s, request: %#v", err, req)