diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 57ababf4762..b0ec7a2e55f 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -926,9 +926,15 @@ func (c *Conn) handleNextCommand(handler Handler) error { prepare.BindVars = make(map[string]*querypb.BindVariable, paramsCount) } + bindVars := make(map[string]*querypb.BindVariable, paramsCount) + for i := uint16(0); i < paramsCount; i++ { + parameterID := fmt.Sprintf("v%d", i+1) + bindVars[parameterID] = &querypb.BindVariable{} + } + c.PrepareData[c.StatementID] = prepare - fld, err := handler.ComPrepare(c, queries[0]) + fld, err := handler.ComPrepare(c, queries[0], bindVars) if err != nil { if werr := c.writeErrorPacketFromError(err); werr != nil { diff --git a/go/mysql/fakesqldb/server.go b/go/mysql/fakesqldb/server.go index 5ca771001ca..c7b6f570b56 100644 --- a/go/mysql/fakesqldb/server.go +++ b/go/mysql/fakesqldb/server.go @@ -430,7 +430,7 @@ func (db *DB) comQueryOrdered(query string) (*sqltypes.Result, error) { } // ComPrepare is part of the mysql.Handler interface. -func (db *DB) ComPrepare(c *mysql.Conn, query string) ([]*querypb.Field, error) { +func (db *DB) ComPrepare(c *mysql.Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) { return nil, nil } diff --git a/go/mysql/server.go b/go/mysql/server.go index d47804b0280..bc6e1e93266 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -101,7 +101,7 @@ type Handler interface { // ComPrepare is called when a connection receives a prepared // statement query. - ComPrepare(c *Conn, query string) ([]*querypb.Field, error) + ComPrepare(c *Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) // ComStmtExecute is called when a connection receives a statement // execute query. diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index a8d7dd3b0df..d84f6db8118 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -221,7 +221,7 @@ func (th *testHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.R return nil } -func (th *testHandler) ComPrepare(c *Conn, query string) ([]*querypb.Field, error) { +func (th *testHandler) ComPrepare(c *Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) { return nil, nil } diff --git a/go/test/endtoend/preparestmt/stmt_methods_test.go b/go/test/endtoend/preparestmt/stmt_methods_test.go index 1ee99ad7f44..1b65762ae83 100644 --- a/go/test/endtoend/preparestmt/stmt_methods_test.go +++ b/go/test/endtoend/preparestmt/stmt_methods_test.go @@ -205,6 +205,46 @@ func reconnectAndTest(t *testing.T) { } +// TestColumnParameter query database using column +// parameter. +func TestColumnParameter(t *testing.T) { + defer cluster.PanicHandler(t) + dbo := Connect(t) + defer dbo.Close() + + id := 1000 + parameter1 := "param1" + message := "TestColumnParameter" + insertStmt := "INSERT INTO " + tableName + " (id, msg, keyspace_id) VALUES (?, ?, ?);" + values := []interface{}{ + id, + message, + 2000, + } + exec(t, dbo, insertStmt, values...) + + var param, msg string + var recID int + + selectStmt := "SELECT COALESCE(?, id), msg FROM " + tableName + " WHERE msg = ? LIMIT ?" + + results1, err := dbo.Query(selectStmt, parameter1, message, 1) + require.Nil(t, err) + require.True(t, results1.Next()) + + results1.Scan(¶m, &msg) + assert.Equal(t, parameter1, param) + assert.Equal(t, message, msg) + + results2, err := dbo.Query(selectStmt, nil, message, 1) + require.Nil(t, err) + require.True(t, results2.Next()) + + results2.Scan(&recID, &msg) + assert.Equal(t, id, recID) + assert.Equal(t, message, msg) +} + // TestWrongTableName query database using invalid // tablename and validate error. func TestWrongTableName(t *testing.T) { diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index ff1a7a865db..ded421e99b7 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -218,7 +218,7 @@ func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sq } // ComPrepare is the handler for command prepare. -func (vh *vtgateHandler) ComPrepare(c *mysql.Conn, query string) ([]*querypb.Field, error) { +func (vh *vtgateHandler) ComPrepare(c *mysql.Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) { var ctx context.Context var cancel context.CancelFunc if *mysqlQueryTimeout != 0 { @@ -252,7 +252,7 @@ func (vh *vtgateHandler) ComPrepare(c *mysql.Conn, query string) ([]*querypb.Fie } }() - session, fld, err := vh.vtg.Prepare(ctx, session, query, make(map[string]*querypb.BindVariable)) + session, fld, err := vh.vtg.Prepare(ctx, session, query, bindVars) err = mysql.NewSQLErrorFromError(err) if err != nil { return nil, err diff --git a/go/vt/vtgate/plugin_mysql_server_test.go b/go/vt/vtgate/plugin_mysql_server_test.go index 6a43aa93015..711b1ff861b 100644 --- a/go/vt/vtgate/plugin_mysql_server_test.go +++ b/go/vt/vtgate/plugin_mysql_server_test.go @@ -51,7 +51,7 @@ func (th *testHandler) ComQuery(c *mysql.Conn, q string, callback func(*sqltypes return nil } -func (th *testHandler) ComPrepare(c *mysql.Conn, q string) ([]*querypb.Field, error) { +func (th *testHandler) ComPrepare(c *mysql.Conn, q string, b map[string]*querypb.BindVariable) ([]*querypb.Field, error) { return nil, nil }