Skip to content
3 changes: 0 additions & 3 deletions go/mysql/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,6 @@ const (
const (
// ERVitessMaxRowsExceeded is when a user tries to select more rows than the max rows as enforced by vitess.
ERVitessMaxRowsExceeded = 10001

// ERVitessShardError is when a shard query fails in a partial scatter statement
ERVitessShardError = 10002
)

// Error codes for server-side errors.
Expand Down
18 changes: 18 additions & 0 deletions go/vt/vtgate/engine/fake_vcursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ func (t noopVCursor) SetContextTimeout(timeout time.Duration) context.CancelFunc
return func() {}
}

func (t noopVCursor) RecordWarning(warning *querypb.QueryWarning) {
}

func (t noopVCursor) Execute(method string, query string, bindvars map[string]*querypb.BindVariable, isDML bool) (*sqltypes.Result, error) {
panic("unimplemented")
}
Expand Down Expand Up @@ -88,6 +91,8 @@ type loggingVCursor struct {
curResult int
resultErr error

warnings []*querypb.QueryWarning

// Optional errors that can be returned from nextResult() alongside the results for
// multi-shard queries
multiShardErrs []error
Expand All @@ -98,10 +103,15 @@ type loggingVCursor struct {
func (f *loggingVCursor) Context() context.Context {
return context.Background()
}

func (f *loggingVCursor) SetContextTimeout(timeout time.Duration) context.CancelFunc {
return func() {}
}

func (f *loggingVCursor) RecordWarning(warning *querypb.QueryWarning) {
f.warnings = append(f.warnings, warning)
}

func (f *loggingVCursor) Execute(method string, query string, bindvars map[string]*querypb.BindVariable, isDML bool) (*sqltypes.Result, error) {
f.log = append(f.log, fmt.Sprintf("Execute %s %v %v", query, printBindVars(bindvars), isDML))
return f.nextResult()
Expand Down Expand Up @@ -206,10 +216,18 @@ func (f *loggingVCursor) ExpectLog(t *testing.T, want []string) {
}
}

func (f *loggingVCursor) ExpectWarnings(t *testing.T, want []*querypb.QueryWarning) {
t.Helper()
if !reflect.DeepEqual(f.warnings, want) {
t.Errorf("vc.warnings:\n%+v\nwant:\n%+v", f.warnings, want)
}
}

func (f *loggingVCursor) Rewind() {
f.curShardForKsid = 0
f.curResult = 0
f.log = nil
f.warnings = nil
}

func (f *loggingVCursor) nextResult() (*sqltypes.Result, error) {
Expand Down
3 changes: 3 additions & 0 deletions go/vt/vtgate/engine/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ type VCursor interface {
// SetContextTimeout updates the context and sets a timeout.
SetContextTimeout(timeout time.Duration) context.CancelFunc

// RecordWarning stores the given warning in the current session
RecordWarning(warning *querypb.QueryWarning)

// V3 functions.
Execute(method string, query string, bindvars map[string]*querypb.BindVariable, isDML bool) (*sqltypes.Result, error)
ExecuteAutocommit(method string, query string, bindvars map[string]*querypb.BindVariable, isDML bool) (*sqltypes.Result, error)
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/engine/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"time"

"vitess.io/vitess/go/jsonutil"
"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/stats"
"vitess.io/vitess/go/vt/key"
Expand Down Expand Up @@ -224,6 +225,13 @@ func (route *Route) execute(vcursor VCursor, bindVars map[string]*querypb.BindVa
if errs != nil {
if route.ScatterErrorsAsWarnings {
partialSuccessScatterQueries.Add(1)

for _, err := range errs {
if err != nil {
serr := mysql.NewSQLErrorFromError(err).(*mysql.SQLError)
vcursor.RecordWarning(&querypb.QueryWarning{Code: uint32(serr.Num), Message: err.Error()})
}
}
// fall through
} else {
return nil, vterrors.Aggregate(errs)
Expand Down
23 changes: 16 additions & 7 deletions go/vt/vtgate/engine/route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import (
"errors"
"testing"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/vtgate/vindexes"

querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/vtgate/vindexes"
)

var defaultSelectResult = sqltypes.MakeTestResult(
Expand Down Expand Up @@ -793,13 +793,14 @@ func TestExecFail(t *testing.T) {
FieldQuery: "dummy_select_field",
}

vc := &loggingVCursor{shards: []string{"0"}, resultErr: errors.New("result error")}
vc := &loggingVCursor{shards: []string{"0"}, resultErr: mysql.NewSQLError(mysql.ERQueryInterrupted, "", "query timeout")}
_, err := sel.Execute(vc, map[string]*querypb.BindVariable{}, false)
expectError(t, "sel.Execute err", err, "result error")
expectError(t, "sel.Execute err", err, "query timeout (errno 1317) (sqlstate HY000)")
vc.ExpectWarnings(t, nil)

vc.Rewind()
_, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false)
expectError(t, "sel.StreamExecute err", err, "result error")
expectError(t, "sel.StreamExecute err", err, "query timeout (errno 1317) (sqlstate HY000)")

// Scatter fails if one of N fails without ScatterErrorsAsWarnings
sel = &Route{
Expand All @@ -821,6 +822,7 @@ func TestExecFail(t *testing.T) {
}
_, err = sel.Execute(vc, map[string]*querypb.BindVariable{}, false)
expectError(t, "sel.Execute err", err, "result error -20")
vc.ExpectWarnings(t, nil)
vc.ExpectLog(t, []string{
`ResolveDestinations ks [] Destinations:DestinationAllShards()`,
`ExecuteMultiShard ks.-20: dummy_select {} ks.20-: dummy_select {} false false`,
Expand All @@ -844,14 +846,21 @@ func TestExecFail(t *testing.T) {
shards: []string{"-20", "20-"},
results: []*sqltypes.Result{defaultSelectResult},
multiShardErrs: []error{
errors.New("result error -20"),
errors.New("result error 20-"),
mysql.NewSQLError(mysql.ERQueryInterrupted, "", "query timeout -20"),
errors.New("not a sql error 20-"),
},
}
_, err = sel.Execute(vc, map[string]*querypb.BindVariable{}, false)
if err != nil {
t.Errorf("unexpected ScatterErrorsAsWarnings error %v", err)
}

// Ensure that the error code is preserved from SQLErrors and that it
// turns into ERUnknownError for all others
vc.ExpectWarnings(t, []*querypb.QueryWarning{
{Code: mysql.ERQueryInterrupted, Message: "query timeout -20 (errno 1317) (sqlstate HY000)"},
{Code: mysql.ERUnknownError, Message: "not a sql error 20-"},
})
vc.ExpectLog(t, []string{
`ResolveDestinations ks [] Destinations:DestinationAllShards()`,
`ExecuteMultiShard ks.-20: dummy_select {} ks.20-: dummy_select {} false false`,
Expand Down
11 changes: 11 additions & 0 deletions go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,17 @@ func (e *Executor) execute(ctx context.Context, safeSession *SafeSession, sql st
stmtType := sqlparser.Preview(sql)
logStats.StmtType = sqlparser.StmtType(stmtType)

// Mysql warnings are scoped to the current session, but are
// cleared when a "non-diagnostic statement" is executed:
// https://dev.mysql.com/doc/refman/8.0/en/show-warnings.html
//
// To emulate this behavior, clear warnings from the session
// for all statements _except_ SHOW, so that SHOW WARNINGS
// can actually return them.
if stmtType != sqlparser.StmtShow {
safeSession.ClearWarnings()
}

switch stmtType {
case sqlparser.StmtSelect:
return e.handleExec(ctx, safeSession, sql, bindVars, destKeyspace, destTabletType, dest, logStats)
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ func TestExecutorShow(t *testing.T) {

session.Warnings = []*querypb.QueryWarning{
{Code: mysql.ERBadTable, Message: "bad table"},
{Code: mysql.ERVitessShardError, Message: "ks/-40: query timed out"},
{Code: mysql.EROutOfResources, Message: "ks/-40: query timed out"},
}
qr, err = executor.Execute(context.Background(), "TestExecute", session, "show warnings", nil)
wantqr = &sqltypes.Result{
Expand All @@ -782,7 +782,7 @@ func TestExecutorShow(t *testing.T) {

Rows: [][]sqltypes.Value{
{sqltypes.NewVarChar("Warning"), sqltypes.NewUint32(mysql.ERBadTable), sqltypes.NewVarChar("bad table")},
{sqltypes.NewVarChar("Warning"), sqltypes.NewUint32(mysql.ERVitessShardError), sqltypes.NewVarChar("ks/-40: query timed out")},
{sqltypes.NewVarChar("Warning"), sqltypes.NewUint32(mysql.EROutOfResources), sqltypes.NewVarChar("ks/-40: query timed out")},
},
RowsAffected: 0,
}
Expand Down
4 changes: 4 additions & 0 deletions go/vt/vtgate/plugin_mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sq
}

func (vh *vtgateHandler) WarningCount(c *mysql.Conn) uint16 {
session, _ := c.ClientData.(*vtgatepb.Session)
if session != nil {
return uint16(len(session.GetWarnings()))
}
return 0
}

Expand Down
15 changes: 15 additions & 0 deletions go/vt/vtgate/safe_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/golang/protobuf/proto"
"vitess.io/vitess/go/vt/vterrors"

querypb "vitess.io/vitess/go/vt/proto/query"
topodatapb "vitess.io/vitess/go/vt/proto/topodata"
vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
Expand Down Expand Up @@ -182,6 +183,20 @@ func (session *SafeSession) MustRollback() bool {
return session.mustRollback
}

// RecordWarning stores the given warning in the session
func (session *SafeSession) RecordWarning(warning *querypb.QueryWarning) {
session.mu.Lock()
defer session.mu.Unlock()
session.Session.Warnings = append(session.Session.Warnings, warning)
}

// ClearWarnings removes all the warnings from the session
func (session *SafeSession) ClearWarnings() {
session.mu.Lock()
defer session.mu.Unlock()
session.Session.Warnings = nil
}

// Reset clears the session
func (session *SafeSession) Reset() {
if session == nil || session.Session == nil {
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vtgate/vcursor_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ func (vc *vcursorImpl) SetContextTimeout(timeout time.Duration) context.CancelFu
return cancel
}

// RecordWarning stores the given warning in the current session
func (vc *vcursorImpl) RecordWarning(warning *querypb.QueryWarning) {
vc.safeSession.RecordWarning(warning)
}

// FindTable finds the specified table. If the keyspace what specified in the input, it gets used as qualifier.
// Otherwise, the keyspace from the request is used, if one was provided.
func (vc *vcursorImpl) FindTable(name sqlparser.TableName) (*vindexes.Table, string, topodatapb.TabletType, key.Destination, error) {
Expand Down
4 changes: 4 additions & 0 deletions py/vtdb/grpc_vtgate_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ def message_ack(

return response.result.rows_affected

def get_warnings(self):
if self.session:
return self.session.warnings
return []

def _convert_exception(exc, *args, **kwargs):
"""This parses the protocol exceptions to the api interface exceptions.
Expand Down
10 changes: 10 additions & 0 deletions py/vtdb/vtgate_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,13 @@ def message_ack(self,
dbexceptions.FatalError: this query should not be retried.
"""
raise NotImplementedError('Child class needs to implement this')

def get_warnings(self):
"""Get warnings from the previous query

Returns:
The list of warnings.

"""
raise NotImplementedError('Child class needs to implement this')

58 changes: 58 additions & 0 deletions test/mysql_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import environment
import utils
import tablet
import warnings

# single shard / 2 tablets
shard_0_master = tablet.Tablet()
Expand Down Expand Up @@ -202,6 +203,63 @@ def test_mysql_connector(self):
self.assertIn('1317', s)
conn.close()

# this query should fail due to the bogus field
conn = MySQLdb.Connect(**params)
try:
cursor = conn.cursor()
cursor.execute('SELECT invalid_field from vt_insert_test', {})
self.fail('Execute went through')
except MySQLdb.OperationalError, e:
s = str(e)
# 1054 is BadFieldError code
self.assertIn('1054', s)

# this query should trigger a warning not an error
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

cursor.execute('SELECT /*vt+ SCATTER_ERRORS_AS_WARNINGS */ invalid_field from vt_insert_test', {})
if cursor.rowcount != 0:
self.fail('expected 0 rows got ' + str(cursor.rowcount))

if len(w) != 1:
print 'unexpected warnings: ', w

# and the next query should get the warnings
cursor.execute('SHOW WARNINGS', {})
if cursor.rowcount != 1:
print 'expected 1 warning row, got ' + str(cursor.rowcount)

for (_, code, message) in cursor:
self.assertEqual(code, 1054)
self.assertIn('errno 1054', message)
self.assertIn('Unknown column', message)

# test with a query timeout error
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

cursor.execute('SELECT /*vt+ SCATTER_ERRORS_AS_WARNINGS QUERY_TIMEOUT_MS=1 */ sleep(1) from vt_insert_test', {})
if cursor.rowcount != 0:
self.fail('expected 0 rows got ' + str(cursor.rowcount))

if len(w) != 1:
print 'unexpected warnings: ', w

cursor.execute('SHOW WARNINGS', {})
if cursor.rowcount != 1:
print 'expected 1 warning row, got ' + str(cursor.rowcount)

for (_, code, message) in cursor:
self.assertEqual(code, 1317)
self.assertIn('context deadline exceeded', message)

# any non-show query clears the warnings
cursor.execute('SELECT 1 from vt_insert_test limit 1', {})
cursor.execute('SHOW WARNINGS', {})
if cursor.rowcount != 0:
print 'expected 0 warnings row, got ' + str(cursor.rowcount)

# 'vtgate client 2' is not authorized to access vt_insert_test
params['user'] = 'testuser2'
params['passwd'] = 'testpassword2'
Expand Down
27 changes: 27 additions & 0 deletions test/vtgatev3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,33 @@ def test_user(self):
([(0,)], 1L, 0,
[(u'SLEEP(1)', self.int_type)]))

# test shard errors as warnings directive
cursor.execute('SELECT /*vt+ SCATTER_ERRORS_AS_WARNINGS */ bad from vt_user', {})
print vtgate_conn.get_warnings()
warnings = vtgate_conn.get_warnings()
self.assertEqual(len(warnings), 2)
for warning in warnings:
self.assertEqual(warning.code, 1054)
self.assertIn('errno 1054', warning.message)
self.assertIn('Unknown column', warning.message)
self.assertEqual(
(cursor.fetchall(), cursor.rowcount, cursor.lastrowid,
cursor.description),
([], 0L, 0, []))

# test shard errors as warnings directive with timeout
cursor.execute('SELECT /*vt+ SCATTER_ERRORS_AS_WARNINGS QUERY_TIMEOUT_MS=10 */ SLEEP(1)', {})
print vtgate_conn.get_warnings()
warnings = vtgate_conn.get_warnings()
self.assertEqual(len(warnings), 1)
for warning in warnings:
self.assertEqual(warning.code, 1317)
self.assertIn('context deadline exceeded', warning.message)
self.assertEqual(
(cursor.fetchall(), cursor.rowcount, cursor.lastrowid,
cursor.description),
([], 0L, 0, []))

# Test insert with no auto-inc
vtgate_conn.begin()
result = self.execute_on_master(
Expand Down