Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,10 +778,7 @@ func (c *Conn) handleNextCommand(handler Handler) error {
case ComInitDB:
db := c.parseComInitDB(data)
c.recycleReadPacket()
c.schemaName = db
handler.ComInitDB(c, db)
if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil {
log.Errorf("Error writing ComInitDB result to %s: %v", c, err)
if err := c.execQuery(fmt.Sprintf("use `%s`", db), handler, false); err != nil {
return err
}
case ComQuery:
Expand Down
4 changes: 0 additions & 4 deletions go/mysql/fakesqldb/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,6 @@ func (db *DB) ConnectionClosed(c *mysql.Conn) {
delete(db.connections, c.ConnectionID)
}

// ComInitDB is part of the mysql.Handler interface.
func (db *DB) ComInitDB(c *mysql.Conn, schemaName string) {
}

// ComQuery is part of the mysql.Handler interface.
func (db *DB) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error {
return db.Handler.HandleQuery(c, query, callback)
Expand Down
19 changes: 12 additions & 7 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package mysql

import (
"crypto/tls"
"fmt"
"io"
"net"
"strings"
Expand Down Expand Up @@ -90,10 +91,6 @@ type Handler interface {
// ConnectionClosed is called when a connection is closed.
ConnectionClosed(c *Conn)

// InitDB is called once at the beginning to set db name,
// and subsequently for every ComInitDB event.
ComInitDB(c *Conn, schemaName string)

// ComQuery is called when a connection receives a query.
// Note the contents of the query slice may change after
// the first call to callback. So the Handler should not
Expand Down Expand Up @@ -441,6 +438,17 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti
defer connCountPerUser.Add(c.User, -1)
}

// Set initial db name.
if c.schemaName != "" {
err = l.handler.ComQuery(c, fmt.Sprintf("use `%s`", c.schemaName), func(result *sqltypes.Result) error {
return nil
})
if err != nil {
c.writeErrorPacketFromError(err)
return
}
}

// Negotiation worked, send OK packet.
if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil {
log.Errorf("Cannot write OK packet to %s: %v", c, err)
Expand All @@ -457,9 +465,6 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti
log.Warningf("Slow connection from %s: %v", c, connectTime)
}

// Set initial db name.
l.handler.ComInitDB(c, c.schemaName)

for {
err := c.handleNextCommand(l.handler)
if err != nil {
Expand Down
3 changes: 0 additions & 3 deletions go/mysql/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,6 @@ func (th *testHandler) NewConnection(c *Conn) {
func (th *testHandler) ConnectionClosed(c *Conn) {
}

func (th *testHandler) ComInitDB(c *Conn, schemaName string) {
}

func (th *testHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error {
if result := th.Result(); result != nil {
callback(result)
Expand Down
7 changes: 1 addition & 6 deletions go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,7 @@ func (e *Executor) execute(ctx context.Context, safeSession *SafeSession, sql st
// addNeededBindVars adds bind vars that are needed by the plan
func (e *Executor) addNeededBindVars(bindVarNeeds sqlparser.BindVarNeeds, bindVars map[string]*querypb.BindVariable, session *SafeSession) error {
if bindVarNeeds.NeedDatabase {
keyspace, _, _, _ := e.ParseDestinationTarget(session.TargetString)
if keyspace == "" {
bindVars[sqlparser.DBVarName] = sqltypes.NullBindVariable
} else {
bindVars[sqlparser.DBVarName] = sqltypes.StringBindVariable(keyspace)
}
bindVars[sqlparser.DBVarName] = sqltypes.StringBindVariable(session.TargetString)
}

if bindVarNeeds.NeedLastInsertID {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ func TestSelectDatabase(t *testing.T) {
{Name: "database()", Type: sqltypes.VarBinary},
},
Rows: [][]sqltypes.Value{{
sqltypes.NewVarBinary("TestExecutor"),
sqltypes.NewVarBinary("TestExecutor@master"),
}},
}
require.NoError(t, err)
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func TestDirectTargetRewrites(t *testing.T) {
require.NoError(t, err)
testQueries(t, "sbclookup", sbclookup, []*querypb.BoundQuery{{
Sql: "select :__vtdbname as `database()` from dual",
BindVariables: map[string]*querypb.BindVariable{"__vtdbname": sqltypes.StringBindVariable("TestUnsharded")},
BindVariables: map[string]*querypb.BindVariable{"__vtdbname": sqltypes.StringBindVariable("TestUnsharded/0@master")},
}})
}

Expand Down Expand Up @@ -1038,7 +1038,7 @@ func TestExecutorUse(t *testing.T) {
}

_, err = executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{}), "use UnexistentKeyspace", nil)
wantErr = "invalid keyspace provided: UnexistentKeyspace"
wantErr = "Unknown database 'UnexistentKeyspace' (errno 1049) (sqlstate 42000)"
if err == nil || err.Error() != wantErr {
t.Errorf("got: %v, want %v", err, wantErr)
}
Expand Down
16 changes: 3 additions & 13 deletions go/vt/vtgate/mysql_protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func TestMySQLProtocolExecuteUseStatement(t *testing.T) {
// No such keyspace this will fail
_, err = c.ExecuteFetch("use InvalidKeyspace", 0, false)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid keyspace provided: InvalidKeyspace")
assert.Contains(t, err.Error(), "Unknown database 'InvalidKeyspace' (errno 1049) (sqlstate 42000)")

// That doesn't reset the vitess_target
qr, err = c.ExecuteFetch("show vitess_target", 1, false)
Expand All @@ -135,18 +135,8 @@ func TestMySQLProtocolExecuteUseStatement(t *testing.T) {
}

func TestMysqlProtocolInvalidDB(t *testing.T) {
c, err := mysqlConnect(&mysql.ConnParams{DbName: "invalidDB"})
if err != nil {
t.Fatal(err)
}
defer c.Close()

_, err = c.ExecuteFetch("select id from t1", 10, true /* wantfields */)
c.Close()
want := "vtgate: : keyspace invalidDB not found in vschema (errno 1105) (sqlstate HY000) during query: select id from t1"
if err == nil || err.Error() != want {
t.Errorf("exec with db:\n%v, want\n%s", err, want)
}
_, err := mysqlConnect(&mysql.ConnParams{DbName: "invalidDB"})
require.EqualError(t, err, "vtgate: : Unknown database 'invalidDB' (errno 1049) (sqlstate 42000) (errno 1049) (sqlstate 42000)")
}

func TestMySQLProtocolClientFoundRows(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/plan_executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ func TestPlanSelectDatabase(t *testing.T) {
{Name: "database()", Type: sqltypes.VarBinary},
},
Rows: [][]sqltypes.Value{{
sqltypes.NewVarBinary("TestExecutor"),
sqltypes.NewVarBinary("TestExecutor@master"),
}},
}
require.NoError(t, err)
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/plan_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func TestPlanDirectTargetRewrites(t *testing.T) {
require.NoError(t, err)
testQueries(t, "sbclookup", sbclookup, []*querypb.BoundQuery{{
Sql: "select :__vtdbname as `database()` from dual",
BindVariables: map[string]*querypb.BindVariable{"__vtdbname": sqltypes.StringBindVariable("TestUnsharded")},
BindVariables: map[string]*querypb.BindVariable{"__vtdbname": sqltypes.StringBindVariable("TestUnsharded/0@master")},
}})
}

Expand Down Expand Up @@ -996,7 +996,7 @@ func TestPlanExecutorUse(t *testing.T) {
}

_, err = executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{}), "use UnexistentKeyspace", nil)
wantErr = "invalid keyspace provided: UnexistentKeyspace"
wantErr = "Unknown database 'UnexistentKeyspace' (errno 1049) (sqlstate 42000)"
if err == nil || err.Error() != wantErr {
t.Errorf("got: %v, want %v", err, wantErr)
}
Expand Down
4 changes: 0 additions & 4 deletions go/vt/vtgate/plugin_mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,6 @@ func startSpan(ctx context.Context, query, label string) (trace.Span, context.Co
return startSpanTestable(ctx, query, label, trace.NewSpan, trace.NewFromString)
}

func (vh *vtgateHandler) ComInitDB(c *mysql.Conn, schemaName string) {
vh.session(c).TargetString = schemaName
}

func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error {
ctx := context.Background()
var cancel context.CancelFunc
Expand Down
3 changes: 0 additions & 3 deletions go/vt/vtgate/plugin_mysql_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ func (th *testHandler) NewConnection(c *mysql.Conn) {
func (th *testHandler) ConnectionClosed(c *mysql.Conn) {
}

func (th *testHandler) ComInitDB(c *mysql.Conn, schemaName string) {
}

func (th *testHandler) ComQuery(c *mysql.Conn, q string, callback func(*sqltypes.Result) error) error {
return nil
}
Expand Down
4 changes: 3 additions & 1 deletion go/vt/vtgate/vcursor_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"sync/atomic"
"time"

"vitess.io/vitess/go/mysql"

"vitess.io/vitess/go/vt/callerid"
vschemapb "vitess.io/vitess/go/vt/proto/vschema"
"vitess.io/vitess/go/vt/topotools"
Expand Down Expand Up @@ -322,7 +324,7 @@ func (vc *vcursorImpl) SetTarget(target string) error {
return err
}
if _, ok := vc.vschema.Keyspaces[keyspace]; keyspace != "" && !ok {
return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid keyspace provided: %s", keyspace)
return mysql.NewSQLError(mysql.ERBadDb, "42000", "Unknown database '%s'", keyspace)
}

if vc.safeSession.InTransaction() && tabletType != topodatapb.TabletType_MASTER {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/vcursor_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func TestSetTarget(t *testing.T) {
}, {
vschema: vschemaWith2KS,
targetString: "ks3",
expectedError: "invalid keyspace provided: ks3",
expectedError: "Unknown database 'ks3' (errno 1049) (sqlstate 42000)",
}, {
vschema: vschemaWith2KS,
targetString: "ks2@replica",
Expand Down