diff --git a/go/mysql/mysql_fuzzer.go b/go/mysql/mysql_fuzzer.go index af5a4d53e9e..d863845afa2 100644 --- a/go/mysql/mysql_fuzzer.go +++ b/go/mysql/mysql_fuzzer.go @@ -18,13 +18,19 @@ limitations under the License. package mysql import ( + "context" "fmt" + "io/ioutil" "net" + "os" + "path" "sync" "time" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/tlstest" + "vitess.io/vitess/go/vt/vttls" ) func createFuzzingSocketPair() (net.Listener, *Conn, *Conn) { @@ -216,3 +222,135 @@ func FuzzReadQueryResults(data []byte) int { } return 1 } + +type fuzzTestHandler struct { + mu sync.Mutex + lastConn *Conn + result *sqltypes.Result + err error + warnings uint16 +} + +func (th *fuzzTestHandler) LastConn() *Conn { + th.mu.Lock() + defer th.mu.Unlock() + return th.lastConn +} + +func (th *fuzzTestHandler) Result() *sqltypes.Result { + th.mu.Lock() + defer th.mu.Unlock() + return th.result +} + +func (th *fuzzTestHandler) SetErr(err error) { + th.mu.Lock() + defer th.mu.Unlock() + th.err = err +} + +func (th *fuzzTestHandler) Err() error { + th.mu.Lock() + defer th.mu.Unlock() + return th.err +} + +func (th *fuzzTestHandler) SetWarnings(count uint16) { + th.mu.Lock() + defer th.mu.Unlock() + th.warnings = count +} + +func (th *fuzzTestHandler) NewConnection(c *Conn) { + th.mu.Lock() + defer th.mu.Unlock() + th.lastConn = c +} + +func (th *fuzzTestHandler) ConnectionClosed(_ *Conn) { +} + +func (th *fuzzTestHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error { + + return nil +} + +func (th *fuzzTestHandler) ComPrepare(c *Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) { + return nil, nil +} + +func (th *fuzzTestHandler) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error { + return nil +} + +func (th *fuzzTestHandler) ComResetConnection(c *Conn) { + +} + +func (th *fuzzTestHandler) WarningCount(c *Conn) uint16 { + th.mu.Lock() + defer th.mu.Unlock() + return th.warnings +} + +func FuzzTLSServer(data []byte) int { + th := &fuzzTestHandler{} + + authServer := NewAuthServerStatic("", "", 0) + authServer.entries["user1"] = []*AuthServerStaticEntry{{ + Password: "password1", + }} + defer authServer.close() + l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false) + if err != nil { + return -1 + } + defer l.Close() + host, err := os.Hostname() + if err != nil { + return -1 + } + port := l.Addr().(*net.TCPAddr).Port + root, err := ioutil.TempDir("", "TestTLSServer") + if err != nil { + return -1 + } + defer os.RemoveAll(root) + tlstest.CreateCA(root) + tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", host) + tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") + + serverConfig, err := vttls.ServerConfig( + path.Join(root, "server-cert.pem"), + path.Join(root, "server-key.pem"), + path.Join(root, "ca-cert.pem"), + "") + if err != nil { + return -1 + } + l.TLSConfig.Store(serverConfig) + go l.Accept() + + connCountByTLSVer.ResetAll() + // Setup the right parameters. + params := &ConnParams{ + Host: host, + Port: port, + Uname: "user1", + Pass: "password1", + // SSL flags. + Flags: CapabilityClientSSL, + SslCa: path.Join(root, "ca-cert.pem"), + SslCert: path.Join(root, "client-cert.pem"), + SslKey: path.Join(root, "client-key.pem"), + } + conn, err := Connect(context.Background(), params) + if err != nil { + return -1 + } + _, err = conn.ExecuteFetch(string(data), 1000, true) + if err != nil { + return 0 + } + return 1 +}