From 8b4711353793b5410c8dee9f44377bb11f9aa728 Mon Sep 17 00:00:00 2001 From: michael2008 Date: Fri, 21 Dec 2018 10:24:29 +0800 Subject: [PATCH] Add support for MySQL 8.0 and support for TLS/SSL for both Server and Client (#312) --- Gopkg.lock | 16 +- Gopkg.toml | 8 +- README.md | 33 +- client/auth.go | 176 +++++-- client/client_test.go | 67 ++- client/conn.go | 21 +- client/resp.go | 120 ++++- client/tls.go | 28 ++ docker/docker-compose.yaml | 80 +++ docker/resources/ca.key | 27 + docker/resources/ca.pem | 22 + docker/resources/client-cert.pem | 19 + docker/resources/client-key.pem | 27 + docker/resources/server-cert.pem | 19 + docker/resources/server-key.pem | 27 + driver/dirver_test.go | 1 + mysql/const.go | 17 +- mysql/field.go | 12 +- mysql/resultset.go | 4 +- mysql/util.go | 103 +++- packet/conn.go | 131 +++-- replication/binlogsyncer.go | 2 +- replication/row_event.go | 2 +- server/auth.go | 235 +++++---- server/auth_switch_response.go | 133 +++++ server/caching_sha2_cache_test.go | 233 +++++++++ server/command.go | 4 +- server/conn.go | 72 ++- server/credential_provider.go | 45 ++ server/example/server_example.go | 51 ++ server/handshake_resp.go | 190 +++++++ server/initial_handshake.go | 57 +++ server/resp.go | 53 ++ server/server_conf.go | 103 ++++ server/server_test.go | 252 ++++++---- server/ssl.go | 133 +++++ server/stmt.go | 10 +- test_util/test_keys/keys.go | 85 ++++ vendor/github.com/BurntSushi/toml/COPYING | 14 + .../toml/cmd/toml-test-decoder/COPYING | 14 + .../toml/cmd/toml-test-encoder/COPYING | 14 + .../BurntSushi/toml/cmd/tomlv/COPYING | 14 + vendor/github.com/go-sql-driver/mysql/AUTHORS | 90 ++++ .../go-sql-driver/mysql/appengine.go | 2 +- vendor/github.com/go-sql-driver/mysql/auth.go | 420 ++++++++++++++++ .../github.com/go-sql-driver/mysql/buffer.go | 12 +- .../go-sql-driver/mysql/collations.go | 1 + .../go-sql-driver/mysql/connection.go | 341 +++++++++++-- .../github.com/go-sql-driver/mysql/const.go | 25 +- .../github.com/go-sql-driver/mysql/driver.go | 84 ++-- vendor/github.com/go-sql-driver/mysql/dsn.go | 159 ++++-- .../github.com/go-sql-driver/mysql/errors.go | 79 +-- .../github.com/go-sql-driver/mysql/fields.go | 194 ++++++++ .../github.com/go-sql-driver/mysql/infile.go | 6 +- .../github.com/go-sql-driver/mysql/packets.go | 401 +++++++-------- vendor/github.com/go-sql-driver/mysql/rows.go | 174 +++++-- .../go-sql-driver/mysql/statement.go | 120 +++-- .../go-sql-driver/mysql/transaction.go | 4 +- .../github.com/go-sql-driver/mysql/utils.go | 463 +++++++++--------- .../shopspring/decimal/decimal-go.go | 414 ++++++++++++++++ .../github.com/shopspring/decimal/decimal.go | 341 ++++++++++++- .../github.com/shopspring/decimal/rounding.go | 118 +++++ vendor/google.golang.org/appengine/LICENSE | 202 ++++++++ .../appengine/cloudsql/cloudsql.go | 62 +++ .../appengine/cloudsql/cloudsql_classic.go | 17 + .../appengine/cloudsql/cloudsql_vm.go | 16 + 66 files changed, 5330 insertions(+), 1089 deletions(-) create mode 100644 client/tls.go create mode 100644 docker/docker-compose.yaml create mode 100644 docker/resources/ca.key create mode 100644 docker/resources/ca.pem create mode 100644 docker/resources/client-cert.pem create mode 100644 docker/resources/client-key.pem create mode 100644 docker/resources/server-cert.pem create mode 100644 docker/resources/server-key.pem create mode 100644 server/auth_switch_response.go create mode 100644 server/caching_sha2_cache_test.go create mode 100644 server/credential_provider.go create mode 100644 server/example/server_example.go create mode 100644 server/handshake_resp.go create mode 100644 server/initial_handshake.go create mode 100644 server/server_conf.go create mode 100644 server/ssl.go create mode 100644 test_util/test_keys/keys.go create mode 100644 vendor/github.com/BurntSushi/toml/COPYING create mode 100644 vendor/github.com/BurntSushi/toml/cmd/toml-test-decoder/COPYING create mode 100644 vendor/github.com/BurntSushi/toml/cmd/toml-test-encoder/COPYING create mode 100644 vendor/github.com/BurntSushi/toml/cmd/tomlv/COPYING create mode 100644 vendor/github.com/go-sql-driver/mysql/AUTHORS create mode 100644 vendor/github.com/go-sql-driver/mysql/auth.go create mode 100644 vendor/github.com/go-sql-driver/mysql/fields.go create mode 100644 vendor/github.com/shopspring/decimal/decimal-go.go create mode 100644 vendor/github.com/shopspring/decimal/rounding.go create mode 100644 vendor/google.golang.org/appengine/LICENSE create mode 100644 vendor/google.golang.org/appengine/cloudsql/cloudsql.go create mode 100644 vendor/google.golang.org/appengine/cloudsql/cloudsql_classic.go create mode 100644 vendor/google.golang.org/appengine/cloudsql/cloudsql_vm.go diff --git a/Gopkg.lock b/Gopkg.lock index 8fe7aa562..ae65b1d28 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -8,10 +8,10 @@ version = "v0.3.0" [[projects]] + branch = "master" name = "github.com/go-sql-driver/mysql" packages = ["."] - revision = "a0583e0143b1624142adab07e0e97fe106d99561" - version = "v1.3.0" + revision = "99ff426eb706cffe92ff3d058e168b278cabf7c7" [[projects]] branch = "master" @@ -43,8 +43,8 @@ [[projects]] name = "github.com/shopspring/decimal" packages = ["."] - revision = "69b3a8ad1f5f2c8bd855cb6506d18593064a346b" - version = "1.0.1" + revision = "cd690d0c9e2447b1ef2a129a6b7b49077da89b8e" + version = "1.1.0" [[projects]] branch = "master" @@ -64,9 +64,15 @@ ] revision = "a4d157e46fa3e08b7e7ff329af341fa3ff86c02c" +[[projects]] + name = "google.golang.org/appengine" + packages = ["cloudsql"] + revision = "b1f26356af11148e710935ed1ac8a7f5702c7612" + version = "v1.1.0" + [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "3a6f74dc90e82303c48131160fe66213a85d53581b2a0e0c7b36a8df0463fdd0" + inputs-digest = "a1f9939938a58551bbb3f19411c9d1386995d36296de6f6fb5d858f5923db85e" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index aecf2cacd..71df4b331 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -27,11 +27,11 @@ [[constraint]] name = "github.com/BurntSushi/toml" - version = "0.3.0" + version = "v0.3.0" [[constraint]] name = "github.com/go-sql-driver/mysql" - version = "1.3.0" + branch = "master" [[constraint]] branch = "master" @@ -39,11 +39,11 @@ [[constraint]] name = "github.com/satori/go.uuid" - version = "1.1.0" + version = "v1.2.0" [[constraint]] name = "github.com/shopspring/decimal" - version = "1.0.0" + version = "v1.1.0" [[constraint]] branch = "master" diff --git a/README.md b/README.md index 1d31dbc4d..0b958c746 100644 --- a/README.md +++ b/README.md @@ -138,9 +138,16 @@ import ( "github.com/siddontang/go-mysql/client" ) -// Connect MySQL at 127.0.0.1:3306, with user root, an empty passowrd and database test +// Connect MySQL at 127.0.0.1:3306, with user root, an empty password and database test conn, _ := client.Connect("127.0.0.1:3306", "root", "", "test") +// Or to use SSL/TLS connection if MySQL server supports TLS +//conn, _ := client.Connect("127.0.0.1:3306", "root", "", "test", func(c *Conn) {c.UseSSL(true)}) + +// or to set your own client-side certificates for identity verification for security +//tlsConfig := NewClientTLSConfig(caPem, certPem, keyPem, false, "your-server-name") +//conn, _ := client.Connect("127.0.0.1:3306", "root", "", "test", func(c *Conn) {c.SetTLSConfig(tlsConfig)}) + conn.Ping() // Insert @@ -157,10 +164,17 @@ v, _ := r.GetInt(0, 0) v, _ = r.GetIntByName(0, "id") ``` +Tested MySQL versions for the client include: +- 5.5.x +- 5.6.x +- 5.7.x +- 8.0.x + ## Server Server package supplies a framework to implement a simple MySQL server which can handle the packets from the MySQL client. -You can use it to build your own MySQL proxy. +You can use it to build your own MySQL proxy. The server connection is compatible with MySQL 5.5, 5.6, 5.7, and 8.0 versions, +so that most MySQL clients should be able to connect to the Server without modifications. ### Example @@ -174,14 +188,14 @@ l, _ := net.Listen("tcp", "127.0.0.1:4000") c, _ := l.Accept() -// Create a connection with user root and an empty passowrd -// We only an empty handler to handle command too +// Create a connection with user root and an empty password. +// You can use your own handler to handle command here. conn, _ := server.NewConn(c, "root", "", server.EmptyHandler{}) for { conn.HandleCommand() } -``` +``` Another shell @@ -190,6 +204,15 @@ mysql -h127.0.0.1 -P4000 -uroot -p //Becuase empty handler does nothing, so here the MySQL client can only connect the proxy server. :-) ``` +> ```NewConn()``` will use default server configurations: +> 1. automatically generate default server certificates and enable TLS/SSL support. +> 2. support three mainstream authentication methods **'mysql_native_password'**, **'caching_sha2_password'**, and **'sha256_password'** +> and use **'mysql_native_password'** as default. +> 3. use an in-memory user credential provider to store user and password. +> +> To customize server configurations, use ```NewServer()``` and create connection via ```NewCustomizedConn()```. + + ## Failover Failover supports to promote a new master and let other slaves replicate from it automatically when the old master was down. diff --git a/client/auth.go b/client/auth.go index 85b688c28..5ba9c9f4e 100644 --- a/client/auth.go +++ b/client/auth.go @@ -4,12 +4,29 @@ import ( "bytes" "crypto/tls" "encoding/binary" + "fmt" "github.com/juju/errors" . "github.com/siddontang/go-mysql/mysql" "github.com/siddontang/go-mysql/packet" ) +const defaultAuthPluginName = AUTH_NATIVE_PASSWORD + +// defines the supported auth plugins +var supportedAuthPlugins = []string{AUTH_NATIVE_PASSWORD, AUTH_SHA256_PASSWORD, AUTH_CACHING_SHA2_PASSWORD} + +// helper function to determine what auth methods are allowed by this client +func authPluginAllowed(pluginName string) bool { + for _, p := range supportedAuthPlugins { + if pluginName == p { + return true + } + } + return false +} + +// See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake func (c *Conn) readInitialHandshake() error { data, err := c.ReadPacket() if err != nil { @@ -24,39 +41,44 @@ func (c *Conn) readInitialHandshake() error { return errors.Errorf("invalid protocol version %d, must >= 10", data[0]) } - //skip mysql version - //mysql version end with 0x00 + // skip mysql version + // mysql version end with 0x00 pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 - //connection id length is 4 + // connection id length is 4 c.connectionID = uint32(binary.LittleEndian.Uint32(data[pos : pos+4])) pos += 4 c.salt = []byte{} c.salt = append(c.salt, data[pos:pos+8]...) - //skip filter + // skip filter pos += 8 + 1 - //capability lower 2 bytes + // capability lower 2 bytes c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2])) - + // check protocol + if c.capability&CLIENT_PROTOCOL_41 == 0 { + return errors.New("the MySQL server can not support protocol 41 and above required by the client") + } + if c.capability&CLIENT_SSL == 0 && c.tlsConfig != nil { + return errors.New("the MySQL Server does not support TLS required by the client") + } pos += 2 if len(data) > pos { - //skip server charset + // skip server charset //c.charset = data[pos] pos += 1 c.status = binary.LittleEndian.Uint16(data[pos : pos+2]) pos += 2 - + // capability flags (upper 2 bytes) c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability - pos += 2 - //skip auth data len or [00] - //skip reserved (all [00]) + // skip auth data len or [00] + // skip reserved (all [00]) pos += 10 + 1 // The documentation is ambiguous about the length. @@ -64,78 +86,131 @@ func (c *Conn) readInitialHandshake() error { // mysql-proxy also use 12 // which is not documented but seems to work. c.salt = append(c.salt, data[pos:pos+12]...) + pos += 13 + // auth plugin + if end := bytes.IndexByte(data[pos:], 0x00); end != -1 { + c.authPluginName = string(data[pos : pos+end]) + } else { + c.authPluginName = string(data[pos:]) + } + } + + // if server gives no default auth plugin name, use a client default + if c.authPluginName == "" { + c.authPluginName = defaultAuthPluginName } return nil } +// generate auth response data according to auth plugin +// +// NOTE: the returned boolean value indicates whether to add a \NUL to the end of data. +// it is quite tricky because MySQl server expects different formats of responses in different auth situations. +// here the \NUL needs to be added when sending back the empty password or cleartext password in 'sha256_password' +// authentication. +func (c *Conn) genAuthResponse(authData []byte) ([]byte, bool, error) { + // password hashing + switch c.authPluginName { + case AUTH_NATIVE_PASSWORD: + return CalcPassword(authData[:20], []byte(c.password)), false, nil + case AUTH_CACHING_SHA2_PASSWORD: + return CalcCachingSha2Password(authData, c.password), false, nil + case AUTH_SHA256_PASSWORD: + if len(c.password) == 0 { + return nil, true, nil + } + if c.tlsConfig != nil || c.proto == "unix" { + // write cleartext auth packet + // see: https://dev.mysql.com/doc/refman/8.0/en/sha256-pluggable-authentication.html + return []byte(c.password), true, nil + } else { + // request public key from server + // see: https://dev.mysql.com/doc/internals/en/public-key-retrieval.html + return []byte{1}, false, nil + } + default: + // not reachable + return nil, false, fmt.Errorf("auth plugin '%s' is not supported", c.authPluginName) + } +} + +// See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse func (c *Conn) writeAuthHandshake() error { + if !authPluginAllowed(c.authPluginName) { + return fmt.Errorf("unknow auth plugin name '%s'", c.authPluginName) + } // Adjust client capability flags based on server support capability := CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | - CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_LONG_FLAG + CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_PLUGIN_AUTH | c.capability&CLIENT_LONG_FLAG // To enable TLS / SSL - if c.TLSConfig != nil { - capability |= CLIENT_PLUGIN_AUTH + if c.tlsConfig != nil { capability |= CLIENT_SSL } - capability &= c.capability + auth, addNull, err := c.genAuthResponse(c.salt) + if err != nil { + return err + } + + // encode length of the auth plugin data + // here we use the Length-Encoded-Integer(LEI) as the data length may not fit into one byte + // see: https://dev.mysql.com/doc/internals/en/integer.html#length-encoded-integer + var authRespLEIBuf [9]byte + authRespLEI := AppendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(auth))) + if len(authRespLEI) > 1 { + // if the length can not be written in 1 byte, it must be written as a + // length encoded integer + capability |= CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA + } //packet length - //capbility 4 + //capability 4 //max-packet size 4 //charset 1 //reserved all[0] 23 - length := 4 + 4 + 1 + 23 - //username - length += len(c.user) + 1 - - //we only support secure connection - auth := CalcPassword(c.salt, []byte(c.password)) - - length += 1 + len(auth) - + //auth + //mysql_native_password + null-terminated + length := 4 + 4 + 1 + 23 + len(c.user) + 1 + len(authRespLEI) + len(auth) + 21 + 1 + if addNull { + length++ + } + // db name if len(c.db) > 0 { capability |= CLIENT_CONNECT_WITH_DB - length += len(c.db) + 1 } - // mysql_native_password + null-terminated - length += 21 + 1 - - c.capability = capability - data := make([]byte, length+4) - //capability [32 bit] + // capability [32 bit] data[4] = byte(capability) data[5] = byte(capability >> 8) data[6] = byte(capability >> 16) data[7] = byte(capability >> 24) - //MaxPacketSize [32 bit] (none) - //data[8] = 0x00 - //data[9] = 0x00 - //data[10] = 0x00 - //data[11] = 0x00 + // MaxPacketSize [32 bit] (none) + data[8] = 0x00 + data[9] = 0x00 + data[10] = 0x00 + data[11] = 0x00 - //Charset [1 byte] - //use default collation id 33 here, is utf-8 + // Charset [1 byte] + // use default collation id 33 here, is utf-8 data[12] = byte(DEFAULT_COLLATION_ID) // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest - if c.TLSConfig != nil { + if c.tlsConfig != nil { // Send TLS / SSL request packet if err := c.WritePacket(data[:(4+4+1+23)+4]); err != nil { return err } // Switch to TLS - tlsConn := tls.Client(c.Conn.Conn, c.TLSConfig) + tlsConn := tls.Client(c.Conn.Conn, c.tlsConfig) if err := tlsConn.Handshake(); err != nil { return err } @@ -145,10 +220,13 @@ func (c *Conn) writeAuthHandshake() error { c.Sequence = currentSequence } - //Filler [23 bytes] (all 0x00) - pos := 13 + 23 + // Filler [23 bytes] (all 0x00) + pos := 13 + for ; pos < 13+23; pos++ { + data[pos] = 0 + } - //User [null terminated string] + // User [null terminated string] if len(c.user) > 0 { pos += copy(data[pos:], c.user) } @@ -156,8 +234,12 @@ func (c *Conn) writeAuthHandshake() error { pos++ // auth [length encoded integer] - data[pos] = byte(len(auth)) - pos += 1 + copy(data[pos+1:], auth) + pos += copy(data[pos:], authRespLEI) + pos += copy(data[pos:], auth) + if addNull { + data[pos] = 0x00 + pos++ + } // db [null terminated string] if len(c.db) > 0 { @@ -167,7 +249,7 @@ func (c *Conn) writeAuthHandshake() error { } // Assume native client during response - pos += copy(data[pos:], "mysql_native_password") + pos += copy(data[pos:], c.authPluginName) data[pos] = 0x00 return c.WritePacket(data) diff --git a/client/client_test.go b/client/client_test.go index 06713c975..04bfdb2d4 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,41 +1,56 @@ package client import ( - "crypto/tls" "flag" "fmt" "strings" "testing" + "github.com/juju/errors" . "github.com/pingcap/check" + "github.com/siddontang/go-mysql/test_util/test_keys" "github.com/siddontang/go-mysql/mysql" ) var testHost = flag.String("host", "127.0.0.1", "MySQL server host") -var testPort = flag.Int("port", 3306, "MySQL server port") +// We cover the whole range of MySQL server versions using docker-compose to bind them to different ports for testing. +// MySQL is constantly updating auth plugin to make it secure: +// starting from MySQL 8.0.4, a new auth plugin is introduced, causing plain password auth to fail with error: +// ERROR 1251 (08004): Client does not support authentication protocol requested by server; consider upgrading MySQL client +// Hint: use docker-compose to start corresponding MySQL docker containers and add the their ports here +var testPort = flag.String("port", "3306", "MySQL server port") // choose one or more form 5561,5641,3306,5722,8003,8012,8013, e.g. '3306,5722,8003' var testUser = flag.String("user", "root", "MySQL user") var testPassword = flag.String("pass", "", "MySQL password") var testDB = flag.String("db", "test", "MySQL test database") func Test(t *testing.T) { + segs := strings.Split(*testPort, ",") + for _, seg := range segs { + Suite(&clientTestSuite{port: seg}) + } TestingT(t) } type clientTestSuite struct { - c *Conn + c *Conn + port string } -var _ = Suite(&clientTestSuite{}) - func (s *clientTestSuite) SetUpSuite(c *C) { var err error - addr := fmt.Sprintf("%s:%d", *testHost, *testPort) - s.c, err = Connect(addr, *testUser, *testPassword, *testDB) + addr := fmt.Sprintf("%s:%s", *testHost, s.port) + s.c, err = Connect(addr, *testUser, *testPassword, "") if err != nil { c.Fatal(err) } + _, err = s.c.Execute("CREATE DATABASE IF NOT EXISTS " + *testDB) + c.Assert(err, IsNil) + + _, err = s.c.Execute("USE " + *testDB) + c.Assert(err, IsNil) + s.testConn_CreateTable(c) s.testStmt_CreateTable(c) } @@ -78,12 +93,15 @@ func (s *clientTestSuite) TestConn_Ping(c *C) { c.Assert(err, IsNil) } -func (s *clientTestSuite) TestConn_TLS(c *C) { +// NOTE for MySQL 5.5 and 5.6, server side has to config SSL to pass the TLS test, otherwise, it will throw error that +// MySQL server does not support TLS required by the client. However, for MySQL 5.7 and above, auto generated certificates +// are used by default so that manual config is no longer necessary. +func (s *clientTestSuite) TestConn_TLS_Verify(c *C) { // Verify that the provided tls.Config is used when attempting to connect to mysql. // An empty tls.Config will result in a connection error. - addr := fmt.Sprintf("%s:%d", *testHost, *testPort) + addr := fmt.Sprintf("%s:%s", *testHost, s.port) _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) { - c.TLSConfig = &tls.Config{} + c.UseSSL(false) }) if err == nil { c.Fatal("expected error") @@ -95,6 +113,33 @@ func (s *clientTestSuite) TestConn_TLS(c *C) { } } +func (s *clientTestSuite) TestConn_TLS_Skip_Verify(c *C) { + // An empty tls.Config will result in a connection error but we can configure to skip it. + addr := fmt.Sprintf("%s:%s", *testHost, s.port) + _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) { + c.UseSSL(true) + }) + c.Assert(err, Equals, nil) +} + +func (s *clientTestSuite) TestConn_TLS_Certificate(c *C) { + // This test uses the TLS suite in 'go-mysql/docker/resources'. The certificates are not valid for any names. + // And if server uses auto-generated certificates, it will be an error like: + // "x509: certificate is valid for MySQL_Server_8.0.12_Auto_Generated_Server_Certificate, not not-a-valid-name" + tlsConfig := NewClientTLSConfig(test_keys.CaPem, test_keys.CertPem, test_keys.KeyPem, false, "not-a-valid-name") + addr := fmt.Sprintf("%s:%s", *testHost, s.port) + _, err := Connect(addr, *testUser, *testPassword, *testDB, func(c *Conn) { + c.SetTLSConfig(tlsConfig) + }) + if err == nil { + c.Fatal("expected error") + } + if !strings.Contains(errors.Details(err), "certificate is not valid for any names") && + !strings.Contains(errors.Details(err), "certificate is valid for") { + c.Fatalf("expected errors for server name verification, but got unknown error: %s", errors.Details(err)) + } +} + func (s *clientTestSuite) TestConn_Insert(c *C) { str := `insert into mixer_test_conn (id, str, f, e) values(1, "a", 3.14, "test1")` @@ -349,4 +394,4 @@ func (s *clientTestSuite) TestStmt_Trans(c *C) { str, _ = r.GetString(0, 0) c.Assert(str, Equals, `abc`) -} +} \ No newline at end of file diff --git a/client/conn.go b/client/conn.go index 54ee3f050..b015b433e 100644 --- a/client/conn.go +++ b/client/conn.go @@ -18,7 +18,8 @@ type Conn struct { user string password string db string - TLSConfig *tls.Config + tlsConfig *tls.Config + proto string capability uint32 @@ -26,7 +27,8 @@ type Conn struct { charset string - salt []byte + salt []byte + authPluginName string connectionID uint32 } @@ -56,6 +58,7 @@ func Connect(addr string, user string, password string, dbName string, options . c.user = user c.password = password c.db = dbName + c.proto = proto //use default charset here, utf-8 c.charset = DEFAULT_CHARSET @@ -85,7 +88,7 @@ func (c *Conn) handshake() error { return errors.Trace(err) } - if _, err := c.readOK(); err != nil { + if err := c.handleAuthResult(); err != nil { c.Close() return errors.Trace(err) } @@ -109,6 +112,18 @@ func (c *Conn) Ping() error { return nil } +// use default SSL +// pass to options when connect +func (c *Conn) UseSSL(insecureSkipVerify bool) { + c.tlsConfig = &tls.Config{InsecureSkipVerify: insecureSkipVerify} +} + +// use user-specified TLS config +// pass to options when connect +func (c *Conn) SetTLSConfig(config *tls.Config) { + c.tlsConfig = config +} + func (c *Conn) UseDB(dbName string) error { if c.db == dbName { return nil diff --git a/client/resp.go b/client/resp.go index 4e5f8559d..71aa1bcd4 100644 --- a/client/resp.go +++ b/client/resp.go @@ -1,8 +1,14 @@ package client +import "C" import ( "encoding/binary" + "bytes" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "github.com/juju/errors" . "github.com/siddontang/go-mysql/mysql" "github.com/siddontang/go/hack" @@ -32,7 +38,7 @@ func (c *Conn) isEOFPacket(data []byte) bool { func (c *Conn) handleOKPacket(data []byte) (*Result, error) { var n int - var pos int = 1 + var pos = 1 r := new(Result) @@ -64,7 +70,7 @@ func (c *Conn) handleOKPacket(data []byte) (*Result, error) { func (c *Conn) handleErrorPacket(data []byte) error { e := new(MyError) - var pos int = 1 + var pos = 1 e.Code = binary.LittleEndian.Uint16(data[pos:]) pos += 2 @@ -81,6 +87,116 @@ func (c *Conn) handleErrorPacket(data []byte) error { return e } +func (c *Conn) handleAuthResult() error { + data, switchToPlugin, err := c.readAuthResult() + if err != nil { + return err + } + // handle auth switch, only support 'sha256_password', and 'caching_sha2_password' + if switchToPlugin != "" { + //fmt.Printf("now switching auth plugin to '%s'\n", switchToPlugin) + if data == nil { + data = c.salt + } else { + copy(c.salt, data) + } + c.authPluginName = switchToPlugin + auth, addNull, err := c.genAuthResponse(data) + if err = c.WriteAuthSwitchPacket(auth, addNull); err != nil { + return err + } + + // Read Result Packet + data, switchToPlugin, err = c.readAuthResult() + if err != nil { + return err + } + + // Do not allow to change the auth plugin more than once + if switchToPlugin != "" { + return errors.Errorf("can not switch auth plugin more than once") + } + } + + // handle caching_sha2_password + if c.authPluginName == AUTH_CACHING_SHA2_PASSWORD { + if data == nil { + return nil // auth already succeeded + } + if data[0] == CACHE_SHA2_FAST_AUTH { + if _, err = c.readOK(); err == nil { + return nil // auth successful + } + } else if data[0] == CACHE_SHA2_FULL_AUTH { + // need full authentication + if c.tlsConfig != nil || c.proto == "unix" { + if err = c.WriteClearAuthPacket(c.password); err != nil { + return err + } + } else { + if err = c.WritePublicKeyAuthPacket(c.password, c.salt); err != nil { + return err + } + } + } else { + errors.Errorf("invalid packet") + } + } else if c.authPluginName == AUTH_SHA256_PASSWORD { + if len(data) == 0 { + return nil // auth already succeeded + } + block, _ := pem.Decode(data) + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + // send encrypted password + err = c.WriteEncryptedPassword(c.password, c.salt, pub.(*rsa.PublicKey)) + if err != nil { + return err + } + _, err = c.readOK() + return err + } + return nil +} + +func (c *Conn) readAuthResult() ([]byte, string, error) { + data, err := c.ReadPacket() + if err != nil { + return nil, "", err + } + + // see: https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ + // packet indicator + switch data[0] { + + case OK_HEADER: + _, err := c.handleOKPacket(data) + return nil, "", err + + case MORE_DATE_HEADER: + return data[1:], "", err + + case EOF_HEADER: + // server wants to switch auth + if len(data) < 1 { + // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest + return nil, AUTH_MYSQL_OLD_PASSWORD, nil + } + pluginEndIndex := bytes.IndexByte(data, 0x00) + if pluginEndIndex < 0 { + return nil, "", errors.New("invalid packet") + } + plugin := string(data[1:pluginEndIndex]) + authData := data[pluginEndIndex+1:] + return authData, plugin, nil + + default: // Error otherwise + return nil, "", c.handleErrorPacket(data) + } +} + func (c *Conn) readOK() (*Result, error) { data, err := c.ReadPacket() if err != nil { diff --git a/client/tls.go b/client/tls.go new file mode 100644 index 000000000..3772a50f8 --- /dev/null +++ b/client/tls.go @@ -0,0 +1,28 @@ +package client + +import ( + "crypto/tls" + "crypto/x509" +) + +// generate TLS config for client side +// if insecureSkipVerify is set to true, serverName will not be validated +func NewClientTLSConfig(caPem, certPem, keyPem []byte, insecureSkipVerify bool, serverName string) *tls.Config { + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caPem) { + panic("failed to add ca PEM") + } + + cert, err := tls.X509KeyPair(certPem, keyPem) + if err != nil { + panic(err) + } + + config := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: pool, + InsecureSkipVerify: insecureSkipVerify, + ServerName: serverName, + } + return config +} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml new file mode 100644 index 000000000..151786ef8 --- /dev/null +++ b/docker/docker-compose.yaml @@ -0,0 +1,80 @@ +version: '3' +services: + + mysql-5.5.61: + image: "mysql:5.5.61" + container_name: "mysql-server-5.5.61" + ports: + - "5561:3306" + command: --ssl=TRUE --ssl-ca=/usr/local/mysql/ca.pem --ssl-cert=/usr/local/mysql/server-cert.pem --ssl-key=/usr/local/mysql/server-key.pem + volumes: + - ./resources/ca.pem:/usr/local/mysql/ca.pem + - ./resources/server-cert.pem:/usr/local/mysql/server-cert.pem + - ./resources/server-key.pem:/usr/local/mysql/server-key.pem + environment: + - MYSQL_ALLOW_EMPTY_PASSWORD=true + - bind-address=0.0.0.0 + + mysql-5.6.41: + image: "mysql:5.6.41" + container_name: "mysql-server-5.6.41" + ports: + - "5641:3306" + command: --ssl=TRUE --ssl-ca=/usr/local/mysql/ca.pem --ssl-cert=/usr/local/mysql/server-cert.pem --ssl-key=/usr/local/mysql/server-key.pem + volumes: + - ./resources/ca.pem:/usr/local/mysql/ca.pem + - ./resources/server-cert.pem:/usr/local/mysql/server-cert.pem + - ./resources/server-key.pem:/usr/local/mysql/server-key.pem + environment: + - MYSQL_ALLOW_EMPTY_PASSWORD=true + - bind-address=0.0.0.0 + + mysql-default: + image: "mysql:5.7.22" + container_name: "mysql-server-default" + ports: + - "3306:3306" + command: ["mysqld", "--log-bin=mysql-bin", "--server-id=1"] + environment: + - MYSQL_ALLOW_EMPTY_PASSWORD=true + - bind-address=0.0.0.0 + + mysql-5.7.22: + image: "mysql:5.7.22" + container_name: "mysql-server-5.7.22" + ports: + - "5722:3306" + environment: + - MYSQL_ALLOW_EMPTY_PASSWORD=true + - bind-address=0.0.0.0 + + mysql-8.0.3: + image: "mysql:8.0.3" + container_name: "mysql-server-8.0.3" + ports: + - "8003:3306" + environment: + - MYSQL_ALLOW_EMPTY_PASSWORD=true + - bind-address=0.0.0.0 + + mysql-8.0.12: + image: "mysql:8.0.12" + container_name: "mysql-server-8.0.12" + ports: + - "8012:3306" + environment: + #- MYSQL_ROOT_PASSWORD=abc123 + - MYSQL_ALLOW_EMPTY_PASSWORD=true + - bind-address=0.0.0.0 + + mysql-8.0.12-sha256: + image: "mysql:8.0.12" + container_name: "mysql-server-8.0.12-sha256" + ports: + - "8013:3306" + entrypoint: ['/entrypoint.sh', '--default-authentication-plugin=sha256_password'] + environment: + #- MYSQL_ROOT_PASSWORD=abc123 + - MYSQL_ALLOW_EMPTY_PASSWORD=true + - bind-address=0.0.0.0 + diff --git a/docker/resources/ca.key b/docker/resources/ca.key new file mode 100644 index 000000000..8344ed23d --- /dev/null +++ b/docker/resources/ca.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAsV6xlhFxMn14Pn7XBRGLt8/HXmhVVu20IKFgIOyX7gAZr0QL +suT1fGf5zH9HrlgOMkfdhV847U03KPfUnBsi9lS6/xOxnH/OzTYM0WW0eNMGF7eo +xrS64GSbPVX4pLi5+uwrrZT5HmDgZi49ANmuX6UYmH/eRRvSIoYUTV6t0aYsLyKv +lpEAtRAe4AlKB236j5ggmJ36QUhTFTbeNbeOOgloTEdPK8Y/kgpnhiqzMdPqqIc7 +IeXUc456yX8MJUgniTM2qCNTFdEw+C2Ok0RbU6TI2SuEgVF4jtCcVEKxZ8kYbioO +NaePQKFR/EhdXO+/ag1IEdXElH9knLOfB+zCgwIDAQABAoIBAC2U0jponRiGmgIl +gohw6+D+6pNeaKAAUkwYbKXJZ3noWLFr4T3GDTg9WDqvcvJg+rT9NvZxdCW3tDc5 +CVBcwO1g9PVcUEaRqcme3EhrxKdQQ76QmjUGeQf1ktd+YnmiZ1kOnGLtZ9/gsYpQ +06iGSIOX3+xA4BQOhEAPCOShMjYv+pWvWGhZCSmeulKulNVPBbG2H1I9EoT5Wd+Q +8LUfgZOuUXrtcsuvEf2XeacCo0pUbjx8ErhDHP6aPasFAXq15Bm8DnsUOrrsjcLy +sPy/mHwpd6kTw+O3EzjTdaYSFRoDSpfpIS5Bk+yicdxOmTwp1pzDu6HyYnuOnc9Q +JQ8HvlECgYEA2z+1HKVz5k7NYyRihW4l30vAcAGcgG1RObB6DmLbGu4MPvMymgLO +1QhYjlCcKfRHhVS2864op3Oba2fIgCc2am0DIQQ6kZ23ick78aj9G2ZXYpdpIPLu +Kl1AZHj6XDrOPVqidwcE6iYHLLWp9x4Atgw5d44XmhQ0kwrqAfccOX8CgYEAzxnl +7Uu+v5WI3hBVJpxS6eoS1TdztVIJaumyE43pBoHEuJrp4MRf0Lu2DiDpH8R3/RoE +o+ykn6xzphYwUopYaCWzYTKoXvxCvmLkDjHcpdzLtwWbKG+MJih2nTADEDI7sK4e +a3IU8miK6FeqkQHfs/5dlQa8q31yxiukw0qQEP0CgYAtLg6jTZD5l6mJUZkfx9f0 +EMciDaLzcBN54Nz2E/b0sLNDUZhO1l9K1QJyqTfVCWqnlhJxWqU0BIW1d1iA2BPF +kJtBdX6gPTDyKs64eMtXlxpQzcSzLnxXrIm1apyk3tVbHU83WfHwUk/OLc1NiBg7 +a394HIbOkHVZC7m3F/Xv/wKBgQDHrM2du8D+kJs0l4SxxFjAxPlBb8R01tLTrNwP +tGwu5OEZp+rE1jEXXFRMTPjXsyKI+hPtRJT4ilm6kXwnqNFSIL9RgHkLk6Z6T3hY +I0T8+ePD43jURLBYffzW0tqxO+2HDGmx6H0/twHuv89pHehkb2Qk8ijoIvyNCrlB +vVsntQKBgCK04nbb+G45D6TKCcZ6XKT/+qneJQE5cfvHl5EqrfjSmlnEUpJjJfyc +6Q1PtXtWOtOScU93f1JKL7+JBbWDn9uBlboM8BSkAVVd/2vyg88RuEtIru1syxcW +d1rMxqaMRJuhuqaS33CoPUpn30b4zVrPhQJ2+TwDAol4qIGHaie8 +-----END RSA PRIVATE KEY----- diff --git a/docker/resources/ca.pem b/docker/resources/ca.pem new file mode 100644 index 000000000..e251bd64d --- /dev/null +++ b/docker/resources/ca.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDtTCCAp2gAwIBAgIJANeS1FOzWXlZMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTgwODE2MTUxNDE5WhcNMjEwNjA1MTUxNDE5WjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEAsV6xlhFxMn14Pn7XBRGLt8/HXmhVVu20IKFgIOyX7gAZr0QLsuT1fGf5 +zH9HrlgOMkfdhV847U03KPfUnBsi9lS6/xOxnH/OzTYM0WW0eNMGF7eoxrS64GSb +PVX4pLi5+uwrrZT5HmDgZi49ANmuX6UYmH/eRRvSIoYUTV6t0aYsLyKvlpEAtRAe +4AlKB236j5ggmJ36QUhTFTbeNbeOOgloTEdPK8Y/kgpnhiqzMdPqqIc7IeXUc456 +yX8MJUgniTM2qCNTFdEw+C2Ok0RbU6TI2SuEgVF4jtCcVEKxZ8kYbioONaePQKFR +/EhdXO+/ag1IEdXElH9knLOfB+zCgwIDAQABo4GnMIGkMB0GA1UdDgQWBBQgHiwD +00upIbCOunlK4HRw89DhjjB1BgNVHSMEbjBsgBQgHiwD00upIbCOunlK4HRw89Dh +jqFJpEcwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNV +BAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZIIJANeS1FOzWXlZMAwGA1UdEwQF +MAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAFMZFQTFKU5tWIpWh8BbVZeVZcng0Kiq +qwbhVwaTkqtfmbqw8/w+faOWylmLncQEMmgvnUltGMQlQKBwQM2byzPkz9phal3g +uI0JWJYqtcMyIQUB9QbbhrDNC9kdt/ji/x6rrIqzaMRuiBXqH5LQ9h856yXzArqd +cAQGzzYpbUCIv7ciSB93cKkU73fQLZVy5ZBy1+oAa1V9U4cb4G/20/PDmT+G3Gxz +pEjeDKtz8XINoWgA2cSdfAhNZt5vqJaCIZ8qN0z6C7SUKwUBderERUMLUXdhUldC +KTVHyEPvd0aULd5S5vEpKCnHcQmFcLdoN8t9k9pR9ZgwqXbyJHlxWFo= +-----END CERTIFICATE----- diff --git a/docker/resources/client-cert.pem b/docker/resources/client-cert.pem new file mode 100644 index 000000000..e478e7863 --- /dev/null +++ b/docker/resources/client-cert.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDBjCCAe4CCQDg06wCf7hcuTANBgkqhkiG9w0BAQUFADBFMQswCQYDVQQGEwJB +VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMB4XDTE4MDgxOTA4NDY0N1oXDTI4MDgxNjA4NDY0N1owRTELMAkG +A1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNVBAoTGEludGVybmV0 +IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +AMmivNyk3Rc1ZvLPhb3WPNkf9f2G4g9nMc0+eMrR1IKJ1U1A98ojeIBT+pfk1bSq +Ol0UDm66Vd3YQ+4HpyYHaYV6mwoTEulL9Quk8RLa7TRwQu3PLi3o567RhVIrx8Z3 +umuWb9UUzJfSFH04Uy9+By4CJCqIQXU4BocLIKHhIkNjmAQ9fWO1hZ8zmPHSEfvu +Wqa/DYKGvF0MJr4Lnkm/sKUd+O94p9suvwM6OGIDibACiKRF2H+JbgQLbA58zkLv +DHtXOqsCL7HxiONX8VDrQjN/66Nh9omk/Bx2Ec8IqappHvWf768HSH79x/znaial +VEV+6K0gP+voJHfnA10laWMCAwEAATANBgkqhkiG9w0BAQUFAAOCAQEAPD+Fn1qj +HN62GD3eIgx6wJxYuemhdbgmEwrZZf4V70lS6e9Iloif0nBiISDxJUpXVWNRCN3Z +3QVC++F7deDmWL/3dSpXRvWsapzbCUhVQ2iBcnZ7QCOdvAqYR1ecZx70zvXCwBcd +6XKmRtdeNV6B211KRFmTYtVyPq4rcWrkTPGwPBncJI1eQQmyFv2T9SwVVp96Nbrq +sf7zrJGmuVCdXGPRi/ALVHtJCz6oPoft3I707eMe+ijnFqwGbmMD4fMD6Ync/hEz +PyR5FMZkXSXHS0gkA5pfwW7wJ2WSWDhI6JMS1gbatY7QzgHbKoQpxBPUXlnzzj2h +7O9cgFTh/XOZXQ== +-----END CERTIFICATE----- diff --git a/docker/resources/client-key.pem b/docker/resources/client-key.pem new file mode 100644 index 000000000..996a97b2f --- /dev/null +++ b/docker/resources/client-key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAyaK83KTdFzVm8s+FvdY82R/1/YbiD2cxzT54ytHUgonVTUD3 +yiN4gFP6l+TVtKo6XRQObrpV3dhD7genJgdphXqbChMS6Uv1C6TxEtrtNHBC7c8u +LejnrtGFUivHxne6a5Zv1RTMl9IUfThTL34HLgIkKohBdTgGhwsgoeEiQ2OYBD19 +Y7WFnzOY8dIR++5apr8Ngoa8XQwmvgueSb+wpR3473in2y6/Azo4YgOJsAKIpEXY +f4luBAtsDnzOQu8Me1c6qwIvsfGI41fxUOtCM3/ro2H2iaT8HHYRzwipqmke9Z/v +rwdIfv3H/OdqJqVURX7orSA/6+gkd+cDXSVpYwIDAQABAoIBAAGLY5L1GFRzLkSx +3j5kA7dODV5RyC2CBtmhnt8+2DffwmiDFOLRfrzM5+B9+j0WCLhpzOqANuQqIesS +1+7so5xIIiPjnYN393qNWuNgFe0O5xRXP+1OGWg3ZqQIfdFBXYYxcs3ZCPAoxctn +wQteFcP+dDR3MrkpIrOqHCfhR5foieOMP+9k5kCjk+aZqhEmFyko+X+xVO/32xs+ ++3qXhUrHt3Op5on30QMOFguniQlYwLJkd9qVjGuGMIrVPxoUz0rya4SKrGKgkAr8 +mvQe2+sZo7cc6zC2ceaGMJU7z1RalTrCObbg5mynlu+Vf0E/YiES0abkQhSbcSB9 +mAkJC7ECgYEA/H1NDEiO164yYK9ji4HM/8CmHegWS4qsgrzAs8lU0yAcgdg9e19A +mNi8yssfIBCw62RRE4UGWS5F82myhmvq/mXbf8eCJ2CMgdCHQh1rT7WFD/Uc5Pe/ +8Lv2jNMQ61POguPyq6D0qcf8iigKIMHa1MIgAOmrgWrxulfbSUhm370CgYEAzHBu +J9p4dAqW32+Hrtv2XE0KUjH72TXr13WErosgeGTfsIW2exXByvLasxOJSY4Wb8oS +OLZ7bgp/EBchAc7my+nF8n5uOJxipWQUB5BoeB9aUJZ9AnWF4RDl94Jlm5PYBG/J +lRXrMtSTTIgmSw3Ft2A1vRMOQaHX89lNwOZL758CgYAXOT84/gOFexRPKFKzpkDA +1WtyHMLQN/UeIVZoMwCGWtHEb6tYCa7bYDQdQwmd3Wsoe5WpgfbPhR4SAYrWKl72 +/09tNWCXVp4V4qRORH52Wm/ew+Dgfpk8/0zyLwfDXXYFPAo6Fxfp9ecYng4wbSQ/ +pYtkChooUTniteoJl4s+0QKBgHbFEpoAqF3yEPi52L/TdmrlLwvVkhT86IkB8xVc +Kn8HS5VH+V3EpBN9x2SmAupCq/JCGRftnAOwAWWdqkVcqGTq6V8Z6HrnD8A6RhCm +6qpuvI94/iNBl4fLw25pyRH7cFITh68fTsb3DKQ3rNeJpsYEFPRFb9Ddb5JxOmTI +5nDNAoGBAM+SyOhUGU+0Uw2WJaGWzmEutjeMRr5Z+cZ8keC/ZJNdji/faaQoeOQR +OXI8O6RBTBwVNQMyDyttT8J8BkISwfAhSdPkjgPw9GZ1pGREl53uCFDIlX2nvtQM +ioNzG5WHB7Gd7eUUTA91kRF9MZJTHPqNiNGR0Udj/trGyGqJebni +-----END RSA PRIVATE KEY----- diff --git a/docker/resources/server-cert.pem b/docker/resources/server-cert.pem new file mode 100644 index 000000000..3cb3b9ca6 --- /dev/null +++ b/docker/resources/server-cert.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDBjCCAe4CCQDg06wCf7hcuDANBgkqhkiG9w0BAQUFADBFMQswCQYDVQQGEwJB +VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMB4XDTE4MDgxOTA4NDUyNVoXDTI4MDgxNjA4NDUyNVowRTELMAkG +A1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNVBAoTGEludGVybmV0 +IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +ALK2gqK4uvTlxJANO2JKdibvmh899z6oCo9Km0mz5unj4dpnq9hljsQuKtcHUcM4 +HXcE06knaJ4TOF7lcsjaqoDO7r/SaFgjjXCqNvHD0Su4B+7qe52BZZTRV1AANP10 +PvebarXSEzaZUCyHHhSF8+Qb4vX04XKX/TOqinTVGtlnduKzP5+qsaFBtpLAw1V0 +At9EQB5BgnTYtdIsmvD4/2WhBvOjVxab75yx0R4oof4F3u528tbEegcWhBtmy2Xd +HI3S+TLljj3kOOdB+pgrVUl+KaDavWK3T+F1vTNDe56HEVNKeWlLy1scul61E0j9 +IkZAu6aRDxtKdl7bKu0BkzMCAwEAATANBgkqhkiG9w0BAQUFAAOCAQEAma3yFqR7 +xkeaZBg4/1I3jSlaNe5+2JB4iybAkMOu77fG5zytLomTbzdhewsuBwpTVMJdga8T +IdPeIFCin1U+5SkbjSMlpKf+krE+5CyrNJ5jAgO9ATIqx66oCTYXfGlNapGRLfSE +sa0iMqCe/dr4GPU+flW2DZFWiyJVDSF1JjReQnfrWY+SD2SpP/lmlgltnY8MJngd +xBLG5nsZCpUXGB713Q8ZyIm2ThVAMiskcxBleIZDDghLuhGvY/9eFJhZpvOkjWa6 +XGEi4E1G/SA+zVKFl41nHKCdqXdmIOnpcLlFBUVloQok5a95Kqc1TYw3f+WbdFff +99dAgk3gWwWZQA== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/docker/resources/server-key.pem b/docker/resources/server-key.pem new file mode 100644 index 000000000..babaaaec2 --- /dev/null +++ b/docker/resources/server-key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAsraCori69OXEkA07Ykp2Ju+aHz33PqgKj0qbSbPm6ePh2mer +2GWOxC4q1wdRwzgddwTTqSdonhM4XuVyyNqqgM7uv9JoWCONcKo28cPRK7gH7up7 +nYFllNFXUAA0/XQ+95tqtdITNplQLIceFIXz5Bvi9fThcpf9M6qKdNUa2Wd24rM/ +n6qxoUG2ksDDVXQC30RAHkGCdNi10iya8Pj/ZaEG86NXFpvvnLHRHiih/gXe7nby +1sR6BxaEG2bLZd0cjdL5MuWOPeQ450H6mCtVSX4poNq9YrdP4XW9M0N7nocRU0p5 +aUvLWxy6XrUTSP0iRkC7ppEPG0p2Xtsq7QGTMwIDAQABAoIBAGh1m8hHWCg7gXh9 +838RbRx3IswuKS27hWiaQEiFWmzOIb7KqDy1qAxtu+ayRY1paHegH6QY/+Kd824s +ibpzbgQacJ04/HrAVTVMmQ8Z2VLHoAN7lcPL1bd14aZGaLLZVtDeTDJ413grhxxv +4ho27gcgcbo4Z+rWgk7H2WRPCAGYqWYAycm3yF5vy9QaO6edU+T588YsEQOos5iy +5pVFSGDGZkcUp1ukL3BJYR+jvygn6WPCobQ/LScUdi+ucitaI9i+UdlLokZARVRG +M/msqcTM73thR8yVRcexU6NUDxRBfZ/f7moSAEbBmGDXuxDcIyH9KGMQ2rMtN1X3 +lK8UNwkCgYEA2STJq/IUQHjdqd3Dqh/Q7Zm8/pMWFqLJSkqpnFtFsXPyUOx9zDOy +KqkIfGeyKwvsj9X9BcZ0FUKj9zoct1/WpPY+h7i7+z0MIujBh4AMjAcDrt4o76yK +UHuVmG2xKTdJoAbqOdToQeX6E82Ioal5pbB2W7AbCQScNBPZ52jxgtcCgYEA0rE7 +2dFiRm0YmuszFYxft2+GP6NgP3R2TQNEooi1uCXG2xgwObie1YCHzpZ5CfSqJIxP +XB7DXpIWi7PxJoeai2F83LnmdFz6F1BPRobwDoSFNdaSKLg4Yf856zpgYNKhL1fE +OoOXj4VBWBZh1XDfZV44fgwlMIf7edOF1XOagwUCgYAw953O+7FbdKYwF0V3iOM5 +oZDAK/UwN5eC/GFRVDfcM5RycVJRCVtlSWcTfuLr2C2Jpiz/72fgH34QU3eEVsV1 +v94MBznFB1hESw7ReqvZq/9FoO3EVrl+OtBaZmosLD6bKtQJJJ0Xtz/01UW5hxla +pveZ55XBK9v51nwuNjk4UwKBgHD8fJUllSchUCWb5cwzeAz98Kdl7LJ6uQo5q2/i +EllLYOWThiEeIYdrIuklholRPIDXAaPsF2c6vn5yo+q+o6EFSZlw0+YpCjDAb5Lp +wAh5BprFk6HkkM/0t9Guf4rMyYWC8odSlE9x7YXYkuSMYDCTI4Zs6vCoq7I8PbQn +B4AlAoGAZ6Ee5m/ph5UVp/3+cR6jCY7aHBUU/M3pbJSkVjBW+ymEBVJ6sUdz8k3P +x8BiPEQggNN7faWBqRWP7KXPnDYHh6shYUgPJwI5HX6NE/ZDnnXjeysHRyf0oCo5 +S6tHXwHNKB5HS1c/KDyyNGjP2oi/MF4o/MGWNWEcK6TJA3RGOYM= +-----END RSA PRIVATE KEY----- diff --git a/driver/dirver_test.go b/driver/dirver_test.go index dc786ba0f..d43580f07 100644 --- a/driver/dirver_test.go +++ b/driver/dirver_test.go @@ -11,6 +11,7 @@ import ( // Use docker mysql to test, mysql is 3306 var testHost = flag.String("host", "127.0.0.1", "MySQL master host") +// possible choices for different MySQL versions are: 5561,5641,3306,5722,8003,8012 var testPort = flag.Int("port", 3306, "MySQL server port") var testUser = flag.String("user", "root", "MySQL user") var testPassword = flag.String("pass", "", "MySQL password") diff --git a/mysql/const.go b/mysql/const.go index a4862ea73..256d163e0 100644 --- a/mysql/const.go +++ b/mysql/const.go @@ -6,16 +6,22 @@ const ( TimeFormat string = "2006-01-02 15:04:05" ) -var ( - // maybe you can change for your specified name - ServerVersion string = "5.7.0" -) - const ( OK_HEADER byte = 0x00 + MORE_DATE_HEADER byte = 0x01 ERR_HEADER byte = 0xff EOF_HEADER byte = 0xfe LocalInFile_HEADER byte = 0xfb + + CACHE_SHA2_FAST_AUTH byte = 0x03 + CACHE_SHA2_FULL_AUTH byte = 0x04 +) + +const ( + AUTH_MYSQL_OLD_PASSWORD = "mysql_old_password" + AUTH_NATIVE_PASSWORD = "mysql_native_password" + AUTH_CACHING_SHA2_PASSWORD = "caching_sha2_password" + AUTH_SHA256_PASSWORD = "sha256_password" ) const ( @@ -151,7 +157,6 @@ const ( ) const ( - AUTH_NAME = "mysql_native_password" DEFAULT_CHARSET = "utf8" DEFAULT_COLLATION_ID uint8 = 33 DEFAULT_COLLATION_NAME string = "utf8_general_ci" diff --git a/mysql/field.go b/mysql/field.go index c26f6a292..891f00b15 100644 --- a/mysql/field.go +++ b/mysql/field.go @@ -31,42 +31,42 @@ func (p FieldData) Parse() (f *Field, err error) { var n int pos := 0 //skip catelog, always def - n, err = SkipLengthEnodedString(p) + n, err = SkipLengthEncodedString(p) if err != nil { return } pos += n //schema - f.Schema, _, n, err = LengthEnodedString(p[pos:]) + f.Schema, _, n, err = LengthEncodedString(p[pos:]) if err != nil { return } pos += n //table - f.Table, _, n, err = LengthEnodedString(p[pos:]) + f.Table, _, n, err = LengthEncodedString(p[pos:]) if err != nil { return } pos += n //org_table - f.OrgTable, _, n, err = LengthEnodedString(p[pos:]) + f.OrgTable, _, n, err = LengthEncodedString(p[pos:]) if err != nil { return } pos += n //name - f.Name, _, n, err = LengthEnodedString(p[pos:]) + f.Name, _, n, err = LengthEncodedString(p[pos:]) if err != nil { return } pos += n //org_name - f.OrgName, _, n, err = LengthEnodedString(p[pos:]) + f.OrgName, _, n, err = LengthEncodedString(p[pos:]) if err != nil { return } diff --git a/mysql/resultset.go b/mysql/resultset.go index 080405087..b01e1a530 100644 --- a/mysql/resultset.go +++ b/mysql/resultset.go @@ -28,7 +28,7 @@ func (p RowData) ParseText(f []*Field) ([]interface{}, error) { var n int = 0 for i := range f { - v, isNull, n, err = LengthEnodedString(p[pos:]) + v, isNull, n, err = LengthEncodedString(p[pos:]) if err != nil { return nil, errors.Trace(err) } @@ -151,7 +151,7 @@ func (p RowData) ParseBinary(f []*Field) ([]interface{}, error) { MYSQL_TYPE_BIT, MYSQL_TYPE_ENUM, MYSQL_TYPE_SET, MYSQL_TYPE_TINY_BLOB, MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB, MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING, MYSQL_TYPE_GEOMETRY: - v, isNull, n, err = LengthEnodedString(p[pos:]) + v, isNull, n, err = LengthEncodedString(p[pos:]) pos += n if err != nil { return nil, errors.Trace(err) diff --git a/mysql/util.go b/mysql/util.go index 7fe41fa21..757910eef 100644 --- a/mysql/util.go +++ b/mysql/util.go @@ -11,6 +11,8 @@ import ( "github.com/juju/errors" "github.com/siddontang/go/hack" + "crypto/sha256" + "crypto/rsa" ) func Pstack() string { @@ -48,6 +50,62 @@ func CalcPassword(scramble, password []byte) []byte { return scramble } +// Hash password using MySQL 8+ method (SHA256) +func CalcCachingSha2Password(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) + + crypt := sha256.New() + crypt.Write([]byte(password)) + message1 := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1) + message1Hash := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1Hash) + crypt.Write(scramble) + message2 := crypt.Sum(nil) + + for i := range message1 { + message1[i] ^= message2[i] + } + + return message1 +} + + +func EncryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { + plain := make([]byte, len(password)+1) + copy(plain, password) + for i := range plain { + j := i % len(seed) + plain[i] ^= seed[j] + } + sha1v := sha1.New() + return rsa.EncryptOAEP(sha1v, rand.Reader, pub, plain, nil) +} + +// encodes a uint64 value and appends it to the given bytes slice +func AppendLengthEncodedInteger(b []byte, n uint64) []byte { + switch { + case n <= 250: + return append(b, byte(n)) + + case n <= 0xffff: + return append(b, 0xfc, byte(n), byte(n>>8)) + + case n <= 0xffffff: + return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) + } + return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), + byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) +} + func RandomBuf(size int) ([]byte, error) { buf := make([]byte, size) @@ -84,39 +142,33 @@ func BFixedLengthInt(buf []byte) uint64 { } func LengthEncodedInt(b []byte) (num uint64, isNull bool, n int) { - switch b[0] { + if len(b) == 0 { + return 0, true, 1 + } + switch b[0] { // 251: NULL case 0xfb: - n = 1 - isNull = true - return + return 0, true, 1 - // 252: value of following 2 + // 252: value of following 2 case 0xfc: - num = uint64(b[1]) | uint64(b[2])<<8 - n = 3 - return + return uint64(b[1]) | uint64(b[2])<<8, false, 3 - // 253: value of following 3 + // 253: value of following 3 case 0xfd: - num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 - n = 4 - return + return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 - // 254: value of following 8 + // 254: value of following 8 case 0xfe: - num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | + return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | - uint64(b[7])<<48 | uint64(b[8])<<56 - n = 9 - return + uint64(b[7])<<48 | uint64(b[8])<<56, + false, 9 } // 0-250: value of first byte - num = uint64(b[0]) - n = 1 - return + return uint64(b[0]), false, 1 } func PutLengthEncodedInt(n uint64) []byte { @@ -137,23 +189,26 @@ func PutLengthEncodedInt(n uint64) []byte { return nil } -func LengthEnodedString(b []byte) ([]byte, bool, int, error) { +// returns the string read as a bytes slice, whether the value is NULL, +// the number of bytes read and an error, in case the string is longer than +// the input slice +func LengthEncodedString(b []byte) ([]byte, bool, int, error) { // Get length num, isNull, n := LengthEncodedInt(b) if num < 1 { - return nil, isNull, n, nil + return b[n:n], isNull, n, nil } n += int(num) // Check data length if len(b) >= n { - return b[n-int(num) : n], false, n, nil + return b[n-int(num) : n : n], false, n, nil } return nil, false, n, io.EOF } -func SkipLengthEnodedString(b []byte) (int, error) { +func SkipLengthEncodedString(b []byte) (int, error) { // Get length num, _, n := LengthEncodedInt(b) if num < 1 { diff --git a/packet/conn.go b/packet/conn.go index 3772e1a33..41b1bf1c7 100644 --- a/packet/conn.go +++ b/packet/conn.go @@ -1,11 +1,17 @@ package packet +import "C" import ( - "bufio" "bytes" "io" "net" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "encoding/pem" + "github.com/juju/errors" . "github.com/siddontang/go-mysql/mysql" ) @@ -15,7 +21,9 @@ import ( */ type Conn struct { net.Conn - br *bufio.Reader + + // we removed the buffer reader because it will cause the SSLRequest to block (tls connection handshake won't be + // able to read the "Client Hello" data since it has been buffered into the buffer reader) Sequence uint8 } @@ -23,7 +31,6 @@ type Conn struct { func NewConn(conn net.Conn) *Conn { c := new(Conn) - c.br = bufio.NewReaderSize(conn, 4096) c.Conn = conn return c @@ -37,55 +44,20 @@ func (c *Conn) ReadPacket() ([]byte, error) { } else { return buf.Bytes(), nil } - - // header := []byte{0, 0, 0, 0} - - // if _, err := io.ReadFull(c.br, header); err != nil { - // return nil, ErrBadConn - // } - - // length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) - // if length < 1 { - // return nil, fmt.Errorf("invalid payload length %d", length) - // } - - // sequence := uint8(header[3]) - - // if sequence != c.Sequence { - // return nil, fmt.Errorf("invalid sequence %d != %d", sequence, c.Sequence) - // } - - // c.Sequence++ - - // data := make([]byte, length) - // if _, err := io.ReadFull(c.br, data); err != nil { - // return nil, ErrBadConn - // } else { - // if length < MaxPayloadLen { - // return data, nil - // } - - // var buf []byte - // buf, err = c.ReadPacket() - // if err != nil { - // return nil, ErrBadConn - // } else { - // return append(data, buf...), nil - // } - // } } func (c *Conn) ReadPacketTo(w io.Writer) error { header := []byte{0, 0, 0, 0} - if _, err := io.ReadFull(c.br, header); err != nil { + if _, err := io.ReadFull(c.Conn, header); err != nil { return ErrBadConn } length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) - if length < 1 { - return errors.Errorf("invalid payload length %d", length) - } + // bug fixed: caching_sha2_password will send 0-length payload (the unscrambled password) when the password is empty + //if length < 1 { + // return errors.Errorf("invalid payload length %d", length) + //} sequence := uint8(header[3]) @@ -95,7 +67,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error { c.Sequence++ - if n, err := io.CopyN(w, c.br, int64(length)); err != nil { + if n, err := io.CopyN(w, c.Conn, int64(length)); err != nil { return ErrBadConn } else if n != int64(length) { return ErrBadConn @@ -150,6 +122,77 @@ func (c *Conn) WritePacket(data []byte) error { } } +// Client clear text authentication packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse +func (c *Conn) WriteClearAuthPacket(password string) error { + // Calculate the packet length and add a tailing 0 + pktLen := len(password) + 1 + data := make([]byte, 4 + pktLen) + + // Add the clear password [null terminated string] + copy(data[4:], password) + data[4+pktLen-1] = 0x00 + + return c.WritePacket(data) +} + +// Caching sha2 authentication. Public key request and send encrypted password +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse +func (c *Conn) WritePublicKeyAuthPacket(password string, cipher []byte) error { + // request public key + data := make([]byte, 4 + 1) + data[4] = 2 // cachingSha2PasswordRequestPublicKey + c.WritePacket(data) + + data, err := c.ReadPacket() + if err != nil { + return err + } + + block, _ := pem.Decode(data[1:]) + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + + plain := make([]byte, len(password)+1) + copy(plain, password) + for i := range plain { + j := i % len(cipher) + plain[i] ^= cipher[j] + } + sha1v := sha1.New() + enc, _ := rsa.EncryptOAEP(sha1v, rand.Reader, pub.(*rsa.PublicKey), plain, nil) + data = make([]byte, 4 + len(enc)) + copy(data[4:], enc) + return c.WritePacket(data) +} + +func (c *Conn) WriteEncryptedPassword(password string, seed []byte, pub *rsa.PublicKey) error { + enc, err := EncryptPassword(password, seed, pub) + if err != nil { + return err + } + return c.WriteAuthSwitchPacket(enc, false) +} + +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse +func (c *Conn) WriteAuthSwitchPacket(authData []byte, addNUL bool) error { + pktLen := 4 + len(authData) + if addNUL { + pktLen++ + } + data := make([]byte, pktLen) + + // Add the auth data [EOF] + copy(data[4:], authData) + if addNUL { + data[pktLen-1] = 0x00 + } + + return c.WritePacket(data) +} + func (c *Conn) ResetSequence() { c.Sequence = 0 } diff --git a/replication/binlogsyncer.go b/replication/binlogsyncer.go index fd045c225..552798bda 100644 --- a/replication/binlogsyncer.go +++ b/replication/binlogsyncer.go @@ -203,7 +203,7 @@ func (b *BinlogSyncer) registerSlave() error { log.Infof("register slave for master server %s", addr) var err error b.c, err = client.Connect(addr, b.cfg.User, b.cfg.Password, "", func(c *client.Conn) { - c.TLSConfig = b.cfg.TLSConfig + c.SetTLSConfig(b.cfg.TLSConfig) }) if err != nil { return errors.Trace(err) diff --git a/replication/row_event.go b/replication/row_event.go index 1315c51de..432170d4d 100644 --- a/replication/row_event.go +++ b/replication/row_event.go @@ -71,7 +71,7 @@ func (e *TableMapEvent) Decode(data []byte) error { var err error var metaData []byte - if metaData, _, n, err = LengthEnodedString(data[pos:]); err != nil { + if metaData, _, n, err = LengthEncodedString(data[pos:]); err != nil { return errors.Trace(err) } diff --git a/server/auth.go b/server/auth.go index b66ea4e0c..0eb54a63d 100644 --- a/server/auth.go +++ b/server/auth.go @@ -2,118 +2,173 @@ package server import ( "bytes" - "encoding/binary" - + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/tls" + "fmt" + + "github.com/juju/errors" . "github.com/siddontang/go-mysql/mysql" ) -func (c *Conn) writeInitialHandshake() error { - capability := CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | - CLIENT_CONNECT_WITH_DB | CLIENT_PROTOCOL_41 | - CLIENT_TRANSACTIONS | CLIENT_SECURE_CONNECTION - - data := make([]byte, 4, 128) - - //min version 10 - data = append(data, 10) - - //server version[00] - data = append(data, ServerVersion...) - data = append(data, 0) - - //connection id - data = append(data, byte(c.connectionID), byte(c.connectionID>>8), byte(c.connectionID>>16), byte(c.connectionID>>24)) - - //auth-plugin-data-part-1 - data = append(data, c.salt[0:8]...) - - //filter [00] - data = append(data, 0) - - //capability flag lower 2 bytes, using default capability here - data = append(data, byte(capability), byte(capability>>8)) +var ErrAccessDenied = errors.New("access denied") - //charset, utf-8 default - data = append(data, uint8(DEFAULT_COLLATION_ID)) - - //status - data = append(data, byte(c.status), byte(c.status>>8)) - - //below 13 byte may not be used - //capability flag upper 2 bytes, using default capability here - data = append(data, byte(capability>>16), byte(capability>>24)) - - //filter [0x15], for wireshark dump, value is 0x15 - data = append(data, 0x15) - - //reserved 10 [00] - data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) +func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) error { + switch authPluginName { + case AUTH_NATIVE_PASSWORD: + if err := c.acquirePassword(); err != nil { + return err + } + return c.compareNativePasswordAuthData(clientAuthData, c.password) - //auth-plugin-data-part-2 - data = append(data, c.salt[8:]...) + case AUTH_CACHING_SHA2_PASSWORD: + if err := c.compareCacheSha2PasswordAuthData(clientAuthData); err != nil { + return err + } + if c.cachingSha2FullAuth { + return c.handleAuthSwitchResponse() + } + return nil - //filter [00] - data = append(data, 0) + case AUTH_SHA256_PASSWORD: + if err := c.acquirePassword(); err != nil { + return err + } + cont, err := c.handlePublicKeyRetrieval(clientAuthData) + if err != nil { + return err + } + if !cont { + return nil + } + return c.compareSha256PasswordAuthData(clientAuthData, c.password) - return c.WritePacket(data) + default: + return errors.Errorf("unknown authentication plugin name '%s'", authPluginName) + } } -func (c *Conn) readHandshakeResponse(password string) error { - data, err := c.ReadPacket() - +func (c *Conn) acquirePassword() error { + password, found, err := c.credentialProvider.GetCredential(c.user) if err != nil { return err } - - pos := 0 - - //capability - c.capability = binary.LittleEndian.Uint32(data[:4]) - pos += 4 - - //skip max packet size - pos += 4 - - //charset, skip, if you want to use another charset, use set names - //c.collation = CollationId(data[pos]) - pos++ - - //skip reserved 23[00] - pos += 23 - - //user name - user := string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) - pos += len(user) + 1 - - if c.user != user { - return NewDefaultError(ER_NO_SUCH_USER, user, c.RemoteAddr().String()) + if !found { + return NewDefaultError(ER_NO_SUCH_USER, c.user, c.RemoteAddr().String()) } + c.password = password + return nil +} - //auth length and auth - authLen := int(data[pos]) - pos++ - auth := data[pos : pos+authLen] - - checkAuth := CalcPassword(c.salt, []byte(password)) - - if !bytes.Equal(auth, checkAuth) { - return NewDefaultError(ER_ACCESS_DENIED_ERROR, c.RemoteAddr().String(), c.user, "Yes") +func scrambleValidation(cached, nonce, scramble []byte) bool { + // SHA256(SHA256(SHA256(STORED_PASSWORD)), NONCE) + crypt := sha256.New() + crypt.Write(cached) + crypt.Write(nonce) + message2 := crypt.Sum(nil) + // SHA256(PASSWORD) + if len(message2) != len(scramble) { + return false } + for i := range message2 { + message2[i] ^= scramble[i] + } + // SHA256(SHA256(PASSWORD) + crypt.Reset() + crypt.Write(message2) + m := crypt.Sum(nil) + return bytes.Equal(m, cached) +} - pos += authLen +func (c *Conn) compareNativePasswordAuthData(clientAuthData []byte, password string) error { + if bytes.Equal(CalcPassword(c.salt, []byte(c.password)), clientAuthData) { + return nil + } + return ErrAccessDenied +} - if c.capability|CLIENT_CONNECT_WITH_DB > 0 { - if len(data[pos:]) == 0 { +func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, password string) error { + // Empty passwords are not hashed, but sent as empty string + if len(clientAuthData) == 0 { + if password == "" { return nil } - - db := string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) - pos += len(db) + 1 - - if err = c.h.UseDB(db); err != nil { + return ErrAccessDenied + } + if tlsConn, ok := c.Conn.Conn.(*tls.Conn); ok { + if !tlsConn.ConnectionState().HandshakeComplete { + return errors.New("incomplete TSL handshake") + } + // connection is SSL/TLS, client should send plain password + // deal with the trailing \NUL added for plain text password received + if l := len(clientAuthData); l != 0 && clientAuthData[l-1] == 0x00 { + clientAuthData = clientAuthData[:l-1] + } + if bytes.Equal(clientAuthData, []byte(password)) { + return nil + } + return ErrAccessDenied + } else { + // client should send encrypted password + // decrypt + dbytes, err := rsa.DecryptOAEP(sha1.New(), rand.Reader, (c.serverConf.tlsConfig.Certificates[0].PrivateKey).(*rsa.PrivateKey), clientAuthData, nil) + if err != nil { return err } + plain := make([]byte, len(password)+1) + copy(plain, password) + for i := range plain { + j := i % len(c.salt) + plain[i] ^= c.salt[j] + } + if bytes.Equal(plain, dbytes) { + return nil + } + return ErrAccessDenied } +} +func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error { + // Empty passwords are not hashed, but sent as empty string + if len(clientAuthData) == 0 { + if err := c.acquirePassword(); err != nil { + return err + } + if c.password == "" { + return nil + } + return ErrAccessDenied + } + // the caching of 'caching_sha2_password' in MySQL, see: https://dev.mysql.com/worklog/task/?id=9591 + if _, ok := c.credentialProvider.(*InMemoryProvider); ok { + // since we have already kept the password in memory and calculate the scramble is not that high of cost, we eliminate + // the caching part. So our server will never ask the client to do a full authentication via RSA key exchange and it appears + // like the auth will always hit the cache. + if err := c.acquirePassword(); err != nil { + return err + } + if bytes.Equal(CalcCachingSha2Password(c.salt, c.password), clientAuthData) { + // 'fast' auth: write "More data" packet (first byte == 0x01) with the second byte = 0x03 + return c.writeAuthMoreDataFastAuth() + } + return ErrAccessDenied + } + // other type of credential provider, we use the cache + cached, ok := c.serverConf.cacheShaPassword.Load(fmt.Sprintf("%s@%s", c.user, c.Conn.LocalAddr())) + if ok { + // Scramble validation + if scrambleValidation(cached.([]byte), c.salt, clientAuthData) { + // 'fast' auth: write "More data" packet (first byte == 0x01) with the second byte = 0x03 + return c.writeAuthMoreDataFastAuth() + } + return ErrAccessDenied + } + // cache miss, do full auth + if err := c.writeAuthMoreDataFullAuth(); err != nil { + return err + } + c.cachingSha2FullAuth = true return nil } diff --git a/server/auth_switch_response.go b/server/auth_switch_response.go new file mode 100644 index 000000000..038acffe4 --- /dev/null +++ b/server/auth_switch_response.go @@ -0,0 +1,133 @@ +package server + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/tls" + "fmt" + + "github.com/juju/errors" + . "github.com/siddontang/go-mysql/mysql" +) + +func (c *Conn) handleAuthSwitchResponse() error { + authData, err := c.readAuthSwitchRequestResponse() + if err != nil { + return err + } + + switch c.authPluginName { + case AUTH_NATIVE_PASSWORD: + if err := c.acquirePassword(); err != nil { + return err + } + if !bytes.Equal(CalcPassword(c.salt, []byte(c.password)), authData) { + return ErrAccessDenied + } + return nil + + case AUTH_CACHING_SHA2_PASSWORD: + if !c.cachingSha2FullAuth { + // Switched auth method but no MoreData packet send yet + if err := c.compareCacheSha2PasswordAuthData(authData); err != nil { + return err + } else { + if c.cachingSha2FullAuth { + return c.handleAuthSwitchResponse() + } + return nil + } + } + // AuthMoreData packet already sent, do full auth + if err := c.handleCachingSha2PasswordFullAuth(authData); err != nil { + return err + } + c.writeCachingSha2Cache() + return nil + + case AUTH_SHA256_PASSWORD: + cont, err := c.handlePublicKeyRetrieval(authData) + if err != nil { + return err + } + if !cont { + return nil + } + if err := c.acquirePassword(); err != nil { + return err + } + return c.compareSha256PasswordAuthData(authData, c.password) + + default: + return errors.Errorf("unknown authentication plugin name '%s'", c.authPluginName) + } +} + +func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error { + if err := c.acquirePassword(); err != nil { + return err + } + if tlsConn, ok := c.Conn.Conn.(*tls.Conn); ok { + if !tlsConn.ConnectionState().HandshakeComplete { + return errors.New("incomplete TSL handshake") + } + // connection is SSL/TLS, client should send plain password + // deal with the trailing \NUL added for plain text password received + if l := len(authData); l != 0 && authData[l-1] == 0x00 { + authData = authData[:l-1] + } + if bytes.Equal(authData, []byte(c.password)) { + return nil + } + return ErrAccessDenied + } else { + // client either request for the public key or send the encrypted password + if len(authData) == 1 && authData[0] == 0x02 { + // send the public key + if err := c.writeAuthMoreDataPubkey(); err != nil { + return err + } + // read the encrypted password + var err error + if authData, err = c.readAuthSwitchRequestResponse(); err != nil { + return err + } + } + // the encrypted password + // decrypt + dbytes, err := rsa.DecryptOAEP(sha1.New(), rand.Reader, (c.serverConf.tlsConfig.Certificates[0].PrivateKey).(*rsa.PrivateKey), authData, nil) + if err != nil { + return err + } + plain := make([]byte, len(c.password)+1) + copy(plain, c.password) + for i := range plain { + j := i % len(c.salt) + plain[i] ^= c.salt[j] + } + if bytes.Equal(plain, dbytes) { + return nil + } + return ErrAccessDenied + } +} + +func (c *Conn) writeCachingSha2Cache() { + // write cache + if c.password == "" { + return + } + // SHA256(PASSWORD) + crypt := sha256.New() + crypt.Write([]byte(c.password)) + m1 := crypt.Sum(nil) + // SHA256(SHA256(PASSWORD)) + crypt.Reset() + crypt.Write(m1) + m2 := crypt.Sum(nil) + // caching_sha2_password will maintain an in-memory hash of `user`@`host` => SHA256(SHA256(PASSWORD)) + c.serverConf.cacheShaPassword.Store(fmt.Sprintf("%s@%s", c.user, c.Conn.LocalAddr()), m2) +} diff --git a/server/caching_sha2_cache_test.go b/server/caching_sha2_cache_test.go new file mode 100644 index 000000000..a8139eb80 --- /dev/null +++ b/server/caching_sha2_cache_test.go @@ -0,0 +1,233 @@ +package server + +import ( + "database/sql" + "fmt" + "net" + "strings" + "sync" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + "github.com/juju/errors" + . "github.com/pingcap/check" + "github.com/siddontang/go-log/log" + "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go-mysql/test_util/test_keys" +) + +var delay = 50 + +// test caching for 'caching_sha2_password' +// NOTE the idea here is to plugin a throttled credential provider so that the first connection (cache miss) will take longer time +// than the second connection (cache hit). Remember to set the password for MySQL user otherwise it won't cache empty password. +func TestCachingSha2Cache(t *testing.T) { + log.SetLevel(log.LevelDebug) + + remoteProvider := &RemoteThrottleProvider{NewInMemoryProvider(), delay + 50} + remoteProvider.AddUser(*testUser, *testPassword) + cacheServer := NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf) + + // no TLS + Suite(&cacheTestSuite{ + server: cacheServer, + credProvider: remoteProvider, + tlsPara: "false", + }) + + TestingT(t) +} + +func TestCachingSha2CacheTLS(t *testing.T) { + log.SetLevel(log.LevelDebug) + + remoteProvider := &RemoteThrottleProvider{NewInMemoryProvider(), delay + 50} + remoteProvider.AddUser(*testUser, *testPassword) + cacheServer := NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf) + + // TLS + Suite(&cacheTestSuite{ + server: cacheServer, + credProvider: remoteProvider, + tlsPara: "skip-verify", + }) + + TestingT(t) +} + +type RemoteThrottleProvider struct { + *InMemoryProvider + delay int // in milliseconds +} + +func (m *RemoteThrottleProvider) GetCredential(username string) (password string, found bool, err error) { + time.Sleep(time.Millisecond * time.Duration(m.delay)) + return m.InMemoryProvider.GetCredential(username) +} + +type cacheTestSuite struct { + server *Server + credProvider CredentialProvider + tlsPara string + + db *sql.DB + + l net.Listener +} + +func (s *cacheTestSuite) SetUpSuite(c *C) { + var err error + + s.l, err = net.Listen("tcp", *testAddr) + c.Assert(err, IsNil) + + go s.onAccept(c) + + time.Sleep(30 * time.Millisecond) +} + +func (s *cacheTestSuite) TearDownSuite(c *C) { + if s.l != nil { + s.l.Close() + } +} + +func (s *cacheTestSuite) onAccept(c *C) { + for { + conn, err := s.l.Accept() + if err != nil { + return + } + + go s.onConn(conn, c) + } +} + +func (s *cacheTestSuite) onConn(conn net.Conn, c *C) { + //co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s}) + co, err := NewCustomizedConn(conn, s.server, s.credProvider, &testCacheHandler{s}) + c.Assert(err, IsNil) + for { + err = co.HandleCommand() + if err != nil { + return + } + } +} + +func (s *cacheTestSuite) runSelect(c *C) { + var a int64 + var b string + + err := s.db.QueryRow("SELECT a, b FROM tbl WHERE id=1").Scan(&a, &b) + c.Assert(err, IsNil) + c.Assert(a, Equals, int64(1)) + c.Assert(b, Equals, "hello world") +} + +func (s *cacheTestSuite) TestCache(c *C) { + // first connection + t1 := time.Now() + var err error + s.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=%s", *testUser, *testPassword, *testAddr, *testDB, s.tlsPara)) + c.Assert(err, IsNil) + s.db.SetMaxIdleConns(4) + s.runSelect(c) + t2 := time.Now() + + d1 := int(t2.Sub(t1).Nanoseconds() / 1e6) + //log.Debugf("first connection took %d milliseconds", d1) + + c.Assert(d1, GreaterEqual, delay) + + if s.db != nil { + s.db.Close() + } + + // second connection + t3 := time.Now() + s.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=%s", *testUser, *testPassword, *testAddr, *testDB, s.tlsPara)) + c.Assert(err, IsNil) + s.db.SetMaxIdleConns(4) + s.runSelect(c) + t4 := time.Now() + + d2 := int(t4.Sub(t3).Nanoseconds() / 1e6) + //log.Debugf("second connection took %d milliseconds", d2) + + c.Assert(d2, Less, delay) + if s.db != nil { + s.db.Close() + } + + s.server.cacheShaPassword = &sync.Map{} +} + +type testCacheHandler struct { + s *cacheTestSuite +} + +func (h *testCacheHandler) UseDB(dbName string) error { + return nil +} + +func (h *testCacheHandler) handleQuery(query string, binary bool) (*mysql.Result, error) { + ss := strings.Split(query, " ") + switch strings.ToLower(ss[0]) { + case "select": + var r *mysql.Resultset + var err error + //for handle go mysql driver select @@max_allowed_packet + if strings.Contains(strings.ToLower(query), "max_allowed_packet") { + r, err = mysql.BuildSimpleResultset([]string{"@@max_allowed_packet"}, [][]interface{}{ + {mysql.MaxPayloadLen}, + }, binary) + } else { + r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{ + {1, "hello world"}, + }, binary) + } + + if err != nil { + return nil, errors.Trace(err) + } else { + return &mysql.Result{0, 0, 0, r}, nil + } + case "insert": + return &mysql.Result{0, 1, 0, nil}, nil + case "delete": + return &mysql.Result{0, 0, 1, nil}, nil + case "update": + return &mysql.Result{0, 0, 1, nil}, nil + case "replace": + return &mysql.Result{0, 0, 1, nil}, nil + default: + return nil, fmt.Errorf("invalid query %s", query) + } + + return nil, nil +} + +func (h *testCacheHandler) HandleQuery(query string) (*mysql.Result, error) { + return h.handleQuery(query, false) +} + +func (h *testCacheHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) { + return nil, nil +} +func (h *testCacheHandler) HandleStmtPrepare(sql string) (params int, columns int, ctx interface{}, err error) { + return 0, 0, nil, nil +} + +func (h *testCacheHandler) HandleStmtClose(context interface{}) error { + return nil +} + +func (h *testCacheHandler) HandleStmtExecute(ctx interface{}, query string, args []interface{}) (*mysql.Result, error) { + return h.handleQuery(query, true) +} + +func (h *testCacheHandler) HandleOtherCommand(cmd byte, data []byte) error { + return mysql.NewError(mysql.ER_UNKNOWN_ERROR, fmt.Sprintf("command %d is not supported now", cmd)) +} diff --git a/server/command.go b/server/command.go index 3bc23ac33..6c8d13ad5 100644 --- a/server/command.go +++ b/server/command.go @@ -11,8 +11,8 @@ import ( type Handler interface { //handle COM_INIT_DB command, you can check whether the dbName is valid, or other. UseDB(dbName string) error - //handle COM_QUERY comamnd, like SELECT, INSERT, UPDATE, etc... - //If Result has a Resultset (SELECT, SHOW, etc...), we will send this as the repsonse, otherwise, we will send Result + //handle COM_QUERY command, like SELECT, INSERT, UPDATE, etc... + //If Result has a Resultset (SELECT, SHOW, etc...), we will send this as the response, otherwise, we will send Result HandleQuery(query string) (*Result, error) //handle COM_FILED_LIST command HandleFieldList(table string, fieldWildcard string) ([]*Field, error) diff --git a/server/conn.go b/server/conn.go index d6ea846ae..a279b93ea 100644 --- a/server/conn.go +++ b/server/conn.go @@ -15,15 +15,17 @@ import ( type Conn struct { *packet.Conn - capability uint32 - - connectionID uint32 - - status uint16 - - user string - - salt []byte + serverConf *Server + capability uint32 + authPluginName string + connectionID uint32 + status uint16 + salt []byte // should be 8 + 12 for auth-plugin-data-part-1 and auth-plugin-data-part-2 + + credentialProvider CredentialProvider + user string + password string + cachingSha2FullAuth bool h Handler @@ -35,23 +37,45 @@ type Conn struct { var baseConnID uint32 = 10000 +// create connection with default server settings func NewConn(conn net.Conn, user string, password string, h Handler) (*Conn, error) { - c := new(Conn) - - c.h = h - - c.user = user - c.Conn = packet.NewConn(conn) - - c.connectionID = atomic.AddUint32(&baseConnID, 1) + p := NewInMemoryProvider() + p.AddUser(user, password) + salt, _ := RandomBuf(20) + c := &Conn{ + Conn: packet.NewConn(conn), + serverConf: defaultServer, + credentialProvider: p, + h: h, + connectionID: atomic.AddUint32(&baseConnID, 1), + stmts: make(map[uint32]*Stmt), + salt: salt, + } + c.closed.Set(false) - c.stmts = make(map[uint32]*Stmt) + if err := c.handshake(); err != nil { + c.Close() + return nil, err + } - c.salt, _ = RandomBuf(20) + return c, nil +} +// create connection with customized server settings +func NewCustomizedConn(conn net.Conn, serverConf *Server, p CredentialProvider, h Handler) (*Conn, error) { + salt, _ := RandomBuf(20) + c := &Conn{ + Conn: packet.NewConn(conn), + serverConf: serverConf, + credentialProvider: p, + h: h, + connectionID: atomic.AddUint32(&baseConnID, 1), + stmts: make(map[uint32]*Stmt), + salt: salt, + } c.closed.Set(false) - if err := c.handshake(password); err != nil { + if err := c.handshake(); err != nil { c.Close() return nil, err } @@ -59,14 +83,16 @@ func NewConn(conn net.Conn, user string, password string, h Handler) (*Conn, err return c, nil } -func (c *Conn) handshake(password string) error { +func (c *Conn) handshake() error { if err := c.writeInitialHandshake(); err != nil { return err } - if err := c.readHandshakeResponse(password); err != nil { + if err := c.readHandshakeResponse(); err != nil { + if err == ErrAccessDenied { + err = NewDefaultError(ER_ACCESS_DENIED_ERROR, c.user, c.LocalAddr().String(), "Yes") + } c.writeError(err) - return err } diff --git a/server/credential_provider.go b/server/credential_provider.go new file mode 100644 index 000000000..3d44eb0c1 --- /dev/null +++ b/server/credential_provider.go @@ -0,0 +1,45 @@ +package server + +import "sync" + +// interface for user credential provider +// hint: can be extended for more functionality +// =================================IMPORTANT NOTE=============================== +// if the password in a third-party credential provider could be updated at runtime, we have to invalidate the caching +// for 'caching_sha2_password' by calling 'func (s *Server)InvalidateCache(string, string)'. +type CredentialProvider interface { + // check if the user exists + CheckUsername(username string) (bool, error) + // get user credential + GetCredential(username string) (password string, found bool, err error) +} + +func NewInMemoryProvider() *InMemoryProvider { + return &InMemoryProvider{ + userPool: sync.Map{}, + } +} + +// implements a in memory credential provider +type InMemoryProvider struct { + userPool sync.Map // username -> password +} + +func (m *InMemoryProvider) CheckUsername(username string) (found bool, err error) { + _, ok := m.userPool.Load(username) + return ok, nil +} + +func (m *InMemoryProvider) GetCredential(username string) (password string, found bool, err error) { + v, ok := m.userPool.Load(username) + if !ok { + return "", false, nil + } + return v.(string), true, nil +} + +func (m *InMemoryProvider) AddUser(username, password string) { + m.userPool.Store(username, password) +} + +type Provider InMemoryProvider diff --git a/server/example/server_example.go b/server/example/server_example.go new file mode 100644 index 000000000..1efa1a307 --- /dev/null +++ b/server/example/server_example.go @@ -0,0 +1,51 @@ +package main + +import ( + "net" + + "github.com/siddontang/go-log/log" + "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go-mysql/server" + "github.com/siddontang/go-mysql/test_util/test_keys" + + "crypto/tls" + "time" +) + +type RemoteThrottleProvider struct { + *server.InMemoryProvider + delay int // in milliseconds +} + +func (m *RemoteThrottleProvider) GetCredential(username string) (password string, found bool, err error) { + time.Sleep(time.Millisecond * time.Duration(m.delay)) + return m.InMemoryProvider.GetCredential(username) +} + +func main() { + l, _ := net.Listen("tcp", "127.0.0.1:3306") + // user either the in-memory credential provider or the remote credential provider (you can implement your own) + //inMemProvider := server.NewInMemoryProvider() + //inMemProvider.AddUser("root", "123") + remoteProvider := &RemoteThrottleProvider{server.NewInMemoryProvider(), 10 + 50} + remoteProvider.AddUser("root", "123") + var tlsConf = server.NewServerTLSConfig(test_keys.CaPem, test_keys.CertPem, test_keys.KeyPem, tls.VerifyClientCertIfGiven) + for { + c, _ := l.Accept() + go func() { + // Create a connection with user root and an empty password. + // You can use your own handler to handle command here. + svr := server.NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf) + conn, err := server.NewCustomizedConn(c, svr, remoteProvider, server.EmptyHandler{}) + + if err != nil { + log.Errorf("Connection error: %v", err) + return + } + + for { + conn.HandleCommand() + } + }() + } +} diff --git a/server/handshake_resp.go b/server/handshake_resp.go new file mode 100644 index 000000000..79af6f22e --- /dev/null +++ b/server/handshake_resp.go @@ -0,0 +1,190 @@ +package server + +import ( + "bytes" + "crypto/tls" + "encoding/binary" + + "github.com/juju/errors" + . "github.com/siddontang/go-mysql/mysql" +) + +func (c *Conn) readHandshakeResponse() error { + data, pos, err := c.readFirstPart() + if err != nil { + return err + } + if pos, err = c.readUserName(data, pos); err != nil { + return err + } + authData, authLen, pos, err := c.readAuthData(data, pos) + if err != nil { + return err + } + + pos += authLen + + if pos, err = c.readDb(data, pos); err != nil { + return err + } + + pos = c.readPluginName(data, pos) + + cont, err := c.handleAuthMatch(authData, pos) + if err != nil { + return err + } + if !cont { + return nil + } + + // ignore connect attrs for now, the proxy does not support passing attrs to actual MySQL server + + // try to authenticate the client + return c.compareAuthData(c.authPluginName, authData) +} + +func (c *Conn) readFirstPart() ([]byte, int, error) { + data, err := c.ReadPacket() + if err != nil { + return nil, 0, err + } + + pos := 0 + + // check CLIENT_PROTOCOL_41 + if uint32(binary.LittleEndian.Uint16(data[:2]))&CLIENT_PROTOCOL_41 == 0 { + return nil, 0, errors.New("CLIENT_PROTOCOL_41 compatible client is required") + } + + //capability + c.capability = binary.LittleEndian.Uint32(data[:4]) + if c.capability&CLIENT_SECURE_CONNECTION == 0 { + return nil, 0, errors.New("CLIENT_SECURE_CONNECTION compatible client is required") + } + pos += 4 + + //skip max packet size + pos += 4 + + //charset, skip, if you want to use another charset, use set names + //c.collation = CollationId(data[pos]) + pos++ + + //skip reserved 23[00] + pos += 23 + + // is this a SSLRequest packet? + if len(data) == (4 + 4 + 1 + 23) { + if c.serverConf.capability&CLIENT_SSL == 0 { + return nil, 0, errors.Errorf("The host '%s' does not support SSL connections", c.RemoteAddr().String()) + } + // switch to TLS + tlsConn := tls.Server(c.Conn.Conn, c.serverConf.tlsConfig) + if err := tlsConn.Handshake(); err != nil { + return nil, 0, err + } + c.Conn.Conn = tlsConn + + // mysql handshake again + return c.readFirstPart() + } + return data, pos, nil +} + +func (c *Conn) readUserName(data []byte, pos int) (int, error) { + //user name + user := string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)]) + pos += len(user) + 1 + c.user = user + return pos, nil +} + +func (c *Conn) readDb(data []byte, pos int) (int, error) { + if c.capability&CLIENT_CONNECT_WITH_DB != 0 { + if len(data[pos:]) == 0 { + return pos, nil + } + + db := string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)]) + pos += len(db) + 1 + + if err := c.h.UseDB(db); err != nil { + return 0, err + } + } + return pos, nil +} + +func (c *Conn) readPluginName(data []byte, pos int) int { + if c.capability&CLIENT_PLUGIN_AUTH != 0 { + c.authPluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)]) + pos += len(c.authPluginName) + } else { + // The method used is Native Authentication if both CLIENT_PROTOCOL_41 and CLIENT_SECURE_CONNECTION are set, + // but CLIENT_PLUGIN_AUTH is not set, so we fallback to 'mysql_native_password' + c.authPluginName = AUTH_NATIVE_PASSWORD + } + return pos +} + +func (c *Conn) readAuthData(data []byte, pos int) ([]byte, int, int, error) { + // length encoded data + var auth []byte + var authLen int + if c.capability&CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA != 0 { + authData, isNULL, readBytes, err := LengthEncodedString(data[pos:]) + if err != nil { + return nil, 0, 0, err + } + if isNULL { + // no auth length and no auth data, just \NUL, considered invalid auth data, and reject connection as MySQL does + return nil, 0, 0, NewDefaultError(ER_ACCESS_DENIED_ERROR, c.LocalAddr().String(), c.user, "Yes") + } + auth = authData + authLen = readBytes + } else { + //auth length and auth + authLen = int(data[pos]) + pos++ + auth = data[pos : pos+authLen] + if authLen == 0 { + // skip the next \NUL in case the password is empty + pos++ + } + } + return auth, authLen, pos, nil +} + +// Public Key Retrieval +// See: https://dev.mysql.com/doc/internals/en/public-key-retrieval.html +func (c *Conn) handlePublicKeyRetrieval(authData []byte) (bool, error) { + // if the client use 'sha256_password' auth method, and request for a public key + // we send back a keyfile with Protocol::AuthMoreData + if c.authPluginName == AUTH_SHA256_PASSWORD && len(authData) == 1 && authData[0] == 0x01 { + if c.serverConf.capability&CLIENT_SSL == 0 { + return false, errors.New("server does not support SSL: CLIENT_SSL not enabled") + } + if err := c.writeAuthMoreDataPubkey(); err != nil { + return false, err + } + + return false, c.handleAuthSwitchResponse() + } + return true, nil +} + +func (c *Conn) handleAuthMatch(authData []byte, pos int) (bool, error) { + // if the client responds the handshake with a different auth method, the server will send the AuthSwitchRequest packet + // to the client to ask the client to switch. + + if c.authPluginName != c.serverConf.defaultAuthMethod { + if err := c.writeAuthSwitchRequest(c.serverConf.defaultAuthMethod); err != nil { + return false, err + } + c.authPluginName = c.serverConf.defaultAuthMethod + // handle AuthSwitchResponse + return false, c.handleAuthSwitchResponse() + } + return true, nil +} diff --git a/server/initial_handshake.go b/server/initial_handshake.go new file mode 100644 index 000000000..312ac2b68 --- /dev/null +++ b/server/initial_handshake.go @@ -0,0 +1,57 @@ +package server + +// see: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html +func (c *Conn) writeInitialHandshake() error { + data := make([]byte, 4) + + //min version 10 + data = append(data, 10) + + //server version[00] + data = append(data, c.serverConf.serverVersion...) + data = append(data, 0x00) + + //connection id + data = append(data, byte(c.connectionID), byte(c.connectionID>>8), byte(c.connectionID>>16), byte(c.connectionID>>24)) + + //auth-plugin-data-part-1 + data = append(data, c.salt[0:8]...) + + //filter 0x00 byte, terminating the first part of a scramble + data = append(data, 0x00) + + defaultFlag := c.serverConf.capability + //capability flag lower 2 bytes, using default capability here + data = append(data, byte(defaultFlag), byte(defaultFlag>>8)) + + //charset + data = append(data, c.serverConf.collationId) + + //status + data = append(data, byte(c.status), byte(c.status>>8)) + + //capability flag upper 2 bytes, using default capability here + data = append(data, byte(defaultFlag>>16), byte(defaultFlag>>24)) + + // server supports CLIENT_PLUGIN_AUTH and CLIENT_SECURE_CONNECTION + data = append(data, byte(8+12+1)) + + //reserved 10 [00] + data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + + //auth-plugin-data-part-2 + data = append(data, c.salt[8:]...) + // second part of the password cipher [mininum 13 bytes], + // where len=MAX(13, length of auth-plugin-data - 8) + // add \NUL to terminate the string + data = append(data, 0x00) + + // auth plugin name + data = append(data, c.serverConf.defaultAuthMethod...) + + // EOF if MySQL version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) + // \NUL otherwise, so we use \NUL + data = append(data, 0) + + return c.WritePacket(data) +} diff --git a/server/resp.go b/server/resp.go index 1123032dd..db8632394 100644 --- a/server/resp.go +++ b/server/resp.go @@ -62,6 +62,59 @@ func (c *Conn) writeEOF() error { return c.WritePacket(data) } +// see: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_request.html +func (c *Conn) writeAuthSwitchRequest(newAuthPluginName string) error { + data := make([]byte, 4) + data = append(data, EOF_HEADER) + data = append(data, []byte(newAuthPluginName)...) + data = append(data, 0x00) + rnd, err := RandomBuf(20) + if err != nil { + return err + } + // new auth data + c.salt = rnd + data = append(data, c.salt...) + // the online doc states it's a string.EOF, however, the actual MySQL server add a \NUL to the end, without it, the + // official MySQL client will fail. + data = append(data, 0x00) + return c.WritePacket(data) +} + +// see: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_response.html +func (c *Conn) readAuthSwitchRequestResponse() ([]byte, error) { + data, err := c.ReadPacket() + if err != nil { + return nil, err + } + if len(data) == 1 && data[0] == 0x00 { + // \NUL + return make([]byte, 0), nil + } + return data, nil +} + +func (c *Conn) writeAuthMoreDataPubkey() error { + data := make([]byte, 4) + data = append(data, MORE_DATE_HEADER) + data = append(data, c.serverConf.pubKey...) + return c.WritePacket(data) +} + +func (c *Conn) writeAuthMoreDataFullAuth() error { + data := make([]byte, 4) + data = append(data, MORE_DATE_HEADER) + data = append(data, CACHE_SHA2_FULL_AUTH) + return c.WritePacket(data) +} + +func (c *Conn) writeAuthMoreDataFastAuth() error { + data := make([]byte, 4) + data = append(data, MORE_DATE_HEADER) + data = append(data, CACHE_SHA2_FAST_AUTH) + return c.WritePacket(data) +} + func (c *Conn) writeResultset(r *Resultset) error { columnLen := PutLengthEncodedInt(uint64(len(r.Fields))) diff --git a/server/server_conf.go b/server/server_conf.go new file mode 100644 index 000000000..353595c4b --- /dev/null +++ b/server/server_conf.go @@ -0,0 +1,103 @@ +package server + +import ( + "crypto/tls" + "fmt" + "sync" + + . "github.com/siddontang/go-mysql/mysql" +) + +var defaultServer = NewDefaultServer() + +// Defines a basic MySQL server with configs. +// +// We do not aim at implementing the whole MySQL connection suite to have the best compatibilities for the clients. +// The MySQL server can be configured to switch auth methods covering 'mysql_old_password', 'mysql_native_password', +// 'mysql_clear_password', 'authentication_windows_client', 'sha256_password', 'caching_sha2_password', etc. +// +// However, since some old auth methods are considered broken with security issues. MySQL major versions like 5.7 and 8.0 default to +// 'mysql_native_password' or 'caching_sha2_password', and most MySQL clients should have already supported at least one of the three auth +// methods 'mysql_native_password', 'caching_sha2_password', and 'sha256_password'. Thus here we will only support these three +// auth methods, and use 'mysql_native_password' as default for maximum compatibility with the clients and leave the other two as +// config options. +// +// The MySQL doc states that 'mysql_old_password' will be used if 'CLIENT_PROTOCOL_41' or 'CLIENT_SECURE_CONNECTION' flag is not set. +// We choose to drop the support for insecure 'mysql_old_password' auth method and require client capability 'CLIENT_PROTOCOL_41' and 'CLIENT_SECURE_CONNECTION' +// are set. Besides, if 'CLIENT_PLUGIN_AUTH' is not set, we fallback to 'mysql_native_password' auth method. +type Server struct { + serverVersion string // e.g. "8.0.12" + protocolVersion int // minimal 10 + capability uint32 // server capability flag + collationId uint8 + defaultAuthMethod string // default authentication method, 'mysql_native_password' + pubKey []byte + tlsConfig *tls.Config + cacheShaPassword *sync.Map // 'user@host' -> SHA256(SHA256(PASSWORD)) +} + +// New mysql server with default settings. +// +// NOTES: +// TLS support will be enabled by default with auto-generated CA and server certificates (however, you can still use +// non-TLS connection). By default, it will verify the client certificate if present. You can enable TLS support on +// the client side without providing a client-side certificate. So only when you need the server to verify client +// identity for maximum security, you need to set a signed certificate for the client. +func NewDefaultServer() *Server { + caPem, caKey := generateCA() + certPem, keyPem := generateAndSignRSACerts(caPem, caKey) + tlsConf := NewServerTLSConfig(caPem, certPem, keyPem, tls.VerifyClientCertIfGiven) + return &Server{ + serverVersion: "5.7.0", + protocolVersion: 10, + capability: CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_CONNECT_WITH_DB | CLIENT_PROTOCOL_41 | + CLIENT_TRANSACTIONS | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_SSL | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA, + collationId: DEFAULT_COLLATION_ID, + defaultAuthMethod: AUTH_NATIVE_PASSWORD, + pubKey: getPublicKeyFromCert(certPem), + tlsConfig: tlsConf, + cacheShaPassword: new(sync.Map), + } +} + +// New mysql server with customized settings. +// +// NOTES: +// You can control the authentication methods and TLS settings here. +// For auth method, you can specify one of the supported methods 'mysql_native_password', 'caching_sha2_password', and 'sha256_password'. +// The specified auth method will be enforced by the server in the connection phase. That means, client will be asked to switch auth method +// if the supplied auth method is different from the server default. +// And for TLS support, you can specify self-signed or CA-signed certificates and decide whether the client needs to provide +// a signed or unsigned certificate to provide different level of security. +func NewServer(serverVersion string, collationId uint8, defaultAuthMethod string, pubKey []byte, tlsConfig *tls.Config) *Server { + if !isAuthMethodSupported(defaultAuthMethod) { + panic(fmt.Sprintf("server authentication method '%s' is not supported", defaultAuthMethod)) + } + + //if !isAuthMethodAllowedByServer(defaultAuthMethod, allowedAuthMethods) { + // panic(fmt.Sprintf("default auth method is not one of the allowed auth methods")) + //} + var capFlag = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_CONNECT_WITH_DB | CLIENT_PROTOCOL_41 | + CLIENT_TRANSACTIONS | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA + if tlsConfig != nil { + capFlag |= CLIENT_SSL + } + return &Server{ + serverVersion: serverVersion, + protocolVersion: 10, + capability: capFlag, + collationId: collationId, + defaultAuthMethod: defaultAuthMethod, + pubKey: pubKey, + tlsConfig: tlsConfig, + cacheShaPassword: new(sync.Map), + } +} + +func isAuthMethodSupported(authMethod string) bool { + return authMethod == AUTH_NATIVE_PASSWORD || authMethod == AUTH_CACHING_SHA2_PASSWORD || authMethod == AUTH_SHA256_PASSWORD +} + +func (s *Server) InvalidateCache(username string, host string) { + s.cacheShaPassword.Delete(fmt.Sprintf("%s@%s", username, host)) +} diff --git a/server/server_test.go b/server/server_test.go index 7d118d522..1f427fd84 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1,6 +1,7 @@ package server import ( + "crypto/tls" "database/sql" "flag" "fmt" @@ -12,112 +13,88 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/juju/errors" . "github.com/pingcap/check" - mysql "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go-log/log" + "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go-mysql/test_util/test_keys" ) var testAddr = flag.String("addr", "127.0.0.1:4000", "MySQL proxy server address") var testUser = flag.String("user", "root", "MySQL user") -var testPassword = flag.String("pass", "", "MySQL password") +var testPassword = flag.String("pass", "123456", "MySQL password") var testDB = flag.String("db", "test", "MySQL test database") -func Test(t *testing.T) { - TestingT(t) -} - -type serverTestSuite struct { - db *sql.DB - - l net.Listener -} - -var _ = Suite(&serverTestSuite{}) - -type testHandler struct { - s *serverTestSuite -} - -func (h *testHandler) UseDB(dbName string) error { - return nil +var tlsConf = NewServerTLSConfig(test_keys.CaPem, test_keys.CertPem, test_keys.KeyPem, tls.VerifyClientCertIfGiven) + +func prepareServerConf() []*Server { + // add default server without TLS + var servers = []*Server{ + // with default TLS + NewDefaultServer(), + // for key exchange, CLIENT_SSL must be enabled for the server and if the connection is not secured with TLS + // server permits MYSQL_NATIVE_PASSWORD only + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_NATIVE_PASSWORD, test_keys.PubPem, tlsConf), + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_NATIVE_PASSWORD, test_keys.PubPem, tlsConf), + // server permits SHA256_PASSWORD only + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_SHA256_PASSWORD, test_keys.PubPem, tlsConf), + // server permits CACHING_SHA2_PASSWORD only + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf), + + // test auth switch: server permits SHA256_PASSWORD only but sent different method MYSQL_NATIVE_PASSWORD in handshake response + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_NATIVE_PASSWORD, test_keys.PubPem, tlsConf), + // test auth switch: server permits CACHING_SHA2_PASSWORD only but sent different method MYSQL_NATIVE_PASSWORD in handshake response + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_NATIVE_PASSWORD, test_keys.PubPem, tlsConf), + // test auth switch: server permits CACHING_SHA2_PASSWORD only but sent different method SHA256_PASSWORD in handshake response + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_SHA256_PASSWORD, test_keys.PubPem, tlsConf), + // test auth switch: server permits MYSQL_NATIVE_PASSWORD only but sent different method SHA256_PASSWORD in handshake response + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_SHA256_PASSWORD, test_keys.PubPem, tlsConf), + // test auth switch: server permits SHA256_PASSWORD only but sent different method CACHING_SHA2_PASSWORD in handshake response + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf), + // test auth switch: server permits MYSQL_NATIVE_PASSWORD only but sent different method CACHING_SHA2_PASSWORD in handshake response + NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf), + } + return servers } -func (h *testHandler) handleQuery(query string, binary bool) (*mysql.Result, error) { - ss := strings.Split(query, " ") - switch strings.ToLower(ss[0]) { - case "select": - var r *mysql.Resultset - var err error - //for handle go mysql driver select @@max_allowed_packet - if strings.Contains(strings.ToLower(query), "max_allowed_packet") { - r, err = mysql.BuildSimpleResultset([]string{"@@max_allowed_packet"}, [][]interface{}{ - []interface{}{mysql.MaxPayloadLen}, - }, binary) - } else { - r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{ - []interface{}{1, "hello world"}, - }, binary) - } +func Test(t *testing.T) { + log.SetLevel(log.LevelDebug) + + // general tests + inMemProvider := NewInMemoryProvider() + inMemProvider.AddUser(*testUser, *testPassword) + + servers := prepareServerConf() + //no TLS + for _, svr := range servers { + Suite(&serverTestSuite{ + server: svr, + credProvider: inMemProvider, + tlsPara: "false", + }) + } - if err != nil { - return nil, errors.Trace(err) - } else { - return &mysql.Result{0, 0, 0, r}, nil + // TLS if server supports + for _, svr := range servers { + if svr.tlsConfig != nil { + Suite(&serverTestSuite{ + server: svr, + credProvider: inMemProvider, + tlsPara: "skip-verify", + }) } - case "insert": - return &mysql.Result{0, 1, 0, nil}, nil - case "delete": - return &mysql.Result{0, 0, 1, nil}, nil - case "update": - return &mysql.Result{0, 0, 1, nil}, nil - case "replace": - return &mysql.Result{0, 0, 1, nil}, nil - default: - return nil, fmt.Errorf("invalid query %s", query) } - return nil, nil -} - -func (h *testHandler) HandleQuery(query string) (*mysql.Result, error) { - return h.handleQuery(query, false) + TestingT(t) } -func (h *testHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) { - return nil, nil -} -func (h *testHandler) HandleStmtPrepare(sql string) (params int, columns int, ctx interface{}, err error) { - ss := strings.Split(sql, " ") - switch strings.ToLower(ss[0]) { - case "select": - params = 1 - columns = 2 - case "insert": - params = 2 - columns = 0 - case "replace": - params = 2 - columns = 0 - case "update": - params = 1 - columns = 0 - case "delete": - params = 1 - columns = 0 - default: - err = fmt.Errorf("invalid prepare %s", sql) - } - return params, columns, nil, err -} +type serverTestSuite struct { + server *Server + credProvider CredentialProvider -func (h *testHandler) HandleStmtClose(context interface{}) error { - return nil -} + tlsPara string -func (h *testHandler) HandleStmtExecute(ctx interface{}, query string, args []interface{}) (*mysql.Result, error) { - return h.handleQuery(query, true) -} + db *sql.DB -func (h *testHandler) HandleOtherCommand(cmd byte, data []byte) error { - return mysql.NewError(mysql.ER_UNKNOWN_ERROR, fmt.Sprintf("command %d is not supported now", cmd)) + l net.Listener } func (s *serverTestSuite) SetUpSuite(c *C) { @@ -128,9 +105,9 @@ func (s *serverTestSuite) SetUpSuite(c *C) { go s.onAccept(c) - time.Sleep(500 * time.Millisecond) + time.Sleep(20 * time.Millisecond) - s.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s", *testUser, *testPassword, *testAddr, *testDB)) + s.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=%s", *testUser, *testPassword, *testAddr, *testDB, s.tlsPara)) c.Assert(err, IsNil) s.db.SetMaxIdleConns(4) @@ -158,9 +135,10 @@ func (s *serverTestSuite) onAccept(c *C) { } func (s *serverTestSuite) onConn(conn net.Conn, c *C) { - co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s}) + //co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s}) + co, err := NewCustomizedConn(conn, s.server, s.credProvider, &testHandler{s}) c.Assert(err, IsNil) - + // set SSL if defined for { err = co.HandleCommand() if err != nil { @@ -232,3 +210,91 @@ func (s *serverTestSuite) TestStmtExec(c *C) { i, _ = r.RowsAffected() c.Assert(i, Equals, int64(1)) } + +type testHandler struct { + s *serverTestSuite +} + +func (h *testHandler) UseDB(dbName string) error { + return nil +} + +func (h *testHandler) handleQuery(query string, binary bool) (*mysql.Result, error) { + ss := strings.Split(query, " ") + switch strings.ToLower(ss[0]) { + case "select": + var r *mysql.Resultset + var err error + //for handle go mysql driver select @@max_allowed_packet + if strings.Contains(strings.ToLower(query), "max_allowed_packet") { + r, err = mysql.BuildSimpleResultset([]string{"@@max_allowed_packet"}, [][]interface{}{ + {mysql.MaxPayloadLen}, + }, binary) + } else { + r, err = mysql.BuildSimpleResultset([]string{"a", "b"}, [][]interface{}{ + {1, "hello world"}, + }, binary) + } + + if err != nil { + return nil, errors.Trace(err) + } else { + return &mysql.Result{0, 0, 0, r}, nil + } + case "insert": + return &mysql.Result{0, 1, 0, nil}, nil + case "delete": + return &mysql.Result{0, 0, 1, nil}, nil + case "update": + return &mysql.Result{0, 0, 1, nil}, nil + case "replace": + return &mysql.Result{0, 0, 1, nil}, nil + default: + return nil, fmt.Errorf("invalid query %s", query) + } + + return nil, nil +} + +func (h *testHandler) HandleQuery(query string) (*mysql.Result, error) { + return h.handleQuery(query, false) +} + +func (h *testHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) { + return nil, nil +} +func (h *testHandler) HandleStmtPrepare(sql string) (params int, columns int, ctx interface{}, err error) { + ss := strings.Split(sql, " ") + switch strings.ToLower(ss[0]) { + case "select": + params = 1 + columns = 2 + case "insert": + params = 2 + columns = 0 + case "replace": + params = 2 + columns = 0 + case "update": + params = 1 + columns = 0 + case "delete": + params = 1 + columns = 0 + default: + err = fmt.Errorf("invalid prepare %s", sql) + } + return params, columns, nil, err +} + +func (h *testHandler) HandleStmtClose(context interface{}) error { + return nil +} + +func (h *testHandler) HandleStmtExecute(ctx interface{}, query string, args []interface{}) (*mysql.Result, error) { + return h.handleQuery(query, true) +} + +func (h *testHandler) HandleOtherCommand(cmd byte, data []byte) error { + return mysql.NewError(mysql.ER_UNKNOWN_ERROR, fmt.Sprintf("command %d is not supported now", cmd)) +} \ No newline at end of file diff --git a/server/ssl.go b/server/ssl.go new file mode 100644 index 000000000..1f8a9edc8 --- /dev/null +++ b/server/ssl.go @@ -0,0 +1,133 @@ +package server + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "time" +) + +// generate TLS config for server side +// controlling the security level by authType +func NewServerTLSConfig(caPem, certPem, keyPem []byte, authType tls.ClientAuthType) *tls.Config { + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caPem) { + panic("failed to add ca PEM") + } + + cert, err := tls.X509KeyPair(certPem, keyPem) + if err != nil { + panic(err) + } + + config := &tls.Config{ + ClientAuth: authType, + Certificates: []tls.Certificate{cert}, + ClientCAs: pool, + } + return config +} + +// extract RSA public key from certificate +func getPublicKeyFromCert(certPem []byte) []byte { + block, _ := pem.Decode(certPem) + crt, err := x509.ParseCertificate(block.Bytes) + if err != nil { + panic(err) + } + pubKey, err := x509.MarshalPKIXPublicKey(crt.PublicKey.(*rsa.PublicKey)) + if err != nil { + panic(err) + } + return pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubKey}) +} + +// generate and sign RSA certificates with given CA +// see: https://fale.io/blog/2017/06/05/create-a-pki-in-golang/ +func generateAndSignRSACerts(caPem, caKey []byte) ([]byte, []byte) { + // Load CA + catls, err := tls.X509KeyPair(caPem, caKey) + if err != nil { + panic(err) + } + ca, err := x509.ParseCertificate(catls.Certificate[0]) + if err != nil { + panic(err) + } + + // use the CA to sign certificates + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + panic(err) + } + cert := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"ORGANIZATION_NAME"}, + Country: []string{"COUNTRY_CODE"}, + Province: []string{"PROVINCE"}, + Locality: []string{"CITY"}, + StreetAddress: []string{"ADDRESS"}, + PostalCode: []string{"POSTAL_CODE"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + SubjectKeyId: []byte{1, 2, 3, 4, 6}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + + // sign the certificate + cert_b, err := x509.CreateCertificate(rand.Reader, ca, cert, &priv.PublicKey, catls.PrivateKey) + if err != nil { + panic(err) + } + certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert_b}) + keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + return certPem, keyPem +} + +// generate CA in PEM +// see: https://github.com/golang/go/blob/master/src/crypto/tls/generate_cert.go +func generateCA() ([]byte, []byte) { + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + panic(err) + } + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"ORGANIZATION_NAME"}, + Country: []string{"COUNTRY_CODE"}, + Province: []string{"PROVINCE"}, + Locality: []string{"CITY"}, + StreetAddress: []string{"ADDRESS"}, + PostalCode: []string{"POSTAL_CODE"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment, + BasicConstraintsValid: true, + } + + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + panic(err) + } + + caPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + caKey := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + return caPem, caKey +} diff --git a/server/stmt.go b/server/stmt.go index 7a325d71e..9bef23ea7 100644 --- a/server/stmt.go +++ b/server/stmt.go @@ -144,7 +144,7 @@ func (c *Conn) handleStmtExecute(data []byte) (*Result, error) { } paramTypes = data[pos : pos+(paramNum<<1)] - pos += (paramNum << 1) + pos += paramNum << 1 paramValues = data[pos:] } @@ -211,7 +211,7 @@ func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) if isUnsigned { args[i] = uint16(binary.LittleEndian.Uint16(paramValues[pos : pos+2])) } else { - args[i] = int16((binary.LittleEndian.Uint16(paramValues[pos : pos+2]))) + args[i] = int16(binary.LittleEndian.Uint16(paramValues[pos : pos+2])) } pos += 2 continue @@ -270,7 +270,7 @@ func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) return ErrMalformPacket } - v, isNull, n, err = LengthEnodedString(paramValues[pos:]) + v, isNull, n, err = LengthEncodedString(paramValues[pos:]) pos += n if err != nil { return errors.Trace(err) @@ -290,7 +290,7 @@ func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) return nil } -// stmt send long data command has no repsonse +// stmt send long data command has no response func (c *Conn) handleStmtSendLongData(data []byte) error { if len(data) < 6 { return nil @@ -340,7 +340,7 @@ func (c *Conn) handleStmtReset(data []byte) (*Result, error) { return &Result{}, nil } -// stmt close command has no repsonse +// stmt close command has no response func (c *Conn) handleStmtClose(data []byte) error { if len(data) < 4 { return nil diff --git a/test_util/test_keys/keys.go b/test_util/test_keys/keys.go new file mode 100644 index 000000000..c1049b6f2 --- /dev/null +++ b/test_util/test_keys/keys.go @@ -0,0 +1,85 @@ +package test_keys + +// here we put the testing encryption keys here +// NOTE THIS IS FOR TESTING ONLY, DO NOT USE THEM IN PRODUCTION! + +var PubPem = []byte(`-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAsraCori69OXEkA07Ykp2 +Ju+aHz33PqgKj0qbSbPm6ePh2mer2GWOxC4q1wdRwzgddwTTqSdonhM4XuVyyNqq +gM7uv9JoWCONcKo28cPRK7gH7up7nYFllNFXUAA0/XQ+95tqtdITNplQLIceFIXz +5Bvi9fThcpf9M6qKdNUa2Wd24rM/n6qxoUG2ksDDVXQC30RAHkGCdNi10iya8Pj/ +ZaEG86NXFpvvnLHRHiih/gXe7nby1sR6BxaEG2bLZd0cjdL5MuWOPeQ450H6mCtV +SX4poNq9YrdP4XW9M0N7nocRU0p5aUvLWxy6XrUTSP0iRkC7ppEPG0p2Xtsq7QGT +MwIDAQAB +-----END PUBLIC KEY-----`) + +var CertPem = []byte(`-----BEGIN CERTIFICATE----- +MIIDBjCCAe4CCQDg06wCf7hcuDANBgkqhkiG9w0BAQUFADBFMQswCQYDVQQGEwJB +VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMB4XDTE4MDgxOTA4NDUyNVoXDTI4MDgxNjA4NDUyNVowRTELMAkG +A1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNVBAoTGEludGVybmV0 +IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +ALK2gqK4uvTlxJANO2JKdibvmh899z6oCo9Km0mz5unj4dpnq9hljsQuKtcHUcM4 +HXcE06knaJ4TOF7lcsjaqoDO7r/SaFgjjXCqNvHD0Su4B+7qe52BZZTRV1AANP10 +PvebarXSEzaZUCyHHhSF8+Qb4vX04XKX/TOqinTVGtlnduKzP5+qsaFBtpLAw1V0 +At9EQB5BgnTYtdIsmvD4/2WhBvOjVxab75yx0R4oof4F3u528tbEegcWhBtmy2Xd +HI3S+TLljj3kOOdB+pgrVUl+KaDavWK3T+F1vTNDe56HEVNKeWlLy1scul61E0j9 +IkZAu6aRDxtKdl7bKu0BkzMCAwEAATANBgkqhkiG9w0BAQUFAAOCAQEAma3yFqR7 +xkeaZBg4/1I3jSlaNe5+2JB4iybAkMOu77fG5zytLomTbzdhewsuBwpTVMJdga8T +IdPeIFCin1U+5SkbjSMlpKf+krE+5CyrNJ5jAgO9ATIqx66oCTYXfGlNapGRLfSE +sa0iMqCe/dr4GPU+flW2DZFWiyJVDSF1JjReQnfrWY+SD2SpP/lmlgltnY8MJngd +xBLG5nsZCpUXGB713Q8ZyIm2ThVAMiskcxBleIZDDghLuhGvY/9eFJhZpvOkjWa6 +XGEi4E1G/SA+zVKFl41nHKCdqXdmIOnpcLlFBUVloQok5a95Kqc1TYw3f+WbdFff +99dAgk3gWwWZQA== +-----END CERTIFICATE-----`) + +var KeyPem = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAsraCori69OXEkA07Ykp2Ju+aHz33PqgKj0qbSbPm6ePh2mer +2GWOxC4q1wdRwzgddwTTqSdonhM4XuVyyNqqgM7uv9JoWCONcKo28cPRK7gH7up7 +nYFllNFXUAA0/XQ+95tqtdITNplQLIceFIXz5Bvi9fThcpf9M6qKdNUa2Wd24rM/ +n6qxoUG2ksDDVXQC30RAHkGCdNi10iya8Pj/ZaEG86NXFpvvnLHRHiih/gXe7nby +1sR6BxaEG2bLZd0cjdL5MuWOPeQ450H6mCtVSX4poNq9YrdP4XW9M0N7nocRU0p5 +aUvLWxy6XrUTSP0iRkC7ppEPG0p2Xtsq7QGTMwIDAQABAoIBAGh1m8hHWCg7gXh9 +838RbRx3IswuKS27hWiaQEiFWmzOIb7KqDy1qAxtu+ayRY1paHegH6QY/+Kd824s +ibpzbgQacJ04/HrAVTVMmQ8Z2VLHoAN7lcPL1bd14aZGaLLZVtDeTDJ413grhxxv +4ho27gcgcbo4Z+rWgk7H2WRPCAGYqWYAycm3yF5vy9QaO6edU+T588YsEQOos5iy +5pVFSGDGZkcUp1ukL3BJYR+jvygn6WPCobQ/LScUdi+ucitaI9i+UdlLokZARVRG +M/msqcTM73thR8yVRcexU6NUDxRBfZ/f7moSAEbBmGDXuxDcIyH9KGMQ2rMtN1X3 +lK8UNwkCgYEA2STJq/IUQHjdqd3Dqh/Q7Zm8/pMWFqLJSkqpnFtFsXPyUOx9zDOy +KqkIfGeyKwvsj9X9BcZ0FUKj9zoct1/WpPY+h7i7+z0MIujBh4AMjAcDrt4o76yK +UHuVmG2xKTdJoAbqOdToQeX6E82Ioal5pbB2W7AbCQScNBPZ52jxgtcCgYEA0rE7 +2dFiRm0YmuszFYxft2+GP6NgP3R2TQNEooi1uCXG2xgwObie1YCHzpZ5CfSqJIxP +XB7DXpIWi7PxJoeai2F83LnmdFz6F1BPRobwDoSFNdaSKLg4Yf856zpgYNKhL1fE +OoOXj4VBWBZh1XDfZV44fgwlMIf7edOF1XOagwUCgYAw953O+7FbdKYwF0V3iOM5 +oZDAK/UwN5eC/GFRVDfcM5RycVJRCVtlSWcTfuLr2C2Jpiz/72fgH34QU3eEVsV1 +v94MBznFB1hESw7ReqvZq/9FoO3EVrl+OtBaZmosLD6bKtQJJJ0Xtz/01UW5hxla +pveZ55XBK9v51nwuNjk4UwKBgHD8fJUllSchUCWb5cwzeAz98Kdl7LJ6uQo5q2/i +EllLYOWThiEeIYdrIuklholRPIDXAaPsF2c6vn5yo+q+o6EFSZlw0+YpCjDAb5Lp +wAh5BprFk6HkkM/0t9Guf4rMyYWC8odSlE9x7YXYkuSMYDCTI4Zs6vCoq7I8PbQn +B4AlAoGAZ6Ee5m/ph5UVp/3+cR6jCY7aHBUU/M3pbJSkVjBW+ymEBVJ6sUdz8k3P +x8BiPEQggNN7faWBqRWP7KXPnDYHh6shYUgPJwI5HX6NE/ZDnnXjeysHRyf0oCo5 +S6tHXwHNKB5HS1c/KDyyNGjP2oi/MF4o/MGWNWEcK6TJA3RGOYM= +-----END RSA PRIVATE KEY-----`) + +var CaPem = []byte(`-----BEGIN CERTIFICATE----- +MIIDtTCCAp2gAwIBAgIJANeS1FOzWXlZMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTgwODE2MTUxNDE5WhcNMjEwNjA1MTUxNDE5WjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEAsV6xlhFxMn14Pn7XBRGLt8/HXmhVVu20IKFgIOyX7gAZr0QLsuT1fGf5 +zH9HrlgOMkfdhV847U03KPfUnBsi9lS6/xOxnH/OzTYM0WW0eNMGF7eoxrS64GSb +PVX4pLi5+uwrrZT5HmDgZi49ANmuX6UYmH/eRRvSIoYUTV6t0aYsLyKvlpEAtRAe +4AlKB236j5ggmJ36QUhTFTbeNbeOOgloTEdPK8Y/kgpnhiqzMdPqqIc7IeXUc456 +yX8MJUgniTM2qCNTFdEw+C2Ok0RbU6TI2SuEgVF4jtCcVEKxZ8kYbioONaePQKFR +/EhdXO+/ag1IEdXElH9knLOfB+zCgwIDAQABo4GnMIGkMB0GA1UdDgQWBBQgHiwD +00upIbCOunlK4HRw89DhjjB1BgNVHSMEbjBsgBQgHiwD00upIbCOunlK4HRw89Dh +jqFJpEcwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNV +BAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZIIJANeS1FOzWXlZMAwGA1UdEwQF +MAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAFMZFQTFKU5tWIpWh8BbVZeVZcng0Kiq +qwbhVwaTkqtfmbqw8/w+faOWylmLncQEMmgvnUltGMQlQKBwQM2byzPkz9phal3g +uI0JWJYqtcMyIQUB9QbbhrDNC9kdt/ji/x6rrIqzaMRuiBXqH5LQ9h856yXzArqd +cAQGzzYpbUCIv7ciSB93cKkU73fQLZVy5ZBy1+oAa1V9U4cb4G/20/PDmT+G3Gxz +pEjeDKtz8XINoWgA2cSdfAhNZt5vqJaCIZ8qN0z6C7SUKwUBderERUMLUXdhUldC +KTVHyEPvd0aULd5S5vEpKCnHcQmFcLdoN8t9k9pR9ZgwqXbyJHlxWFo= +-----END CERTIFICATE-----`) diff --git a/vendor/github.com/BurntSushi/toml/COPYING b/vendor/github.com/BurntSushi/toml/COPYING new file mode 100644 index 000000000..5a8e33254 --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/COPYING @@ -0,0 +1,14 @@ + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + Version 2, December 2004 + + Copyright (C) 2004 Sam Hocevar + + Everyone is permitted to copy and distribute verbatim or modified + copies of this license document, and changing it is allowed as long + as the name is changed. + + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + + 0. You just DO WHAT THE FUCK YOU WANT TO. + diff --git a/vendor/github.com/BurntSushi/toml/cmd/toml-test-decoder/COPYING b/vendor/github.com/BurntSushi/toml/cmd/toml-test-decoder/COPYING new file mode 100644 index 000000000..5a8e33254 --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/cmd/toml-test-decoder/COPYING @@ -0,0 +1,14 @@ + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + Version 2, December 2004 + + Copyright (C) 2004 Sam Hocevar + + Everyone is permitted to copy and distribute verbatim or modified + copies of this license document, and changing it is allowed as long + as the name is changed. + + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + + 0. You just DO WHAT THE FUCK YOU WANT TO. + diff --git a/vendor/github.com/BurntSushi/toml/cmd/toml-test-encoder/COPYING b/vendor/github.com/BurntSushi/toml/cmd/toml-test-encoder/COPYING new file mode 100644 index 000000000..5a8e33254 --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/cmd/toml-test-encoder/COPYING @@ -0,0 +1,14 @@ + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + Version 2, December 2004 + + Copyright (C) 2004 Sam Hocevar + + Everyone is permitted to copy and distribute verbatim or modified + copies of this license document, and changing it is allowed as long + as the name is changed. + + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + + 0. You just DO WHAT THE FUCK YOU WANT TO. + diff --git a/vendor/github.com/BurntSushi/toml/cmd/tomlv/COPYING b/vendor/github.com/BurntSushi/toml/cmd/tomlv/COPYING new file mode 100644 index 000000000..5a8e33254 --- /dev/null +++ b/vendor/github.com/BurntSushi/toml/cmd/tomlv/COPYING @@ -0,0 +1,14 @@ + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + Version 2, December 2004 + + Copyright (C) 2004 Sam Hocevar + + Everyone is permitted to copy and distribute verbatim or modified + copies of this license document, and changing it is allowed as long + as the name is changed. + + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + + 0. You just DO WHAT THE FUCK YOU WANT TO. + diff --git a/vendor/github.com/go-sql-driver/mysql/AUTHORS b/vendor/github.com/go-sql-driver/mysql/AUTHORS new file mode 100644 index 000000000..fbe4ec442 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/AUTHORS @@ -0,0 +1,90 @@ +# This is the official list of Go-MySQL-Driver authors for copyright purposes. + +# If you are submitting a patch, please add your name or the name of the +# organization which holds the copyright to this list in alphabetical order. + +# Names should be added to this file as +# Name +# The email address is not required for organizations. +# Please keep the list sorted. + + +# Individual Persons + +Aaron Hopkins +Achille Roussel +Alexey Palazhchenko +Andrew Reid +Arne Hormann +Asta Xie +Bulat Gaifullin +Carlos Nieto +Chris Moos +Craig Wilson +Daniel Montoya +Daniel Nichter +Daniël van Eeden +Dave Protasowski +DisposaBoy +Egor Smolyakov +Evan Shaw +Frederick Mayle +Gustavo Kristic +Hajime Nakagami +Hanno Braun +Henri Yandell +Hirotaka Yamamoto +ICHINOSE Shogo +INADA Naoki +Jacek Szwec +James Harr +Jeff Hodges +Jeffrey Charles +Jian Zhen +Joshua Prunier +Julien Lefevre +Julien Schmidt +Justin Li +Justin Nuß +Kamil Dziedzic +Kevin Malachowski +Kieron Woodhouse +Lennart Rudolph +Leonardo YongUk Kim +Linh Tran Tuan +Lion Yang +Luca Looz +Lucas Liu +Luke Scott +Maciej Zimnoch +Michael Woolnough +Nicola Peduzzi +Olivier Mengué +oscarzhao +Paul Bonser +Peter Schultz +Rebecca Chin +Reed Allman +Richard Wilkes +Robert Russell +Runrioter Wung +Shuode Li +Soroush Pour +Stan Putrya +Stanley Gunawan +Thomas Wodarek +Xiangyu Hu +Xiaobing Jiang +Xiuming Chen +Zhenye Xie + +# Organizations + +Barracuda Networks, Inc. +Counting Ltd. +Google Inc. +InfoSum Ltd. +Keybase Inc. +Percona LLC +Pivotal Inc. +Stripe Inc. diff --git a/vendor/github.com/go-sql-driver/mysql/appengine.go b/vendor/github.com/go-sql-driver/mysql/appengine.go index 565614eef..be41f2ee6 100644 --- a/vendor/github.com/go-sql-driver/mysql/appengine.go +++ b/vendor/github.com/go-sql-driver/mysql/appengine.go @@ -11,7 +11,7 @@ package mysql import ( - "appengine/cloudsql" + "google.golang.org/appengine/cloudsql" ) func init() { diff --git a/vendor/github.com/go-sql-driver/mysql/auth.go b/vendor/github.com/go-sql-driver/mysql/auth.go new file mode 100644 index 000000000..2f61ecd4f --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/auth.go @@ -0,0 +1,420 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/x509" + "encoding/pem" + "sync" +) + +// server pub keys registry +var ( + serverPubKeyLock sync.RWMutex + serverPubKeyRegistry map[string]*rsa.PublicKey +) + +// RegisterServerPubKey registers a server RSA public key which can be used to +// send data in a secure manner to the server without receiving the public key +// in a potentially insecure way from the server first. +// Registered keys can afterwards be used adding serverPubKey= to the DSN. +// +// Note: The provided rsa.PublicKey instance is exclusively owned by the driver +// after registering it and may not be modified. +// +// data, err := ioutil.ReadFile("mykey.pem") +// if err != nil { +// log.Fatal(err) +// } +// +// block, _ := pem.Decode(data) +// if block == nil || block.Type != "PUBLIC KEY" { +// log.Fatal("failed to decode PEM block containing public key") +// } +// +// pub, err := x509.ParsePKIXPublicKey(block.Bytes) +// if err != nil { +// log.Fatal(err) +// } +// +// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok { +// mysql.RegisterServerPubKey("mykey", rsaPubKey) +// } else { +// log.Fatal("not a RSA public key") +// } +// +func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) { + serverPubKeyLock.Lock() + if serverPubKeyRegistry == nil { + serverPubKeyRegistry = make(map[string]*rsa.PublicKey) + } + + serverPubKeyRegistry[name] = pubKey + serverPubKeyLock.Unlock() +} + +// DeregisterServerPubKey removes the public key registered with the given name. +func DeregisterServerPubKey(name string) { + serverPubKeyLock.Lock() + if serverPubKeyRegistry != nil { + delete(serverPubKeyRegistry, name) + } + serverPubKeyLock.Unlock() +} + +func getServerPubKey(name string) (pubKey *rsa.PublicKey) { + serverPubKeyLock.RLock() + if v, ok := serverPubKeyRegistry[name]; ok { + pubKey = v + } + serverPubKeyLock.RUnlock() + return +} + +// Hash password using pre 4.1 (old password) method +// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c +type myRnd struct { + seed1, seed2 uint32 +} + +const myRndMaxVal = 0x3FFFFFFF + +// Pseudo random number generator +func newMyRnd(seed1, seed2 uint32) *myRnd { + return &myRnd{ + seed1: seed1 % myRndMaxVal, + seed2: seed2 % myRndMaxVal, + } +} + +// Tested to be equivalent to MariaDB's floating point variant +// http://play.golang.org/p/QHvhd4qved +// http://play.golang.org/p/RG0q4ElWDx +func (r *myRnd) NextByte() byte { + r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal + r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal + + return byte(uint64(r.seed1) * 31 / myRndMaxVal) +} + +// Generate binary hash from byte string using insecure pre 4.1 method +func pwHash(password []byte) (result [2]uint32) { + var add uint32 = 7 + var tmp uint32 + + result[0] = 1345345333 + result[1] = 0x12345671 + + for _, c := range password { + // skip spaces and tabs in password + if c == ' ' || c == '\t' { + continue + } + + tmp = uint32(c) + result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) + result[1] += (result[1] << 8) ^ result[0] + add += tmp + } + + // Remove sign bit (1<<31)-1) + result[0] &= 0x7FFFFFFF + result[1] &= 0x7FFFFFFF + + return +} + +// Hash password using insecure pre 4.1 method +func scrambleOldPassword(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + scramble = scramble[:8] + + hashPw := pwHash([]byte(password)) + hashSc := pwHash(scramble) + + r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) + + var out [8]byte + for i := range out { + out[i] = r.NextByte() + 64 + } + + mask := r.NextByte() + for i := range out { + out[i] ^= mask + } + + return out[:] +} + +// Hash password using 4.1+ method (SHA1) +func scramblePassword(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + // stage1Hash = SHA1(password) + crypt := sha1.New() + crypt.Write([]byte(password)) + stage1 := crypt.Sum(nil) + + // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) + // inner Hash + crypt.Reset() + crypt.Write(stage1) + hash := crypt.Sum(nil) + + // outer Hash + crypt.Reset() + crypt.Write(scramble) + crypt.Write(hash) + scramble = crypt.Sum(nil) + + // token = scrambleHash XOR stage1Hash + for i := range scramble { + scramble[i] ^= stage1[i] + } + return scramble +} + +// Hash password using MySQL 8+ method (SHA256) +func scrambleSHA256Password(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) + + crypt := sha256.New() + crypt.Write([]byte(password)) + message1 := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1) + message1Hash := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1Hash) + crypt.Write(scramble) + message2 := crypt.Sum(nil) + + for i := range message1 { + message1[i] ^= message2[i] + } + + return message1 +} + +func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { + plain := make([]byte, len(password)+1) + copy(plain, password) + for i := range plain { + j := i % len(seed) + plain[i] ^= seed[j] + } + sha1 := sha1.New() + return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) +} + +func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { + enc, err := encryptPassword(mc.cfg.Passwd, seed, pub) + if err != nil { + return err + } + return mc.writeAuthSwitchPacket(enc, false) +} + +func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) { + switch plugin { + case "caching_sha2_password": + authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) + return authResp, false, nil + + case "mysql_old_password": + if !mc.cfg.AllowOldPasswords { + return nil, false, ErrOldPassword + } + // Note: there are edge cases where this should work but doesn't; + // this is currently "wontfix": + // https://github.com/go-sql-driver/mysql/issues/184 + authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd) + return authResp, true, nil + + case "mysql_clear_password": + if !mc.cfg.AllowCleartextPasswords { + return nil, false, ErrCleartextPassword + } + // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html + // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html + return []byte(mc.cfg.Passwd), true, nil + + case "mysql_native_password": + if !mc.cfg.AllowNativePasswords { + return nil, false, ErrNativePassword + } + // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html + // Native password authentication only need and will need 20-byte challenge. + authResp := scramblePassword(authData[:20], mc.cfg.Passwd) + return authResp, false, nil + + case "sha256_password": + if len(mc.cfg.Passwd) == 0 { + return nil, true, nil + } + if mc.cfg.tls != nil || mc.cfg.Net == "unix" { + // write cleartext auth packet + return []byte(mc.cfg.Passwd), true, nil + } + + pubKey := mc.cfg.pubKey + if pubKey == nil { + // request public key from server + return []byte{1}, false, nil + } + + // encrypted password + enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) + return enc, false, err + + default: + errLog.Print("unknown auth plugin:", plugin) + return nil, false, ErrUnknownPlugin + } +} + +func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { + // Read Result Packet + authData, newPlugin, err := mc.readAuthResult() + if err != nil { + return err + } + + // handle auth plugin switch, if requested + if newPlugin != "" { + // If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is + // sent and we have to keep using the cipher sent in the init packet. + if authData == nil { + authData = oldAuthData + } else { + // copy data from read buffer to owned slice + copy(oldAuthData, authData) + } + + plugin = newPlugin + + authResp, addNUL, err := mc.auth(authData, plugin) + if err != nil { + return err + } + if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil { + return err + } + + // Read Result Packet + authData, newPlugin, err = mc.readAuthResult() + if err != nil { + return err + } + + // Do not allow to change the auth plugin more than once + if newPlugin != "" { + return ErrMalformPkt + } + } + + switch plugin { + + // https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ + case "caching_sha2_password": + switch len(authData) { + case 0: + return nil // auth successful + case 1: + switch authData[0] { + case cachingSha2PasswordFastAuthSuccess: + if err = mc.readResultOK(); err == nil { + return nil // auth successful + } + + case cachingSha2PasswordPerformFullAuthentication: + if mc.cfg.tls != nil || mc.cfg.Net == "unix" { + // write cleartext auth packet + err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true) + if err != nil { + return err + } + } else { + pubKey := mc.cfg.pubKey + if pubKey == nil { + // request public key from server + data := mc.buf.takeSmallBuffer(4 + 1) + data[4] = cachingSha2PasswordRequestPublicKey + mc.writePacket(data) + + // parse public key + data, err := mc.readPacket() + if err != nil { + return err + } + + block, _ := pem.Decode(data[1:]) + pkix, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + pubKey = pkix.(*rsa.PublicKey) + } + + // send encrypted password + err = mc.sendEncryptedPassword(oldAuthData, pubKey) + if err != nil { + return err + } + } + return mc.readResultOK() + + default: + return ErrMalformPkt + } + default: + return ErrMalformPkt + } + + case "sha256_password": + switch len(authData) { + case 0: + return nil // auth successful + default: + block, _ := pem.Decode(authData) + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + + // send encrypted password + err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey)) + if err != nil { + return err + } + return mc.readResultOK() + } + + default: + return nil // auth successful + } + + return err +} diff --git a/vendor/github.com/go-sql-driver/mysql/buffer.go b/vendor/github.com/go-sql-driver/mysql/buffer.go index 2001feacd..eb4748bf4 100644 --- a/vendor/github.com/go-sql-driver/mysql/buffer.go +++ b/vendor/github.com/go-sql-driver/mysql/buffer.go @@ -130,18 +130,18 @@ func (b *buffer) takeBuffer(length int) []byte { // smaller than defaultBufSize // Only one buffer (total) can be used at a time. func (b *buffer) takeSmallBuffer(length int) []byte { - if b.length == 0 { - return b.buf[:length] + if b.length > 0 { + return nil } - return nil + return b.buf[:length] } // takeCompleteBuffer returns the complete existing buffer. // This can be used if the necessary buffer size is unknown. // Only one buffer (total) can be used at a time. func (b *buffer) takeCompleteBuffer() []byte { - if b.length == 0 { - return b.buf + if b.length > 0 { + return nil } - return nil + return b.buf } diff --git a/vendor/github.com/go-sql-driver/mysql/collations.go b/vendor/github.com/go-sql-driver/mysql/collations.go index 82079cfb9..136c9e4d1 100644 --- a/vendor/github.com/go-sql-driver/mysql/collations.go +++ b/vendor/github.com/go-sql-driver/mysql/collations.go @@ -9,6 +9,7 @@ package mysql const defaultCollation = "utf8_general_ci" +const binaryCollation = "binary" // A list of available collations mapped to the internal ID. // To update this map use the following MySQL query: diff --git a/vendor/github.com/go-sql-driver/mysql/connection.go b/vendor/github.com/go-sql-driver/mysql/connection.go index d82c728f3..911be2060 100644 --- a/vendor/github.com/go-sql-driver/mysql/connection.go +++ b/vendor/github.com/go-sql-driver/mysql/connection.go @@ -9,13 +9,26 @@ package mysql import ( + "context" + "database/sql" "database/sql/driver" + "io" "net" "strconv" "strings" "time" ) +// a copy of context.Context for Go 1.7 and earlier +type mysqlContext interface { + Done() <-chan struct{} + Err() error + + // defined in context.Context, but not used in this driver: + // Deadline() (deadline time.Time, ok bool) + // Value(key interface{}) interface{} +} + type mysqlConn struct { buf buffer netConn net.Conn @@ -29,7 +42,14 @@ type mysqlConn struct { status statusFlag sequence uint8 parseTime bool - strict bool + + // for context support (Go 1.8+) + watching bool + watcher chan<- mysqlContext + closech chan struct{} + finished chan<- struct{} + canceled atomicError // set non-nil if conn is canceled + closed atomicBool // set when conn is closed, before closech is closed } // Handles parameters set in DSN after the connection is established @@ -62,22 +82,41 @@ func (mc *mysqlConn) handleParams() (err error) { return } +func (mc *mysqlConn) markBadConn(err error) error { + if mc == nil { + return err + } + if err != errBadConnNoWrite { + return err + } + return driver.ErrBadConn +} + func (mc *mysqlConn) Begin() (driver.Tx, error) { - if mc.netConn == nil { + return mc.begin(false) +} + +func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } - err := mc.exec("START TRANSACTION") + var q string + if readOnly { + q = "START TRANSACTION READ ONLY" + } else { + q = "START TRANSACTION" + } + err := mc.exec(q) if err == nil { return &mysqlTx{mc}, err } - - return nil, err + return nil, mc.markBadConn(err) } func (mc *mysqlConn) Close() (err error) { // Makes Close idempotent - if mc.netConn != nil { + if !mc.closed.IsSet() { err = mc.writeCommandPacket(comQuit) } @@ -91,26 +130,39 @@ func (mc *mysqlConn) Close() (err error) { // is called before auth or on auth failure because MySQL will have already // closed the network connection. func (mc *mysqlConn) cleanup() { + if !mc.closed.TrySet(true) { + return + } + // Makes cleanup idempotent - if mc.netConn != nil { - if err := mc.netConn.Close(); err != nil { - errLog.Print(err) + close(mc.closech) + if mc.netConn == nil { + return + } + if err := mc.netConn.Close(); err != nil { + errLog.Print(err) + } +} + +func (mc *mysqlConn) error() error { + if mc.closed.IsSet() { + if err := mc.canceled.Value(); err != nil { + return err } - mc.netConn = nil + return ErrInvalidConn } - mc.cfg = nil - mc.buf.nc = nil + return nil } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { - if mc.netConn == nil { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { - return nil, err + return nil, mc.markBadConn(err) } stmt := &mysqlStmt{ @@ -144,7 +196,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin if buf == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return "", driver.ErrBadConn + return "", ErrInvalidConn } buf = buf[:0] argPos := 0 @@ -257,7 +309,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { - if mc.netConn == nil { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -271,7 +323,6 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err return nil, err } query = prepared - args = nil } mc.affectedRows = 0 mc.insertId = 0 @@ -283,32 +334,43 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err insertId: int64(mc.insertId), }, err } - return nil, err + return nil, mc.markBadConn(err) } // Internal function to execute commands func (mc *mysqlConn) exec(query string) error { // Send command - err := mc.writeCommandPacketStr(comQuery, query) - if err != nil { - return err + if err := mc.writeCommandPacketStr(comQuery, query); err != nil { + return mc.markBadConn(err) } // Read Result resLen, err := mc.readResultSetHeaderPacket() - if err == nil && resLen > 0 { - if err = mc.readUntilEOF(); err != nil { + if err != nil { + return err + } + + if resLen > 0 { + // columns + if err := mc.readUntilEOF(); err != nil { return err } - err = mc.readUntilEOF() + // rows + if err := mc.readUntilEOF(); err != nil { + return err + } } - return err + return mc.discardResults() } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { - if mc.netConn == nil { + return mc.query(query, args) +} + +func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -322,7 +384,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro return nil, err } query = prepared - args = nil } // Send command err := mc.writeCommandPacketStr(comQuery, query) @@ -335,15 +396,22 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro rows.mc = mc if resLen == 0 { - // no columns, no more data - return emptyRows{}, nil + rows.rs.done = true + + switch err := rows.NextResultSet(); err { + case nil, io.EOF: + return rows, nil + default: + return nil, err + } } + // Columns - rows.columns, err = mc.readColumns(resLen) + rows.rs.columns, err = mc.readColumns(resLen) return rows, err } } - return nil, err + return nil, mc.markBadConn(err) } // Gets the value of the given MySQL System Variable @@ -359,7 +427,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { if err == nil { rows := new(textRows) rows.mc = mc - rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}} + rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} if resLen > 0 { // Columns @@ -375,3 +443,212 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { } return nil, err } + +// finish is called when the query has canceled. +func (mc *mysqlConn) cancel(err error) { + mc.canceled.Set(err) + mc.cleanup() +} + +// finish is called when the query has succeeded. +func (mc *mysqlConn) finish() { + if !mc.watching || mc.finished == nil { + return + } + select { + case mc.finished <- struct{}{}: + mc.watching = false + case <-mc.closech: + } +} + +// Ping implements driver.Pinger interface +func (mc *mysqlConn) Ping(ctx context.Context) (err error) { + if mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + + if err = mc.watchCancel(ctx); err != nil { + return + } + defer mc.finish() + + if err = mc.writeCommandPacket(comPing); err != nil { + return + } + + return mc.readResultOK() +} + +// BeginTx implements driver.ConnBeginTx interface +func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { + level, err := mapIsolationLevel(opts.Isolation) + if err != nil { + return nil, err + } + err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) + if err != nil { + return nil, err + } + } + + return mc.begin(opts.ReadOnly) +} + +func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := mc.query(query, dargs) + if err != nil { + mc.finish() + return nil, err + } + rows.finish = mc.finish + return rows, err +} + +func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + return mc.Exec(query, dargs) +} + +func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + stmt, err := mc.Prepare(query) + mc.finish() + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + stmt.Close() + return nil, ctx.Err() + } + return stmt, nil +} + +func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := stmt.query(dargs) + if err != nil { + stmt.mc.finish() + return nil, err + } + rows.finish = stmt.mc.finish + return rows, err +} + +func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + defer stmt.mc.finish() + + return stmt.Exec(dargs) +} + +func (mc *mysqlConn) watchCancel(ctx context.Context) error { + if mc.watching { + // Reach here if canceled, + // so the connection is already invalid + mc.cleanup() + return nil + } + if ctx.Done() == nil { + return nil + } + + mc.watching = true + select { + default: + case <-ctx.Done(): + return ctx.Err() + } + if mc.watcher == nil { + return nil + } + + mc.watcher <- ctx + + return nil +} + +func (mc *mysqlConn) startWatcher() { + watcher := make(chan mysqlContext, 1) + mc.watcher = watcher + finished := make(chan struct{}) + mc.finished = finished + go func() { + for { + var ctx mysqlContext + select { + case ctx = <-watcher: + case <-mc.closech: + return + } + + select { + case <-ctx.Done(): + mc.cancel(ctx.Err()) + case <-finished: + case <-mc.closech: + return + } + } + }() +} + +func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = converter{}.ConvertValue(nv.Value) + return +} + +// ResetSession implements driver.SessionResetter. +// (From Go 1.10) +func (mc *mysqlConn) ResetSession(ctx context.Context) error { + if mc.closed.IsSet() { + return driver.ErrBadConn + } + return nil +} diff --git a/vendor/github.com/go-sql-driver/mysql/const.go b/vendor/github.com/go-sql-driver/mysql/const.go index 88cfff3fd..b1e6b85ef 100644 --- a/vendor/github.com/go-sql-driver/mysql/const.go +++ b/vendor/github.com/go-sql-driver/mysql/const.go @@ -9,7 +9,9 @@ package mysql const ( - minProtocolVersion byte = 10 + defaultAuthPlugin = "mysql_native_password" + defaultMaxAllowedPacket = 4 << 20 // 4 MiB + minProtocolVersion = 10 maxPacketSize = 1<<24 - 1 timeFormat = "2006-01-02 15:04:05.999999" ) @@ -18,10 +20,11 @@ const ( // http://dev.mysql.com/doc/internals/en/client-server-protocol.html const ( - iOK byte = 0x00 - iLocalInFile byte = 0xfb - iEOF byte = 0xfe - iERR byte = 0xff + iOK byte = 0x00 + iAuthMoreData byte = 0x01 + iLocalInFile byte = 0xfb + iEOF byte = 0xfe + iERR byte = 0xff ) // https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags @@ -87,8 +90,10 @@ const ( ) // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType +type fieldType byte + const ( - fieldTypeDecimal byte = iota + fieldTypeDecimal fieldType = iota fieldTypeTiny fieldTypeShort fieldTypeLong @@ -107,7 +112,7 @@ const ( fieldTypeBit ) const ( - fieldTypeJSON byte = iota + 0xf5 + fieldTypeJSON fieldType = iota + 0xf5 fieldTypeNewDecimal fieldTypeEnum fieldTypeSet @@ -161,3 +166,9 @@ const ( statusInTransReadonly statusSessionStateChanged ) + +const ( + cachingSha2PasswordRequestPublicKey = 2 + cachingSha2PasswordFastAuthSuccess = 3 + cachingSha2PasswordPerformFullAuthentication = 4 +) diff --git a/vendor/github.com/go-sql-driver/mysql/driver.go b/vendor/github.com/go-sql-driver/mysql/driver.go index 0022d1f1e..ba1297825 100644 --- a/vendor/github.com/go-sql-driver/mysql/driver.go +++ b/vendor/github.com/go-sql-driver/mysql/driver.go @@ -4,7 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// Package mysql provides a MySQL driver for Go's database/sql package +// Package mysql provides a MySQL driver for Go's database/sql package. // // The driver should be used via the database/sql package: // @@ -20,6 +20,7 @@ import ( "database/sql" "database/sql/driver" "net" + "sync" ) // MySQLDriver is exported to make the driver directly accessible. @@ -30,12 +31,17 @@ type MySQLDriver struct{} // Custom dial functions must be registered with RegisterDial type DialFunc func(addr string) (net.Conn, error) -var dials map[string]DialFunc +var ( + dialsLock sync.RWMutex + dials map[string]DialFunc +) // RegisterDial registers a custom dial function. It can then be used by the // network address mynet(addr), where mynet is the registered new network. // addr is passed as a parameter to the dial function. func RegisterDial(net string, dial DialFunc) { + dialsLock.Lock() + defer dialsLock.Unlock() if dials == nil { dials = make(map[string]DialFunc) } @@ -52,16 +58,19 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc := &mysqlConn{ maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, + closech: make(chan struct{}), } mc.cfg, err = ParseDSN(dsn) if err != nil { return nil, err } mc.parseTime = mc.cfg.ParseTime - mc.strict = mc.cfg.Strict // Connect to Server - if dial, ok := dials[mc.cfg.Net]; ok { + dialsLock.RLock() + dial, ok := dials[mc.cfg.Net] + dialsLock.RUnlock() + if ok { mc.netConn, err = dial(mc.cfg.Addr) } else { nd := net.Dialer{Timeout: mc.cfg.Timeout} @@ -81,6 +90,9 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { } } + // Call startWatcher for context support (From Go 1.8) + mc.startWatcher() + mc.buf = newBuffer(mc.netConn) // Set I/O timeouts @@ -88,20 +100,34 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.writeTimeout = mc.cfg.WriteTimeout // Reading Handshake Initialization Packet - cipher, err := mc.readInitPacket() + authData, plugin, err := mc.readHandshakePacket() if err != nil { mc.cleanup() return nil, err } + if plugin == "" { + plugin = defaultAuthPlugin + } // Send Client Authentication Packet - if err = mc.writeAuthPacket(cipher); err != nil { + authResp, addNUL, err := mc.auth(authData, plugin) + if err != nil { + // try the default auth plugin, if using the requested plugin failed + errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) + plugin = defaultAuthPlugin + authResp, addNUL, err = mc.auth(authData, plugin) + if err != nil { + mc.cleanup() + return nil, err + } + } + if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil { mc.cleanup() return nil, err } // Handle response to auth packet, switch methods if possible - if err = handleAuthResult(mc, cipher); err != nil { + if err = mc.handleAuthResult(authData, plugin); err != nil { // Authentication failed and MySQL has already closed the connection // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). // Do not send COM_QUIT, just cleanup and return the error. @@ -134,50 +160,6 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { return mc, nil } -func handleAuthResult(mc *mysqlConn, oldCipher []byte) error { - // Read Result Packet - cipher, err := mc.readResultOK() - if err == nil { - return nil // auth successful - } - - if mc.cfg == nil { - return err // auth failed and retry not possible - } - - // Retry auth if configured to do so. - if mc.cfg.AllowOldPasswords && err == ErrOldPassword { - // Retry with old authentication method. Note: there are edge cases - // where this should work but doesn't; this is currently "wontfix": - // https://github.com/go-sql-driver/mysql/issues/184 - - // If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is - // sent and we have to keep using the cipher sent in the init packet. - if cipher == nil { - cipher = oldCipher - } - - if err = mc.writeOldAuthPacket(cipher); err != nil { - return err - } - _, err = mc.readResultOK() - } else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword { - // Retry with clear text password for - // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html - // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html - if err = mc.writeClearAuthPacket(); err != nil { - return err - } - _, err = mc.readResultOK() - } else if mc.cfg.AllowNativePasswords && err == ErrNativePassword { - if err = mc.writeNativeAuthPacket(cipher); err != nil { - return err - } - _, err = mc.readResultOK() - } - return err -} - func init() { sql.Register("mysql", &MySQLDriver{}) } diff --git a/vendor/github.com/go-sql-driver/mysql/dsn.go b/vendor/github.com/go-sql-driver/mysql/dsn.go index ac00dcedd..be014babe 100644 --- a/vendor/github.com/go-sql-driver/mysql/dsn.go +++ b/vendor/github.com/go-sql-driver/mysql/dsn.go @@ -10,11 +10,13 @@ package mysql import ( "bytes" + "crypto/rsa" "crypto/tls" "errors" "fmt" "net" "net/url" + "sort" "strconv" "strings" "time" @@ -27,7 +29,9 @@ var ( errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") ) -// Config is a configuration parsed from a DSN string +// Config is a configuration parsed from a DSN string. +// If a new Config is created instead of being parsed from a DSN string, +// the NewConfig function should be used, which sets default values. type Config struct { User string // Username Passwd string // Password (requires User) @@ -38,6 +42,8 @@ type Config struct { Collation string // Connection collation Loc *time.Location // Location for time.Time values MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + pubKey *rsa.PublicKey // Server public key TLSConfig string // TLS configuration name tls *tls.Config // TLS configuration Timeout time.Duration // Dial timeout @@ -53,7 +59,54 @@ type Config struct { InterpolateParams bool // Interpolate placeholders into query string MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time - Strict bool // Return warnings as errors + RejectReadOnly bool // Reject read-only connections +} + +// NewConfig creates a new Config and sets default values. +func NewConfig() *Config { + return &Config{ + Collation: defaultCollation, + Loc: time.UTC, + MaxAllowedPacket: defaultMaxAllowedPacket, + AllowNativePasswords: true, + } +} + +func (cfg *Config) normalize() error { + if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { + return errInvalidDSNUnsafeCollation + } + + // Set default network if empty + if cfg.Net == "" { + cfg.Net = "tcp" + } + + // Set default address if empty + if cfg.Addr == "" { + switch cfg.Net { + case "tcp": + cfg.Addr = "127.0.0.1:3306" + case "unix": + cfg.Addr = "/tmp/mysql.sock" + default: + return errors.New("default addr for network '" + cfg.Net + "' unknown") + } + + } else if cfg.Net == "tcp" { + cfg.Addr = ensureHavePort(cfg.Addr) + } + + if cfg.tls != nil { + if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { + host, _, err := net.SplitHostPort(cfg.Addr) + if err == nil { + cfg.tls.ServerName = host + } + } + } + + return nil } // FormatDSN formats the given Config into a DSN string which can be passed to @@ -102,12 +155,12 @@ func (cfg *Config) FormatDSN() string { } } - if cfg.AllowNativePasswords { + if !cfg.AllowNativePasswords { if hasParam { - buf.WriteString("&allowNativePasswords=true") + buf.WriteString("&allowNativePasswords=false") } else { hasParam = true - buf.WriteString("?allowNativePasswords=true") + buf.WriteString("?allowNativePasswords=false") } } @@ -195,15 +248,25 @@ func (cfg *Config) FormatDSN() string { buf.WriteString(cfg.ReadTimeout.String()) } - if cfg.Strict { + if cfg.RejectReadOnly { if hasParam { - buf.WriteString("&strict=true") + buf.WriteString("&rejectReadOnly=true") } else { hasParam = true - buf.WriteString("?strict=true") + buf.WriteString("?rejectReadOnly=true") } } + if len(cfg.ServerPubKey) > 0 { + if hasParam { + buf.WriteString("&serverPubKey=") + } else { + hasParam = true + buf.WriteString("?serverPubKey=") + } + buf.WriteString(url.QueryEscape(cfg.ServerPubKey)) + } + if cfg.Timeout > 0 { if hasParam { buf.WriteString("&timeout=") @@ -234,7 +297,7 @@ func (cfg *Config) FormatDSN() string { buf.WriteString(cfg.WriteTimeout.String()) } - if cfg.MaxAllowedPacket > 0 { + if cfg.MaxAllowedPacket != defaultMaxAllowedPacket { if hasParam { buf.WriteString("&maxAllowedPacket=") } else { @@ -247,7 +310,12 @@ func (cfg *Config) FormatDSN() string { // other params if cfg.Params != nil { - for param, value := range cfg.Params { + var params []string + for param := range cfg.Params { + params = append(params, param) + } + sort.Strings(params) + for _, param := range params { if hasParam { buf.WriteByte('&') } else { @@ -257,7 +325,7 @@ func (cfg *Config) FormatDSN() string { buf.WriteString(param) buf.WriteByte('=') - buf.WriteString(url.QueryEscape(value)) + buf.WriteString(url.QueryEscape(cfg.Params[param])) } } @@ -267,10 +335,7 @@ func (cfg *Config) FormatDSN() string { // ParseDSN parses the DSN string to a Config func ParseDSN(dsn string) (cfg *Config, err error) { // New config with some default values - cfg = &Config{ - Loc: time.UTC, - Collation: defaultCollation, - } + cfg = NewConfig() // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] // Find the last '/' (since the password or the net addr might contain a '/') @@ -338,28 +403,9 @@ func ParseDSN(dsn string) (cfg *Config, err error) { return nil, errInvalidDSNNoSlash } - if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { - return nil, errInvalidDSNUnsafeCollation - } - - // Set default network if empty - if cfg.Net == "" { - cfg.Net = "tcp" + if err = cfg.normalize(); err != nil { + return nil, err } - - // Set default address if empty - if cfg.Addr == "" { - switch cfg.Net { - case "tcp": - cfg.Addr = "127.0.0.1:3306" - case "unix": - cfg.Addr = "/tmp/mysql.sock" - default: - return nil, errors.New("default addr for network '" + cfg.Net + "' unknown") - } - - } - return } @@ -374,7 +420,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { // cfg params switch value := param[1]; param[0] { - // Disable INFILE whitelist / enable all files case "allowAllFiles": var isBool bool @@ -472,14 +517,32 @@ func parseDSNParams(cfg *Config, params string) (err error) { return } - // Strict mode - case "strict": + // Reject read-only connections + case "rejectReadOnly": var isBool bool - cfg.Strict, isBool = readBool(value) + cfg.RejectReadOnly, isBool = readBool(value) if !isBool { return errors.New("invalid bool value: " + value) } + // Server public key + case "serverPubKey": + name, err := url.QueryUnescape(value) + if err != nil { + return fmt.Errorf("invalid value for server pub key name: %v", err) + } + + if pubKey := getServerPubKey(name); pubKey != nil { + cfg.ServerPubKey = name + cfg.pubKey = pubKey + } else { + return errors.New("invalid value / unknown server pub key name: " + name) + } + + // Strict mode + case "strict": + panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") + // Dial Timeout case "timeout": cfg.Timeout, err = time.ParseDuration(value) @@ -506,14 +569,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { return fmt.Errorf("invalid value for TLS config name: %v", err) } - if tlsConfig, ok := tlsConfigRegister[name]; ok { - if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { - host, _, err := net.SplitHostPort(cfg.Addr) - if err == nil { - tlsConfig.ServerName = host - } - } - + if tlsConfig := getTLSConfigClone(name); tlsConfig != nil { cfg.TLSConfig = name cfg.tls = tlsConfig } else { @@ -546,3 +602,10 @@ func parseDSNParams(cfg *Config, params string) (err error) { return } + +func ensureHavePort(addr string) string { + if _, _, err := net.SplitHostPort(addr); err != nil { + return net.JoinHostPort(addr, "3306") + } + return addr +} diff --git a/vendor/github.com/go-sql-driver/mysql/errors.go b/vendor/github.com/go-sql-driver/mysql/errors.go index 857854e14..760782ff2 100644 --- a/vendor/github.com/go-sql-driver/mysql/errors.go +++ b/vendor/github.com/go-sql-driver/mysql/errors.go @@ -9,10 +9,8 @@ package mysql import ( - "database/sql/driver" "errors" "fmt" - "io" "log" "os" ) @@ -31,6 +29,12 @@ var ( ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") ErrBusyBuffer = errors.New("busy buffer") + + // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. + // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn + // to trigger a resend. + // See https://github.com/go-sql-driver/mysql/pull/302 + errBadConnNoWrite = errors.New("bad connection") ) var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) @@ -59,74 +63,3 @@ type MySQLError struct { func (me *MySQLError) Error() string { return fmt.Sprintf("Error %d: %s", me.Number, me.Message) } - -// MySQLWarnings is an error type which represents a group of one or more MySQL -// warnings -type MySQLWarnings []MySQLWarning - -func (mws MySQLWarnings) Error() string { - var msg string - for i, warning := range mws { - if i > 0 { - msg += "\r\n" - } - msg += fmt.Sprintf( - "%s %s: %s", - warning.Level, - warning.Code, - warning.Message, - ) - } - return msg -} - -// MySQLWarning is an error type which represents a single MySQL warning. -// Warnings are returned in groups only. See MySQLWarnings -type MySQLWarning struct { - Level string - Code string - Message string -} - -func (mc *mysqlConn) getWarnings() (err error) { - rows, err := mc.Query("SHOW WARNINGS", nil) - if err != nil { - return - } - - var warnings = MySQLWarnings{} - var values = make([]driver.Value, 3) - - for { - err = rows.Next(values) - switch err { - case nil: - warning := MySQLWarning{} - - if raw, ok := values[0].([]byte); ok { - warning.Level = string(raw) - } else { - warning.Level = fmt.Sprintf("%s", values[0]) - } - if raw, ok := values[1].([]byte); ok { - warning.Code = string(raw) - } else { - warning.Code = fmt.Sprintf("%s", values[1]) - } - if raw, ok := values[2].([]byte); ok { - warning.Message = string(raw) - } else { - warning.Message = fmt.Sprintf("%s", values[0]) - } - - warnings = append(warnings, warning) - - case io.EOF: - return warnings - - default: - rows.Close() - return - } - } -} diff --git a/vendor/github.com/go-sql-driver/mysql/fields.go b/vendor/github.com/go-sql-driver/mysql/fields.go new file mode 100644 index 000000000..e1e2ece4b --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/fields.go @@ -0,0 +1,194 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql" + "reflect" +) + +func (mf *mysqlField) typeDatabaseName() string { + switch mf.fieldType { + case fieldTypeBit: + return "BIT" + case fieldTypeBLOB: + if mf.charSet != collations[binaryCollation] { + return "TEXT" + } + return "BLOB" + case fieldTypeDate: + return "DATE" + case fieldTypeDateTime: + return "DATETIME" + case fieldTypeDecimal: + return "DECIMAL" + case fieldTypeDouble: + return "DOUBLE" + case fieldTypeEnum: + return "ENUM" + case fieldTypeFloat: + return "FLOAT" + case fieldTypeGeometry: + return "GEOMETRY" + case fieldTypeInt24: + return "MEDIUMINT" + case fieldTypeJSON: + return "JSON" + case fieldTypeLong: + return "INT" + case fieldTypeLongBLOB: + if mf.charSet != collations[binaryCollation] { + return "LONGTEXT" + } + return "LONGBLOB" + case fieldTypeLongLong: + return "BIGINT" + case fieldTypeMediumBLOB: + if mf.charSet != collations[binaryCollation] { + return "MEDIUMTEXT" + } + return "MEDIUMBLOB" + case fieldTypeNewDate: + return "DATE" + case fieldTypeNewDecimal: + return "DECIMAL" + case fieldTypeNULL: + return "NULL" + case fieldTypeSet: + return "SET" + case fieldTypeShort: + return "SMALLINT" + case fieldTypeString: + if mf.charSet == collations[binaryCollation] { + return "BINARY" + } + return "CHAR" + case fieldTypeTime: + return "TIME" + case fieldTypeTimestamp: + return "TIMESTAMP" + case fieldTypeTiny: + return "TINYINT" + case fieldTypeTinyBLOB: + if mf.charSet != collations[binaryCollation] { + return "TINYTEXT" + } + return "TINYBLOB" + case fieldTypeVarChar: + if mf.charSet == collations[binaryCollation] { + return "VARBINARY" + } + return "VARCHAR" + case fieldTypeVarString: + if mf.charSet == collations[binaryCollation] { + return "VARBINARY" + } + return "VARCHAR" + case fieldTypeYear: + return "YEAR" + default: + return "" + } +} + +var ( + scanTypeFloat32 = reflect.TypeOf(float32(0)) + scanTypeFloat64 = reflect.TypeOf(float64(0)) + scanTypeInt8 = reflect.TypeOf(int8(0)) + scanTypeInt16 = reflect.TypeOf(int16(0)) + scanTypeInt32 = reflect.TypeOf(int32(0)) + scanTypeInt64 = reflect.TypeOf(int64(0)) + scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) + scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) + scanTypeNullTime = reflect.TypeOf(NullTime{}) + scanTypeUint8 = reflect.TypeOf(uint8(0)) + scanTypeUint16 = reflect.TypeOf(uint16(0)) + scanTypeUint32 = reflect.TypeOf(uint32(0)) + scanTypeUint64 = reflect.TypeOf(uint64(0)) + scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) + scanTypeUnknown = reflect.TypeOf(new(interface{})) +) + +type mysqlField struct { + tableName string + name string + length uint32 + flags fieldFlag + fieldType fieldType + decimals byte + charSet uint8 +} + +func (mf *mysqlField) scanType() reflect.Type { + switch mf.fieldType { + case fieldTypeTiny: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint8 + } + return scanTypeInt8 + } + return scanTypeNullInt + + case fieldTypeShort, fieldTypeYear: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint16 + } + return scanTypeInt16 + } + return scanTypeNullInt + + case fieldTypeInt24, fieldTypeLong: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint32 + } + return scanTypeInt32 + } + return scanTypeNullInt + + case fieldTypeLongLong: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint64 + } + return scanTypeInt64 + } + return scanTypeNullInt + + case fieldTypeFloat: + if mf.flags&flagNotNULL != 0 { + return scanTypeFloat32 + } + return scanTypeNullFloat + + case fieldTypeDouble: + if mf.flags&flagNotNULL != 0 { + return scanTypeFloat64 + } + return scanTypeNullFloat + + case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, + fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, + fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, + fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, + fieldTypeTime: + return scanTypeRawBytes + + case fieldTypeDate, fieldTypeNewDate, + fieldTypeTimestamp, fieldTypeDateTime: + // NullTime is always returned for more consistent behavior as it can + // handle both cases of parseTime regardless if the field is nullable. + return scanTypeNullTime + + default: + return scanTypeUnknown + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/infile.go b/vendor/github.com/go-sql-driver/mysql/infile.go index 547357cfa..273cb0ba5 100644 --- a/vendor/github.com/go-sql-driver/mysql/infile.go +++ b/vendor/github.com/go-sql-driver/mysql/infile.go @@ -147,7 +147,8 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { } // send content packets - if err == nil { + // if packetSize == 0, the Reader contains no data + if err == nil && packetSize > 0 { data := make([]byte, 4+packetSize) var n int for err == nil { @@ -173,8 +174,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { // read OK packet if err == nil { - _, err = mc.readResultOK() - return err + return mc.readResultOK() } mc.readPacket() diff --git a/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/go-sql-driver/mysql/packets.go index aafe9793e..170aaa02b 100644 --- a/vendor/github.com/go-sql-driver/mysql/packets.go +++ b/vendor/github.com/go-sql-driver/mysql/packets.go @@ -30,9 +30,12 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // read packet header data, err := mc.buf.readNext(4) if err != nil { + if cerr := mc.canceled.Value(); cerr != nil { + return nil, cerr + } errLog.Print(err) mc.Close() - return nil, driver.ErrBadConn + return nil, ErrInvalidConn } // packet length [24 bit] @@ -54,7 +57,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if prevData == nil { errLog.Print(ErrMalformPkt) mc.Close() - return nil, driver.ErrBadConn + return nil, ErrInvalidConn } return prevData, nil @@ -63,9 +66,12 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // read packet body [pktLen bytes] data, err = mc.buf.readNext(pktLen) if err != nil { + if cerr := mc.canceled.Value(); cerr != nil { + return nil, cerr + } errLog.Print(err) mc.Close() - return nil, driver.ErrBadConn + return nil, ErrInvalidConn } // return data if this was the last packet @@ -125,33 +131,47 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handle error if err == nil { // n != len(data) + mc.cleanup() errLog.Print(ErrMalformPkt) } else { + if cerr := mc.canceled.Value(); cerr != nil { + return cerr + } + if n == 0 && pktLen == len(data)-4 { + // only for the first loop iteration when nothing was written yet + return errBadConnNoWrite + } + mc.cleanup() errLog.Print(err) } - return driver.ErrBadConn + return ErrInvalidConn } } /****************************************************************************** -* Initialisation Process * +* Initialization Process * ******************************************************************************/ // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readInitPacket() ([]byte, error) { - data, err := mc.readPacket() +func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { + data, err = mc.readPacket() if err != nil { - return nil, err + // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since + // in connection initialization we don't risk retrying non-idempotent actions. + if err == ErrInvalidConn { + return nil, "", driver.ErrBadConn + } + return } if data[0] == iERR { - return nil, mc.handleErrorPacket(data) + return nil, "", mc.handleErrorPacket(data) } // protocol version [1 byte] if data[0] < minProtocolVersion { - return nil, fmt.Errorf( + return nil, "", fmt.Errorf( "unsupported protocol version %d. Version %d or higher is required", data[0], minProtocolVersion, @@ -163,7 +183,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 // first part of the password cipher [8 bytes] - cipher := data[pos : pos+8] + authData := data[pos : pos+8] // (filler) always 0x00 [1 byte] pos += 8 + 1 @@ -171,10 +191,10 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // capability flags (lower 2 bytes) [2 bytes] mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) if mc.flags&clientProtocol41 == 0 { - return nil, ErrOldProtocol + return nil, "", ErrOldProtocol } if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { - return nil, ErrNoTLS + return nil, "", ErrNoTLS } pos += 2 @@ -198,32 +218,32 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // // The official Python library uses the fixed length 12 // which seems to work but technically could have a hidden bug. - cipher = append(cipher, data[pos:pos+12]...) + authData = append(authData, data[pos:pos+12]...) + pos += 13 - // TODO: Verify string termination // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) // \NUL otherwise - // - //if data[len(data)-1] == 0 { - // return - //} - //return ErrMalformPkt + if end := bytes.IndexByte(data[pos:], 0x00); end != -1 { + plugin = string(data[pos : pos+end]) + } else { + plugin = string(data[pos:]) + } // make a memory safe copy of the cipher slice var b [20]byte - copy(b[:], cipher) - return b[:], nil + copy(b[:], authData) + return b[:], plugin, nil } // make a memory safe copy of the cipher slice var b [8]byte - copy(b[:], cipher) - return b[:], nil + copy(b[:], authData) + return b[:], plugin, nil } // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -247,10 +267,19 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientFlags |= clientMultiStatements } - // User Password - scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) + // encode length of the auth plugin data + var authRespLEIBuf [9]byte + authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp))) + if len(authRespLEI) > 1 { + // if the length can not be written in 1 byte, it must be written as a + // length encoded integer + clientFlags |= clientPluginAuthLenEncClientData + } - pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1 + pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 + if addNUL { + pktLen++ + } // To specify a db name if n := len(mc.cfg.DBName); n > 0 { @@ -261,9 +290,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { // Calculate packet length and get buffer with that size data := mc.buf.takeSmallBuffer(pktLen + 4) if data == nil { - // can not take the buffer. Something must be wrong with the connection + // cannot take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // ClientFlags [32 bit] @@ -318,9 +347,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { data[pos] = 0x00 pos++ - // ScrambleBuffer [length encoded integer] - data[pos] = byte(len(scrambleBuff)) - pos += 1 + copy(data[pos+1:], scrambleBuff) + // Auth Data [length encoded integer] + pos += copy(data[pos:], authRespLEI) + pos += copy(data[pos:], authResp) + if addNUL { + data[pos] = 0x00 + pos++ + } // Databasename [null terminated string] if len(mc.cfg.DBName) > 0 { @@ -329,72 +362,32 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { pos++ } - // Assume native client during response - pos += copy(data[pos:], "mysql_native_password") + pos += copy(data[pos:], plugin) data[pos] = 0x00 // Send Auth packet return mc.writePacket(data) } -// Client old authentication packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { - // User password - scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd)) - - // Calculate the packet length and add a tailing 0 - pktLen := len(scrambleBuff) + 1 - data := mc.buf.takeSmallBuffer(4 + pktLen) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn +func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error { + pktLen := 4 + len(authData) + if addNUL { + pktLen++ } - - // Add the scrambled password [null terminated string] - copy(data[4:], scrambleBuff) - data[4+pktLen-1] = 0x00 - - return mc.writePacket(data) -} - -// Client clear text authentication packet -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeClearAuthPacket() error { - // Calculate the packet length and add a tailing 0 - pktLen := len(mc.cfg.Passwd) + 1 - data := mc.buf.takeSmallBuffer(4 + pktLen) + data := mc.buf.takeSmallBuffer(pktLen) if data == nil { - // can not take the buffer. Something must be wrong with the connection + // cannot take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } - // Add the clear password [null terminated string] - copy(data[4:], mc.cfg.Passwd) - data[4+pktLen-1] = 0x00 - - return mc.writePacket(data) -} - -// Native password authentication method -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { - scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) - - // Calculate the packet length and add a tailing 0 - pktLen := len(scrambleBuff) - data := mc.buf.takeSmallBuffer(4 + pktLen) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + // Add the auth data [EOF] + copy(data[4:], authData) + if addNUL { + data[pktLen-1] = 0x00 } - // Add the scramble - copy(data[4:], scrambleBuff) - return mc.writePacket(data) } @@ -408,9 +401,9 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { data := mc.buf.takeSmallBuffer(4 + 1) if data == nil { - // can not take the buffer. Something must be wrong with the connection + // cannot take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // Add command byte @@ -427,9 +420,9 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { pktLen := 1 + len(arg) data := mc.buf.takeBuffer(pktLen + 4) if data == nil { - // can not take the buffer. Something must be wrong with the connection + // cannot take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // Add command byte @@ -448,9 +441,9 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { data := mc.buf.takeSmallBuffer(4 + 1 + 4) if data == nil { - // can not take the buffer. Something must be wrong with the connection + // cannot take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // Add command byte @@ -470,44 +463,50 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { * Result Packets * ******************************************************************************/ -// Returns error if Packet is not an 'Result OK'-Packet -func (mc *mysqlConn) readResultOK() ([]byte, error) { +func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { data, err := mc.readPacket() - if err == nil { - // packet indicator - switch data[0] { + if err != nil { + return nil, "", err + } - case iOK: - return nil, mc.handleOkPacket(data) + // packet indicator + switch data[0] { - case iEOF: - if len(data) > 1 { - pluginEndIndex := bytes.IndexByte(data, 0x00) - plugin := string(data[1:pluginEndIndex]) - cipher := data[pluginEndIndex+1 : len(data)-1] - - if plugin == "mysql_old_password" { - // using old_passwords - return cipher, ErrOldPassword - } else if plugin == "mysql_clear_password" { - // using clear text password - return cipher, ErrCleartextPassword - } else if plugin == "mysql_native_password" { - // using mysql default authentication method - return cipher, ErrNativePassword - } else { - return cipher, ErrUnknownPlugin - } - } else { - // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest - return nil, ErrOldPassword - } + case iOK: + return nil, "", mc.handleOkPacket(data) - default: // Error otherwise - return nil, mc.handleErrorPacket(data) + case iAuthMoreData: + return data[1:], "", err + + case iEOF: + if len(data) < 1 { + // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest + return nil, "mysql_old_password", nil + } + pluginEndIndex := bytes.IndexByte(data, 0x00) + if pluginEndIndex < 0 { + return nil, "", ErrMalformPkt } + plugin := string(data[1:pluginEndIndex]) + authData := data[pluginEndIndex+1:] + return authData, plugin, nil + + default: // Error otherwise + return nil, "", mc.handleErrorPacket(data) + } +} + +// Returns error if Packet is not an 'Result OK'-Packet +func (mc *mysqlConn) readResultOK() error { + data, err := mc.readPacket() + if err != nil { + return err + } + + if data[0] == iOK { + return mc.handleOkPacket(data) } - return nil, err + return mc.handleErrorPacket(data) } // Result Set Header Packet @@ -550,6 +549,22 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { // Error Number [16 bit uint] errno := binary.LittleEndian.Uint16(data[1:3]) + // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION + // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) + if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { + // Oops; we are connected to a read-only connection, and won't be able + // to issue any write statements. Since RejectReadOnly is configured, + // we throw away this connection hoping this one would have write + // permission. This is specifically for a possible race condition + // during failover (e.g. on AWS Aurora). See README.md for more. + // + // We explicitly close the connection before returning + // driver.ErrBadConn to ensure that `database/sql` purges this + // connection and initiates a new one for next statement next time. + mc.Close() + return driver.ErrBadConn + } + pos := 3 // SQL State [optional: # + 5bytes string] @@ -584,19 +599,12 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { // server_status [2 bytes] mc.status = readStatus(data[1+n+m : 1+n+m+2]) - if err := mc.discardResults(); err != nil { - return err + if mc.status&statusMoreResultsExists != 0 { + return nil } // warning count [2 bytes] - if !mc.strict { - return nil - } - pos := 1 + n + m + 2 - if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { - return mc.getWarnings() - } return nil } @@ -668,14 +676,21 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { if err != nil { return nil, err } + pos += n // Filler [uint8] + pos++ + // Charset [charset, collation uint8] + columns[i].charSet = data[pos] + pos += 2 + // Length [uint32] - pos += n + 1 + 2 + 4 + columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) + pos += 4 // Field type [uint8] - columns[i].fieldType = data[pos] + columns[i].fieldType = fieldType(data[pos]) pos++ // Flags [uint16] @@ -698,6 +713,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { func (rows *textRows) readRow(dest []driver.Value) error { mc := rows.mc + if rows.rs.done { + return io.EOF + } + data, err := mc.readPacket() if err != nil { return err @@ -707,15 +726,11 @@ func (rows *textRows) readRow(dest []driver.Value) error { if data[0] == iEOF && len(data) == 5 { // server_status [2 bytes] rows.mc.status = readStatus(data[3:]) - err = rows.mc.discardResults() - if err == nil { - err = io.EOF - } else { - // connection unusable - rows.mc.Close() + rows.rs.done = true + if !rows.HasNextResultSet() { + rows.mc = nil } - rows.mc = nil - return err + return io.EOF } if data[0] == iERR { rows.mc = nil @@ -736,7 +751,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { if !mc.parseTime { continue } else { - switch rows.columns[i].fieldType { + switch rows.rs.columns[i].fieldType { case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeDate, fieldTypeNewDate: dest[i], err = parseDateTime( @@ -808,14 +823,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { // Reserved [8 bit] // Warning count [16 bit uint] - if !stmt.mc.strict { - return columnCount, nil - } - // Check for warnings count > 0, only available in MySQL > 4.1 - if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { - return columnCount, stmt.mc.getWarnings() - } return columnCount, nil } return 0, err @@ -832,7 +840,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { // 2 bytes paramID const dataOffset = 1 + 4 + 2 - // Can not use the write buffer since + // Cannot use the write buffer since // a) the buffer is too small // b) it is in use data := make([]byte, 4+1+4+2+len(arg)) @@ -887,6 +895,12 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { const minPktLen = 4 + 1 + 4 + 1 + 4 mc := stmt.mc + // Determine threshould dynamically to avoid packet size shortage. + longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) + if longDataSize < 64 { + longDataSize = 64 + } + // Reset packet-sequence mc.sequence = 0 @@ -898,9 +912,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { data = mc.buf.takeCompleteBuffer() } if data == nil { - // can not take the buffer. Something must be wrong with the connection + // cannot take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // command [1 byte] @@ -959,7 +973,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // build NULL-bitmap if arg == nil { nullMask[i/8] |= 1 << (uint(i) & 7) - paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i] = byte(fieldTypeNULL) paramTypes[i+i+1] = 0x00 continue } @@ -967,7 +981,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // cache types and values switch v := arg.(type) { case int64: - paramTypes[i+i] = fieldTypeLongLong + paramTypes[i+i] = byte(fieldTypeLongLong) paramTypes[i+i+1] = 0x00 if cap(paramValues)-len(paramValues)-8 >= 0 { @@ -983,7 +997,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } case float64: - paramTypes[i+i] = fieldTypeDouble + paramTypes[i+i] = byte(fieldTypeDouble) paramTypes[i+i+1] = 0x00 if cap(paramValues)-len(paramValues)-8 >= 0 { @@ -999,7 +1013,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } case bool: - paramTypes[i+i] = fieldTypeTiny + paramTypes[i+i] = byte(fieldTypeTiny) paramTypes[i+i+1] = 0x00 if v { @@ -1011,10 +1025,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { case []byte: // Common case (non-nil value) first if v != nil { - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) @@ -1029,14 +1043,14 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // Handle []byte(nil) as a NULL value nullMask[i/8] |= 1 << (uint(i) & 7) - paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i] = byte(fieldTypeNULL) paramTypes[i+i+1] = 0x00 case string: - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) @@ -1048,23 +1062,25 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } case time.Time: - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - var val []byte + var a [64]byte + var b = a[:0] + if v.IsZero() { - val = []byte("0000-00-00") + b = append(b, "0000-00-00"...) } else { - val = []byte(v.In(mc.cfg.Loc).Format(timeFormat)) + b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat) } paramValues = appendLengthEncodedInteger(paramValues, - uint64(len(val)), + uint64(len(b)), ) - paramValues = append(paramValues, val...) + paramValues = append(paramValues, b...) default: - return fmt.Errorf("can not convert type: %T", arg) + return fmt.Errorf("cannot convert type: %T", arg) } } @@ -1097,8 +1113,6 @@ func (mc *mysqlConn) discardResults() error { if err := mc.readUntilEOF(); err != nil { return err } - } else { - mc.status &^= statusMoreResultsExists } } return nil @@ -1116,20 +1130,17 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // EOF Packet if data[0] == iEOF && len(data) == 5 { rows.mc.status = readStatus(data[3:]) - err = rows.mc.discardResults() - if err == nil { - err = io.EOF - } else { - // connection unusable - rows.mc.Close() + rows.rs.done = true + if !rows.HasNextResultSet() { + rows.mc = nil } - rows.mc = nil - return err + return io.EOF } + mc := rows.mc rows.mc = nil // Error otherwise - return rows.mc.handleErrorPacket(data) + return mc.handleErrorPacket(data) } // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] @@ -1145,14 +1156,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { } // Convert to byte-coded string - switch rows.columns[i].fieldType { + switch rows.rs.columns[i].fieldType { case fieldTypeNULL: dest[i] = nil continue // Numeric Types case fieldTypeTiny: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(data[pos]) } else { dest[i] = int64(int8(data[pos])) @@ -1161,7 +1172,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeShort, fieldTypeYear: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) } else { dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) @@ -1170,7 +1181,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeInt24, fieldTypeLong: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) } else { dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) @@ -1179,7 +1190,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeLongLong: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { val := binary.LittleEndian.Uint64(data[pos : pos+8]) if val > math.MaxInt64 { dest[i] = uint64ToString(val) @@ -1193,7 +1204,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeFloat: - dest[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))) + dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])) pos += 4 continue @@ -1233,10 +1244,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { case isNull: dest[i] = nil continue - case rows.columns[i].fieldType == fieldTypeTime: + case rows.rs.columns[i].fieldType == fieldTypeTime: // database/sql does not support an equivalent to TIME, return a string var dstlen uint8 - switch decimals := rows.columns[i].decimals; decimals { + switch decimals := rows.rs.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 8 case 1, 2, 3, 4, 5, 6: @@ -1244,18 +1255,18 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { default: return fmt.Errorf( "protocol error, illegal decimals value %d", - rows.columns[i].decimals, + rows.rs.columns[i].decimals, ) } - dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) + dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen) case rows.mc.parseTime: dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) default: var dstlen uint8 - if rows.columns[i].fieldType == fieldTypeDate { + if rows.rs.columns[i].fieldType == fieldTypeDate { dstlen = 10 } else { - switch decimals := rows.columns[i].decimals; decimals { + switch decimals := rows.rs.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 19 case 1, 2, 3, 4, 5, 6: @@ -1263,11 +1274,11 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { default: return fmt.Errorf( "protocol error, illegal decimals value %d", - rows.columns[i].decimals, + rows.rs.columns[i].decimals, ) } } - dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false) + dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen) } if err == nil { @@ -1279,7 +1290,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // Please report if this happens! default: - return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType) + return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType) } } diff --git a/vendor/github.com/go-sql-driver/mysql/rows.go b/vendor/github.com/go-sql-driver/mysql/rows.go index c08255eee..d3b1e2822 100644 --- a/vendor/github.com/go-sql-driver/mysql/rows.go +++ b/vendor/github.com/go-sql-driver/mysql/rows.go @@ -11,19 +11,20 @@ package mysql import ( "database/sql/driver" "io" + "math" + "reflect" ) -type mysqlField struct { - tableName string - name string - flags fieldFlag - fieldType byte - decimals byte +type resultSet struct { + columns []mysqlField + columnNames []string + done bool } type mysqlRows struct { - mc *mysqlConn - columns []mysqlField + mc *mysqlConn + rs resultSet + finish func() } type binaryRows struct { @@ -34,37 +35,86 @@ type textRows struct { mysqlRows } -type emptyRows struct{} - func (rows *mysqlRows) Columns() []string { - columns := make([]string, len(rows.columns)) + if rows.rs.columnNames != nil { + return rows.rs.columnNames + } + + columns := make([]string, len(rows.rs.columns)) if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { for i := range columns { - if tableName := rows.columns[i].tableName; len(tableName) > 0 { - columns[i] = tableName + "." + rows.columns[i].name + if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 { + columns[i] = tableName + "." + rows.rs.columns[i].name } else { - columns[i] = rows.columns[i].name + columns[i] = rows.rs.columns[i].name } } } else { for i := range columns { - columns[i] = rows.columns[i].name + columns[i] = rows.rs.columns[i].name } } + + rows.rs.columnNames = columns return columns } -func (rows *mysqlRows) Close() error { +func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string { + return rows.rs.columns[i].typeDatabaseName() +} + +// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) { +// return int64(rows.rs.columns[i].length), true +// } + +func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) { + return rows.rs.columns[i].flags&flagNotNULL == 0, true +} + +func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) { + column := rows.rs.columns[i] + decimals := int64(column.decimals) + + switch column.fieldType { + case fieldTypeDecimal, fieldTypeNewDecimal: + if decimals > 0 { + return int64(column.length) - 2, decimals, true + } + return int64(column.length) - 1, decimals, true + case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime: + return decimals, decimals, true + case fieldTypeFloat, fieldTypeDouble: + if decimals == 0x1f { + return math.MaxInt64, math.MaxInt64, true + } + return math.MaxInt64, decimals, true + } + + return 0, 0, false +} + +func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type { + return rows.rs.columns[i].scanType() +} + +func (rows *mysqlRows) Close() (err error) { + if f := rows.finish; f != nil { + f() + rows.finish = nil + } + mc := rows.mc if mc == nil { return nil } - if mc.netConn == nil { - return ErrInvalidConn + if err := mc.error(); err != nil { + return err } // Remove unread packets from stream - err := mc.readUntilEOF() + if !rows.rs.done { + err = mc.readUntilEOF() + } if err == nil { if err = mc.discardResults(); err != nil { return err @@ -75,22 +125,66 @@ func (rows *mysqlRows) Close() error { return err } -func (rows *binaryRows) Next(dest []driver.Value) error { - if mc := rows.mc; mc != nil { - if mc.netConn == nil { - return ErrInvalidConn +func (rows *mysqlRows) HasNextResultSet() (b bool) { + if rows.mc == nil { + return false + } + return rows.mc.status&statusMoreResultsExists != 0 +} + +func (rows *mysqlRows) nextResultSet() (int, error) { + if rows.mc == nil { + return 0, io.EOF + } + if err := rows.mc.error(); err != nil { + return 0, err + } + + // Remove unread packets from stream + if !rows.rs.done { + if err := rows.mc.readUntilEOF(); err != nil { + return 0, err } + rows.rs.done = true + } - // Fetch next row from stream - return rows.readRow(dest) + if !rows.HasNextResultSet() { + rows.mc = nil + return 0, io.EOF } - return io.EOF + rows.rs = resultSet{} + return rows.mc.readResultSetHeaderPacket() } -func (rows *textRows) Next(dest []driver.Value) error { +func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { + for { + resLen, err := rows.nextResultSet() + if err != nil { + return 0, err + } + + if resLen > 0 { + return resLen, nil + } + + rows.rs.done = true + } +} + +func (rows *binaryRows) NextResultSet() error { + resLen, err := rows.nextNotEmptyResultSet() + if err != nil { + return err + } + + rows.rs.columns, err = rows.mc.readColumns(resLen) + return err +} + +func (rows *binaryRows) Next(dest []driver.Value) error { if mc := rows.mc; mc != nil { - if mc.netConn == nil { - return ErrInvalidConn + if err := mc.error(); err != nil { + return err } // Fetch next row from stream @@ -99,14 +193,24 @@ func (rows *textRows) Next(dest []driver.Value) error { return io.EOF } -func (rows emptyRows) Columns() []string { - return nil -} +func (rows *textRows) NextResultSet() (err error) { + resLen, err := rows.nextNotEmptyResultSet() + if err != nil { + return err + } -func (rows emptyRows) Close() error { - return nil + rows.rs.columns, err = rows.mc.readColumns(resLen) + return err } -func (rows emptyRows) Next(dest []driver.Value) error { +func (rows *textRows) Next(dest []driver.Value) error { + if mc := rows.mc; mc != nil { + if err := mc.error(); err != nil { + return err + } + + // Fetch next row from stream + return rows.readRow(dest) + } return io.EOF } diff --git a/vendor/github.com/go-sql-driver/mysql/statement.go b/vendor/github.com/go-sql-driver/mysql/statement.go index 7f9b04585..ce7fe4cd0 100644 --- a/vendor/github.com/go-sql-driver/mysql/statement.go +++ b/vendor/github.com/go-sql-driver/mysql/statement.go @@ -11,6 +11,7 @@ package mysql import ( "database/sql/driver" "fmt" + "io" "reflect" "strconv" ) @@ -19,11 +20,10 @@ type mysqlStmt struct { mc *mysqlConn id uint32 paramCount int - columns []mysqlField // cached from the first query } func (stmt *mysqlStmt) Close() error { - if stmt.mc == nil || stmt.mc.netConn == nil { + if stmt.mc == nil || stmt.mc.closed.IsSet() { // driver.Stmt.Close can be called more than once, thus this function // has to be idempotent. // See also Issue #450 and golang/go#16019. @@ -45,14 +45,14 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { } func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { - if stmt.mc.netConn == nil { + if stmt.mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := stmt.writeExecutePacket(args) if err != nil { - return nil, err + return nil, stmt.mc.markBadConn(err) } mc := stmt.mc @@ -62,37 +62,45 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { // Read Result resLen, err := mc.readResultSetHeaderPacket() - if err == nil { - if resLen > 0 { - // Columns - err = mc.readUntilEOF() - if err != nil { - return nil, err - } - - // Rows - err = mc.readUntilEOF() + if err != nil { + return nil, err + } + + if resLen > 0 { + // Columns + if err = mc.readUntilEOF(); err != nil { + return nil, err } - if err == nil { - return &mysqlResult{ - affectedRows: int64(mc.affectedRows), - insertId: int64(mc.insertId), - }, nil + + // Rows + if err := mc.readUntilEOF(); err != nil { + return nil, err } } - return nil, err + if err := mc.discardResults(); err != nil { + return nil, err + } + + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, nil } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { - if stmt.mc.netConn == nil { + return stmt.query(args) +} + +func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { + if stmt.mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := stmt.writeExecutePacket(args) if err != nil { - return nil, err + return nil, stmt.mc.markBadConn(err) } mc := stmt.mc @@ -107,14 +115,15 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { if resLen > 0 { rows.mc = mc - // Columns - // If not cached, read them and cache them - if stmt.columns == nil { - rows.columns, err = mc.readColumns(resLen) - stmt.columns = rows.columns - } else { - rows.columns = stmt.columns - err = mc.readUntilEOF() + rows.rs.columns, err = mc.readColumns(resLen) + } else { + rows.rs.done = true + + switch err := rows.NextResultSet(); err { + case nil, io.EOF: + return rows, nil + default: + return nil, err } } @@ -123,19 +132,36 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { type converter struct{} +// ConvertValue mirrors the reference/default converter in database/sql/driver +// with _one_ exception. We support uint64 with their high bit and the default +// implementation does not. This function should be kept in sync with +// database/sql/driver defaultConverter.ConvertValue() except for that +// deliberate difference. func (c converter) ConvertValue(v interface{}) (driver.Value, error) { if driver.IsValue(v) { return v, nil } + if vr, ok := v.(driver.Valuer); ok { + sv, err := callValuerValue(vr) + if err != nil { + return nil, err + } + if !driver.IsValue(sv) { + return nil, fmt.Errorf("non-Value type %T returned from Value", sv) + } + return sv, nil + } + rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Ptr: // indirect pointers if rv.IsNil() { return nil, nil + } else { + return c.ConvertValue(rv.Elem().Interface()) } - return c.ConvertValue(rv.Elem().Interface()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return rv.Int(), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: @@ -148,6 +174,38 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { return int64(u64), nil case reflect.Float32, reflect.Float64: return rv.Float(), nil + case reflect.Bool: + return rv.Bool(), nil + case reflect.Slice: + ek := rv.Type().Elem().Kind() + if ek == reflect.Uint8 { + return rv.Bytes(), nil + } + return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) + case reflect.String: + return rv.String(), nil } return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) } + +var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + +// callValuerValue returns vr.Value(), with one exception: +// If vr.Value is an auto-generated method on a pointer type and the +// pointer is nil, it would panic at runtime in the panicwrap +// method. Treat it like nil instead. +// +// This is so people can implement driver.Value on value types and +// still use nil pointers to those types to mean nil/NULL, just like +// string/*string. +// +// This is an exact copy of the same-named unexported function from the +// database/sql package. +func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { + if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && + rv.IsNil() && + rv.Type().Elem().Implements(valuerReflectType) { + return nil, nil + } + return vr.Value() +} diff --git a/vendor/github.com/go-sql-driver/mysql/transaction.go b/vendor/github.com/go-sql-driver/mysql/transaction.go index 33c749b35..417d72793 100644 --- a/vendor/github.com/go-sql-driver/mysql/transaction.go +++ b/vendor/github.com/go-sql-driver/mysql/transaction.go @@ -13,7 +13,7 @@ type mysqlTx struct { } func (tx *mysqlTx) Commit() (err error) { - if tx.mc == nil || tx.mc.netConn == nil { + if tx.mc == nil || tx.mc.closed.IsSet() { return ErrInvalidConn } err = tx.mc.exec("COMMIT") @@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) { } func (tx *mysqlTx) Rollback() (err error) { - if tx.mc == nil || tx.mc.netConn == nil { + if tx.mc == nil || tx.mc.closed.IsSet() { return ErrInvalidConn } err = tx.mc.exec("ROLLBACK") diff --git a/vendor/github.com/go-sql-driver/mysql/utils.go b/vendor/github.com/go-sql-driver/mysql/utils.go index d523b7ffd..cb3650bb9 100644 --- a/vendor/github.com/go-sql-driver/mysql/utils.go +++ b/vendor/github.com/go-sql-driver/mysql/utils.go @@ -9,23 +9,32 @@ package mysql import ( - "crypto/sha1" "crypto/tls" + "database/sql" "database/sql/driver" "encoding/binary" + "errors" "fmt" "io" + "strconv" "strings" + "sync" + "sync/atomic" "time" ) +// Registry for custom tls.Configs var ( - tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs + tlsConfigLock sync.RWMutex + tlsConfigRegistry map[string]*tls.Config ) // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. // Use the key as a value in the DSN where tls=value. // +// Note: The provided tls.Config is exclusively owned by the driver after +// registering it. +// // rootCertPool := x509.NewCertPool() // pem, err := ioutil.ReadFile("/path/ca-cert.pem") // if err != nil { @@ -51,19 +60,32 @@ func RegisterTLSConfig(key string, config *tls.Config) error { return fmt.Errorf("key '%s' is reserved", key) } - if tlsConfigRegister == nil { - tlsConfigRegister = make(map[string]*tls.Config) + tlsConfigLock.Lock() + if tlsConfigRegistry == nil { + tlsConfigRegistry = make(map[string]*tls.Config) } - tlsConfigRegister[key] = config + tlsConfigRegistry[key] = config + tlsConfigLock.Unlock() return nil } // DeregisterTLSConfig removes the tls.Config associated with key. func DeregisterTLSConfig(key string) { - if tlsConfigRegister != nil { - delete(tlsConfigRegister, key) + tlsConfigLock.Lock() + if tlsConfigRegistry != nil { + delete(tlsConfigRegistry, key) + } + tlsConfigLock.Unlock() +} + +func getTLSConfigClone(key string) (config *tls.Config) { + tlsConfigLock.RLock() + if v, ok := tlsConfigRegistry[key]; ok { + config = v.Clone() } + tlsConfigLock.RUnlock() + return } // Returns the bool value of the input. @@ -80,119 +102,6 @@ func readBool(input string) (value bool, valid bool) { return } -/****************************************************************************** -* Authentication * -******************************************************************************/ - -// Encrypt password using 4.1+ method -func scramblePassword(scramble, password []byte) []byte { - if len(password) == 0 { - return nil - } - - // stage1Hash = SHA1(password) - crypt := sha1.New() - crypt.Write(password) - stage1 := crypt.Sum(nil) - - // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) - // inner Hash - crypt.Reset() - crypt.Write(stage1) - hash := crypt.Sum(nil) - - // outer Hash - crypt.Reset() - crypt.Write(scramble) - crypt.Write(hash) - scramble = crypt.Sum(nil) - - // token = scrambleHash XOR stage1Hash - for i := range scramble { - scramble[i] ^= stage1[i] - } - return scramble -} - -// Encrypt password using pre 4.1 (old password) method -// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c -type myRnd struct { - seed1, seed2 uint32 -} - -const myRndMaxVal = 0x3FFFFFFF - -// Pseudo random number generator -func newMyRnd(seed1, seed2 uint32) *myRnd { - return &myRnd{ - seed1: seed1 % myRndMaxVal, - seed2: seed2 % myRndMaxVal, - } -} - -// Tested to be equivalent to MariaDB's floating point variant -// http://play.golang.org/p/QHvhd4qved -// http://play.golang.org/p/RG0q4ElWDx -func (r *myRnd) NextByte() byte { - r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal - r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal - - return byte(uint64(r.seed1) * 31 / myRndMaxVal) -} - -// Generate binary hash from byte string using insecure pre 4.1 method -func pwHash(password []byte) (result [2]uint32) { - var add uint32 = 7 - var tmp uint32 - - result[0] = 1345345333 - result[1] = 0x12345671 - - for _, c := range password { - // skip spaces and tabs in password - if c == ' ' || c == '\t' { - continue - } - - tmp = uint32(c) - result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) - result[1] += (result[1] << 8) ^ result[0] - add += tmp - } - - // Remove sign bit (1<<31)-1) - result[0] &= 0x7FFFFFFF - result[1] &= 0x7FFFFFFF - - return -} - -// Encrypt password using insecure pre 4.1 method -func scrambleOldPassword(scramble, password []byte) []byte { - if len(password) == 0 { - return nil - } - - scramble = scramble[:8] - - hashPw := pwHash(password) - hashSc := pwHash(scramble) - - r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) - - var out [8]byte - for i := range out { - out[i] = r.NextByte() + 64 - } - - mask := r.NextByte() - for i := range out { - out[i] ^= mask - } - - return out[:] -} - /****************************************************************************** * Time related utils * ******************************************************************************/ @@ -321,139 +230,154 @@ var zeroDateTime = []byte("0000-00-00 00:00:00.000000") const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" -func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value, error) { - // length expects the deterministic length of the zero value, - // negative time and 100+ hours are automatically added if needed - if len(src) == 0 { - if justTime { - return zeroDateTime[11 : 11+length], nil - } - return zeroDateTime[:length], nil - } - var dst []byte // return value - var pt, p1, p2, p3 byte // current digit pair - var zOffs byte // offset of value in zeroDateTime - if justTime { - switch length { - case - 8, // time (can be up to 10 when negative and 100+ hours) - 10, 11, 12, 13, 14, 15: // time with fractional seconds - default: - return nil, fmt.Errorf("illegal TIME length %d", length) - } - switch len(src) { - case 8, 12: - default: - return nil, fmt.Errorf("invalid TIME packet length %d", len(src)) - } - // +2 to enable negative time and 100+ hours - dst = make([]byte, 0, length+2) - if src[0] == 1 { - dst = append(dst, '-') - } - if src[1] != 0 { - hour := uint16(src[1])*24 + uint16(src[5]) - pt = byte(hour / 100) - p1 = byte(hour - 100*uint16(pt)) - dst = append(dst, digits01[pt]) - } else { - p1 = src[5] - } - zOffs = 11 - src = src[6:] - } else { - switch length { - case 10, 19, 21, 22, 23, 24, 25, 26: - default: - t := "DATE" - if length > 10 { - t += "TIME" - } - return nil, fmt.Errorf("illegal %s length %d", t, length) - } - switch len(src) { - case 4, 7, 11: - default: - t := "DATE" - if length > 10 { - t += "TIME" - } - return nil, fmt.Errorf("illegal %s packet length %d", t, len(src)) - } - dst = make([]byte, 0, length) - // start with the date - year := binary.LittleEndian.Uint16(src[:2]) - pt = byte(year / 100) - p1 = byte(year - 100*uint16(pt)) - p2, p3 = src[2], src[3] - dst = append(dst, - digits10[pt], digits01[pt], - digits10[p1], digits01[p1], '-', - digits10[p2], digits01[p2], '-', - digits10[p3], digits01[p3], - ) - if length == 10 { - return dst, nil - } - if len(src) == 4 { - return append(dst, zeroDateTime[10:length]...), nil - } - dst = append(dst, ' ') - p1 = src[4] // hour - src = src[5:] +func appendMicrosecs(dst, src []byte, decimals int) []byte { + if decimals <= 0 { + return dst } - // p1 is 2-digit hour, src is after hour - p2, p3 = src[0], src[1] - dst = append(dst, - digits10[p1], digits01[p1], ':', - digits10[p2], digits01[p2], ':', - digits10[p3], digits01[p3], - ) - if length <= byte(len(dst)) { - return dst, nil - } - src = src[2:] if len(src) == 0 { - return append(dst, zeroDateTime[19:zOffs+length]...), nil + return append(dst, ".000000"[:decimals+1]...) } + microsecs := binary.LittleEndian.Uint32(src[:4]) - p1 = byte(microsecs / 10000) + p1 := byte(microsecs / 10000) microsecs -= 10000 * uint32(p1) - p2 = byte(microsecs / 100) + p2 := byte(microsecs / 100) microsecs -= 100 * uint32(p2) - p3 = byte(microsecs) - switch decimals := zOffs + length - 20; decimals { + p3 := byte(microsecs) + + switch decimals { default: return append(dst, '.', digits10[p1], digits01[p1], digits10[p2], digits01[p2], digits10[p3], digits01[p3], - ), nil + ) case 1: return append(dst, '.', digits10[p1], - ), nil + ) case 2: return append(dst, '.', digits10[p1], digits01[p1], - ), nil + ) case 3: return append(dst, '.', digits10[p1], digits01[p1], digits10[p2], - ), nil + ) case 4: return append(dst, '.', digits10[p1], digits01[p1], digits10[p2], digits01[p2], - ), nil + ) case 5: return append(dst, '.', digits10[p1], digits01[p1], digits10[p2], digits01[p2], digits10[p3], - ), nil + ) + } +} + +func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) { + // length expects the deterministic length of the zero value, + // negative time and 100+ hours are automatically added if needed + if len(src) == 0 { + return zeroDateTime[:length], nil + } + var dst []byte // return value + var p1, p2, p3 byte // current digit pair + + switch length { + case 10, 19, 21, 22, 23, 24, 25, 26: + default: + t := "DATE" + if length > 10 { + t += "TIME" + } + return nil, fmt.Errorf("illegal %s length %d", t, length) + } + switch len(src) { + case 4, 7, 11: + default: + t := "DATE" + if length > 10 { + t += "TIME" + } + return nil, fmt.Errorf("illegal %s packet length %d", t, len(src)) + } + dst = make([]byte, 0, length) + // start with the date + year := binary.LittleEndian.Uint16(src[:2]) + pt := year / 100 + p1 = byte(year - 100*uint16(pt)) + p2, p3 = src[2], src[3] + dst = append(dst, + digits10[pt], digits01[pt], + digits10[p1], digits01[p1], '-', + digits10[p2], digits01[p2], '-', + digits10[p3], digits01[p3], + ) + if length == 10 { + return dst, nil + } + if len(src) == 4 { + return append(dst, zeroDateTime[10:length]...), nil + } + dst = append(dst, ' ') + p1 = src[4] // hour + src = src[5:] + + // p1 is 2-digit hour, src is after hour + p2, p3 = src[0], src[1] + dst = append(dst, + digits10[p1], digits01[p1], ':', + digits10[p2], digits01[p2], ':', + digits10[p3], digits01[p3], + ) + return appendMicrosecs(dst, src[2:], int(length)-20), nil +} + +func formatBinaryTime(src []byte, length uint8) (driver.Value, error) { + // length expects the deterministic length of the zero value, + // negative time and 100+ hours are automatically added if needed + if len(src) == 0 { + return zeroDateTime[11 : 11+length], nil + } + var dst []byte // return value + + switch length { + case + 8, // time (can be up to 10 when negative and 100+ hours) + 10, 11, 12, 13, 14, 15: // time with fractional seconds + default: + return nil, fmt.Errorf("illegal TIME length %d", length) + } + switch len(src) { + case 8, 12: + default: + return nil, fmt.Errorf("invalid TIME packet length %d", len(src)) } + // +2 to enable negative time and 100+ hours + dst = make([]byte, 0, length+2) + if src[0] == 1 { + dst = append(dst, '-') + } + days := binary.LittleEndian.Uint32(src[1:5]) + hours := int64(days)*24 + int64(src[5]) + + if hours >= 100 { + dst = strconv.AppendInt(dst, hours, 10) + } else { + dst = append(dst, digits10[hours], digits01[hours]) + } + + min, sec := src[6], src[7] + dst = append(dst, ':', + digits10[min], digits01[min], ':', + digits10[sec], digits01[sec], + ) + return appendMicrosecs(dst, src[8:], int(length)-9), nil } /****************************************************************************** @@ -519,7 +443,7 @@ func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { // Check data length if len(b) >= n { - return b[n-int(num) : n], false, n, nil + return b[n-int(num) : n : n], false, n, nil } return nil, false, n, io.EOF } @@ -548,8 +472,8 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { if len(b) == 0 { return 0, true, 1 } - switch b[0] { + switch b[0] { // 251: NULL case 0xfb: return 0, true, 1 @@ -738,3 +662,94 @@ func escapeStringQuotes(buf []byte, v string) []byte { return buf[:pos] } + +/****************************************************************************** +* Sync utils * +******************************************************************************/ + +// noCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://github.com/golang/go/issues/8005#issuecomment-190753527 +// for details. +type noCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*noCopy) Lock() {} + +// atomicBool is a wrapper around uint32 for usage as a boolean value with +// atomic access. +type atomicBool struct { + _noCopy noCopy + value uint32 +} + +// IsSet returns wether the current boolean value is true +func (ab *atomicBool) IsSet() bool { + return atomic.LoadUint32(&ab.value) > 0 +} + +// Set sets the value of the bool regardless of the previous value +func (ab *atomicBool) Set(value bool) { + if value { + atomic.StoreUint32(&ab.value, 1) + } else { + atomic.StoreUint32(&ab.value, 0) + } +} + +// TrySet sets the value of the bool and returns wether the value changed +func (ab *atomicBool) TrySet(value bool) bool { + if value { + return atomic.SwapUint32(&ab.value, 1) == 0 + } + return atomic.SwapUint32(&ab.value, 0) > 0 +} + +// atomicError is a wrapper for atomically accessed error values +type atomicError struct { + _noCopy noCopy + value atomic.Value +} + +// Set sets the error value regardless of the previous value. +// The value must not be nil +func (ae *atomicError) Set(value error) { + ae.value.Store(value) +} + +// Value returns the current error value +func (ae *atomicError) Value() error { + if v := ae.value.Load(); v != nil { + // this will panic if the value doesn't implement the error interface + return v.(error) + } + return nil +} + +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + // TODO: support the use of Named Parameters #561 + return nil, errors.New("mysql: driver does not support the use of Named Parameters") + } + dargs[n] = param.Value + } + return dargs, nil +} + +func mapIsolationLevel(level driver.IsolationLevel) (string, error) { + switch sql.IsolationLevel(level) { + case sql.LevelRepeatableRead: + return "REPEATABLE READ", nil + case sql.LevelReadCommitted: + return "READ COMMITTED", nil + case sql.LevelReadUncommitted: + return "READ UNCOMMITTED", nil + case sql.LevelSerializable: + return "SERIALIZABLE", nil + default: + return "", fmt.Errorf("mysql: unsupported isolation level: %v", level) + } +} diff --git a/vendor/github.com/shopspring/decimal/decimal-go.go b/vendor/github.com/shopspring/decimal/decimal-go.go new file mode 100644 index 000000000..e08a15ce4 --- /dev/null +++ b/vendor/github.com/shopspring/decimal/decimal-go.go @@ -0,0 +1,414 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Multiprecision decimal numbers. +// For floating-point formatting only; not general purpose. +// Only operations are assign and (binary) left/right shift. +// Can do binary floating point in multiprecision decimal precisely +// because 2 divides 10; cannot do decimal floating point +// in multiprecision binary precisely. +package decimal + +type decimal struct { + d [800]byte // digits, big-endian representation + nd int // number of digits used + dp int // decimal point + neg bool // negative flag + trunc bool // discarded nonzero digits beyond d[:nd] +} + +func (a *decimal) String() string { + n := 10 + a.nd + if a.dp > 0 { + n += a.dp + } + if a.dp < 0 { + n += -a.dp + } + + buf := make([]byte, n) + w := 0 + switch { + case a.nd == 0: + return "0" + + case a.dp <= 0: + // zeros fill space between decimal point and digits + buf[w] = '0' + w++ + buf[w] = '.' + w++ + w += digitZero(buf[w : w+-a.dp]) + w += copy(buf[w:], a.d[0:a.nd]) + + case a.dp < a.nd: + // decimal point in middle of digits + w += copy(buf[w:], a.d[0:a.dp]) + buf[w] = '.' + w++ + w += copy(buf[w:], a.d[a.dp:a.nd]) + + default: + // zeros fill space between digits and decimal point + w += copy(buf[w:], a.d[0:a.nd]) + w += digitZero(buf[w : w+a.dp-a.nd]) + } + return string(buf[0:w]) +} + +func digitZero(dst []byte) int { + for i := range dst { + dst[i] = '0' + } + return len(dst) +} + +// trim trailing zeros from number. +// (They are meaningless; the decimal point is tracked +// independent of the number of digits.) +func trim(a *decimal) { + for a.nd > 0 && a.d[a.nd-1] == '0' { + a.nd-- + } + if a.nd == 0 { + a.dp = 0 + } +} + +// Assign v to a. +func (a *decimal) Assign(v uint64) { + var buf [24]byte + + // Write reversed decimal in buf. + n := 0 + for v > 0 { + v1 := v / 10 + v -= 10 * v1 + buf[n] = byte(v + '0') + n++ + v = v1 + } + + // Reverse again to produce forward decimal in a.d. + a.nd = 0 + for n--; n >= 0; n-- { + a.d[a.nd] = buf[n] + a.nd++ + } + a.dp = a.nd + trim(a) +} + +// Maximum shift that we can do in one pass without overflow. +// A uint has 32 or 64 bits, and we have to be able to accommodate 9<> 63) +const maxShift = uintSize - 4 + +// Binary shift right (/ 2) by k bits. k <= maxShift to avoid overflow. +func rightShift(a *decimal, k uint) { + r := 0 // read pointer + w := 0 // write pointer + + // Pick up enough leading digits to cover first shift. + var n uint + for ; n>>k == 0; r++ { + if r >= a.nd { + if n == 0 { + // a == 0; shouldn't get here, but handle anyway. + a.nd = 0 + return + } + for n>>k == 0 { + n = n * 10 + r++ + } + break + } + c := uint(a.d[r]) + n = n*10 + c - '0' + } + a.dp -= r - 1 + + var mask uint = (1 << k) - 1 + + // Pick up a digit, put down a digit. + for ; r < a.nd; r++ { + c := uint(a.d[r]) + dig := n >> k + n &= mask + a.d[w] = byte(dig + '0') + w++ + n = n*10 + c - '0' + } + + // Put down extra digits. + for n > 0 { + dig := n >> k + n &= mask + if w < len(a.d) { + a.d[w] = byte(dig + '0') + w++ + } else if dig > 0 { + a.trunc = true + } + n = n * 10 + } + + a.nd = w + trim(a) +} + +// Cheat sheet for left shift: table indexed by shift count giving +// number of new digits that will be introduced by that shift. +// +// For example, leftcheats[4] = {2, "625"}. That means that +// if we are shifting by 4 (multiplying by 16), it will add 2 digits +// when the string prefix is "625" through "999", and one fewer digit +// if the string prefix is "000" through "624". +// +// Credit for this trick goes to Ken. + +type leftCheat struct { + delta int // number of new digits + cutoff string // minus one digit if original < a. +} + +var leftcheats = []leftCheat{ + // Leading digits of 1/2^i = 5^i. + // 5^23 is not an exact 64-bit floating point number, + // so have to use bc for the math. + // Go up to 60 to be large enough for 32bit and 64bit platforms. + /* + seq 60 | sed 's/^/5^/' | bc | + awk 'BEGIN{ print "\t{ 0, \"\" }," } + { + log2 = log(2)/log(10) + printf("\t{ %d, \"%s\" },\t// * %d\n", + int(log2*NR+1), $0, 2**NR) + }' + */ + {0, ""}, + {1, "5"}, // * 2 + {1, "25"}, // * 4 + {1, "125"}, // * 8 + {2, "625"}, // * 16 + {2, "3125"}, // * 32 + {2, "15625"}, // * 64 + {3, "78125"}, // * 128 + {3, "390625"}, // * 256 + {3, "1953125"}, // * 512 + {4, "9765625"}, // * 1024 + {4, "48828125"}, // * 2048 + {4, "244140625"}, // * 4096 + {4, "1220703125"}, // * 8192 + {5, "6103515625"}, // * 16384 + {5, "30517578125"}, // * 32768 + {5, "152587890625"}, // * 65536 + {6, "762939453125"}, // * 131072 + {6, "3814697265625"}, // * 262144 + {6, "19073486328125"}, // * 524288 + {7, "95367431640625"}, // * 1048576 + {7, "476837158203125"}, // * 2097152 + {7, "2384185791015625"}, // * 4194304 + {7, "11920928955078125"}, // * 8388608 + {8, "59604644775390625"}, // * 16777216 + {8, "298023223876953125"}, // * 33554432 + {8, "1490116119384765625"}, // * 67108864 + {9, "7450580596923828125"}, // * 134217728 + {9, "37252902984619140625"}, // * 268435456 + {9, "186264514923095703125"}, // * 536870912 + {10, "931322574615478515625"}, // * 1073741824 + {10, "4656612873077392578125"}, // * 2147483648 + {10, "23283064365386962890625"}, // * 4294967296 + {10, "116415321826934814453125"}, // * 8589934592 + {11, "582076609134674072265625"}, // * 17179869184 + {11, "2910383045673370361328125"}, // * 34359738368 + {11, "14551915228366851806640625"}, // * 68719476736 + {12, "72759576141834259033203125"}, // * 137438953472 + {12, "363797880709171295166015625"}, // * 274877906944 + {12, "1818989403545856475830078125"}, // * 549755813888 + {13, "9094947017729282379150390625"}, // * 1099511627776 + {13, "45474735088646411895751953125"}, // * 2199023255552 + {13, "227373675443232059478759765625"}, // * 4398046511104 + {13, "1136868377216160297393798828125"}, // * 8796093022208 + {14, "5684341886080801486968994140625"}, // * 17592186044416 + {14, "28421709430404007434844970703125"}, // * 35184372088832 + {14, "142108547152020037174224853515625"}, // * 70368744177664 + {15, "710542735760100185871124267578125"}, // * 140737488355328 + {15, "3552713678800500929355621337890625"}, // * 281474976710656 + {15, "17763568394002504646778106689453125"}, // * 562949953421312 + {16, "88817841970012523233890533447265625"}, // * 1125899906842624 + {16, "444089209850062616169452667236328125"}, // * 2251799813685248 + {16, "2220446049250313080847263336181640625"}, // * 4503599627370496 + {16, "11102230246251565404236316680908203125"}, // * 9007199254740992 + {17, "55511151231257827021181583404541015625"}, // * 18014398509481984 + {17, "277555756156289135105907917022705078125"}, // * 36028797018963968 + {17, "1387778780781445675529539585113525390625"}, // * 72057594037927936 + {18, "6938893903907228377647697925567626953125"}, // * 144115188075855872 + {18, "34694469519536141888238489627838134765625"}, // * 288230376151711744 + {18, "173472347597680709441192448139190673828125"}, // * 576460752303423488 + {19, "867361737988403547205962240695953369140625"}, // * 1152921504606846976 +} + +// Is the leading prefix of b lexicographically less than s? +func prefixIsLessThan(b []byte, s string) bool { + for i := 0; i < len(s); i++ { + if i >= len(b) { + return true + } + if b[i] != s[i] { + return b[i] < s[i] + } + } + return false +} + +// Binary shift left (* 2) by k bits. k <= maxShift to avoid overflow. +func leftShift(a *decimal, k uint) { + delta := leftcheats[k].delta + if prefixIsLessThan(a.d[0:a.nd], leftcheats[k].cutoff) { + delta-- + } + + r := a.nd // read index + w := a.nd + delta // write index + + // Pick up a digit, put down a digit. + var n uint + for r--; r >= 0; r-- { + n += (uint(a.d[r]) - '0') << k + quo := n / 10 + rem := n - 10*quo + w-- + if w < len(a.d) { + a.d[w] = byte(rem + '0') + } else if rem != 0 { + a.trunc = true + } + n = quo + } + + // Put down extra digits. + for n > 0 { + quo := n / 10 + rem := n - 10*quo + w-- + if w < len(a.d) { + a.d[w] = byte(rem + '0') + } else if rem != 0 { + a.trunc = true + } + n = quo + } + + a.nd += delta + if a.nd >= len(a.d) { + a.nd = len(a.d) + } + a.dp += delta + trim(a) +} + +// Binary shift left (k > 0) or right (k < 0). +func (a *decimal) Shift(k int) { + switch { + case a.nd == 0: + // nothing to do: a == 0 + case k > 0: + for k > maxShift { + leftShift(a, maxShift) + k -= maxShift + } + leftShift(a, uint(k)) + case k < 0: + for k < -maxShift { + rightShift(a, maxShift) + k += maxShift + } + rightShift(a, uint(-k)) + } +} + +// If we chop a at nd digits, should we round up? +func shouldRoundUp(a *decimal, nd int) bool { + if nd < 0 || nd >= a.nd { + return false + } + if a.d[nd] == '5' && nd+1 == a.nd { // exactly halfway - round to even + // if we truncated, a little higher than what's recorded - always round up + if a.trunc { + return true + } + return nd > 0 && (a.d[nd-1]-'0')%2 != 0 + } + // not halfway - digit tells all + return a.d[nd] >= '5' +} + +// Round a to nd digits (or fewer). +// If nd is zero, it means we're rounding +// just to the left of the digits, as in +// 0.09 -> 0.1. +func (a *decimal) Round(nd int) { + if nd < 0 || nd >= a.nd { + return + } + if shouldRoundUp(a, nd) { + a.RoundUp(nd) + } else { + a.RoundDown(nd) + } +} + +// Round a down to nd digits (or fewer). +func (a *decimal) RoundDown(nd int) { + if nd < 0 || nd >= a.nd { + return + } + a.nd = nd + trim(a) +} + +// Round a up to nd digits (or fewer). +func (a *decimal) RoundUp(nd int) { + if nd < 0 || nd >= a.nd { + return + } + + // round up + for i := nd - 1; i >= 0; i-- { + c := a.d[i] + if c < '9' { // can stop after this digit + a.d[i]++ + a.nd = i + 1 + return + } + } + + // Number is all 9s. + // Change to single 1 with adjusted decimal point. + a.d[0] = '1' + a.nd = 1 + a.dp++ +} + +// Extract integer part, rounded appropriately. +// No guarantees about overflow. +func (a *decimal) RoundedInteger() uint64 { + if a.dp > 20 { + return 0xFFFFFFFFFFFFFFFF + } + var i int + n := uint64(0) + for i = 0; i < a.dp && i < a.nd; i++ { + n = n*10 + uint64(a.d[i]-'0') + } + for ; i < a.dp; i++ { + n *= 10 + } + if shouldRoundUp(a, a.dp) { + n++ + } + return n +} diff --git a/vendor/github.com/shopspring/decimal/decimal.go b/vendor/github.com/shopspring/decimal/decimal.go index 20aa60806..134ece2ff 100644 --- a/vendor/github.com/shopspring/decimal/decimal.go +++ b/vendor/github.com/shopspring/decimal/decimal.go @@ -171,17 +171,84 @@ func RequireFromString(value string) Decimal { // NewFromFloat converts a float64 to Decimal. // -// Example: -// -// NewFromFloat(123.45678901234567).String() // output: "123.4567890123456" -// NewFromFloat(.00000000000000001).String() // output: "0.00000000000000001" +// The converted number will contain the number of significant digits that can be +// represented in a float with reliable roundtrip. +// This is typically 15 digits, but may be more in some cases. +// See https://www.exploringbinary.com/decimal-precision-of-binary-floating-point-numbers/ for more information. // -// NOTE: some float64 numbers can take up about 300 bytes of memory in decimal representation. -// Consider using NewFromFloatWithExponent if space is more important than precision. +// For slightly faster conversion, use NewFromFloatWithExponent where you can specify the precision in absolute terms. // // NOTE: this will panic on NaN, +/-inf func NewFromFloat(value float64) Decimal { - return NewFromFloatWithExponent(value, math.MinInt32) + if value == 0 { + return New(0, 0) + } + return newFromFloat(value, math.Float64bits(value), &float64info) +} + +// NewFromFloat converts a float32 to Decimal. +// +// The converted number will contain the number of significant digits that can be +// represented in a float with reliable roundtrip. +// This is typically 6-8 digits depending on the input. +// See https://www.exploringbinary.com/decimal-precision-of-binary-floating-point-numbers/ for more information. +// +// For slightly faster conversion, use NewFromFloatWithExponent where you can specify the precision in absolute terms. +// +// NOTE: this will panic on NaN, +/-inf +func NewFromFloat32(value float32) Decimal { + if value == 0 { + return New(0, 0) + } + // XOR is workaround for https://github.com/golang/go/issues/26285 + a := math.Float32bits(value) ^ 0x80808080 + return newFromFloat(float64(value), uint64(a)^0x80808080, &float32info) +} + +func newFromFloat(val float64, bits uint64, flt *floatInfo) Decimal { + if math.IsNaN(val) || math.IsInf(val, 0) { + panic(fmt.Sprintf("Cannot create a Decimal from %v", val)) + } + exp := int(bits>>flt.mantbits) & (1<>(flt.expbits+flt.mantbits) != 0 + + roundShortest(&d, mant, exp, flt) + // If less than 19 digits, we can do calculation in an int64. + if d.nd < 19 { + tmp := int64(0) + m := int64(1) + for i := d.nd - 1; i >= 0; i-- { + tmp += m * int64(d.d[i]-'0') + m *= 10 + } + if d.neg { + tmp *= -1 + } + return Decimal{value: big.NewInt(tmp), exp: int32(d.dp) - int32(d.nd)} + } + dValue := new(big.Int) + dValue, ok := dValue.SetString(string(d.d[:d.nd]), 10) + if ok { + return Decimal{value: dValue, exp: int32(d.dp) - int32(d.nd)} + } + + return NewFromFloatWithExponent(val, int32(d.dp)-int32(d.nd)) } // NewFromFloatWithExponent converts a float64 to Decimal, with an arbitrary @@ -376,6 +443,18 @@ func (d Decimal) Mul(d2 Decimal) Decimal { } } +// Shift shifts the decimal in base 10. +// It shifts left when shift is positive and right if shift is negative. +// In simpler terms, the given value for shift is added to the exponent +// of the decimal. +func (d Decimal) Shift(shift int32) Decimal { + d.ensureInitialized() + return Decimal{ + value: new(big.Int).Set(d.value), + exp: d.exp + shift, + } +} + // Div returns d / d2. If it doesn't divide exactly, the result will have // DivisionPrecision digits after the decimal point. func (d Decimal) Div(d2 Decimal) Decimal { @@ -544,6 +623,33 @@ func (d Decimal) Sign() int { return d.value.Sign() } +// IsPositive return +// +// true if d > 0 +// false if d == 0 +// false if d < 0 +func (d Decimal) IsPositive() bool { + return d.Sign() == 1 +} + +// IsNegative return +// +// true if d < 0 +// false if d == 0 +// false if d > 0 +func (d Decimal) IsNegative() bool { + return d.Sign() == -1 +} + +// IsZero return +// +// true if d == 0 +// false if d > 0 +// false if d < 0 +func (d Decimal) IsZero() bool { + return d.Sign() == 0 +} + // Exponent returns the exponent, or scale component of the decimal. func (d Decimal) Exponent() int32 { return d.exp @@ -1105,3 +1211,224 @@ func (d NullDecimal) MarshalJSON() ([]byte, error) { } return d.Decimal.MarshalJSON() } + +// Trig functions + +// Atan returns the arctangent, in radians, of x. +func (x Decimal) Atan() Decimal { + if x.Equal(NewFromFloat(0.0)) { + return x + } + if x.GreaterThan(NewFromFloat(0.0)) { + return x.satan() + } + return x.Neg().satan().Neg() +} + +func (d Decimal) xatan() Decimal { + P0 := NewFromFloat(-8.750608600031904122785e-01) + P1 := NewFromFloat(-1.615753718733365076637e+01) + P2 := NewFromFloat(-7.500855792314704667340e+01) + P3 := NewFromFloat(-1.228866684490136173410e+02) + P4 := NewFromFloat(-6.485021904942025371773e+01) + Q0 := NewFromFloat(2.485846490142306297962e+01) + Q1 := NewFromFloat(1.650270098316988542046e+02) + Q2 := NewFromFloat(4.328810604912902668951e+02) + Q3 := NewFromFloat(4.853903996359136964868e+02) + Q4 := NewFromFloat(1.945506571482613964425e+02) + z := d.Mul(d) + b1 := P0.Mul(z).Add(P1).Mul(z).Add(P2).Mul(z).Add(P3).Mul(z).Add(P4).Mul(z) + b2 := z.Add(Q0).Mul(z).Add(Q1).Mul(z).Add(Q2).Mul(z).Add(Q3).Mul(z).Add(Q4) + z = b1.Div(b2) + z = d.Mul(z).Add(d) + return z +} + +// satan reduces its argument (known to be positive) +// to the range [0, 0.66] and calls xatan. +func (d Decimal) satan() Decimal { + Morebits := NewFromFloat(6.123233995736765886130e-17) // pi/2 = PIO2 + Morebits + Tan3pio8 := NewFromFloat(2.41421356237309504880) // tan(3*pi/8) + pi := NewFromFloat(3.14159265358979323846264338327950288419716939937510582097494459) + + if d.LessThanOrEqual(NewFromFloat(0.66)) { + return d.xatan() + } + if d.GreaterThan(Tan3pio8) { + return pi.Div(NewFromFloat(2.0)).Sub(NewFromFloat(1.0).Div(d).xatan()).Add(Morebits) + } + return pi.Div(NewFromFloat(4.0)).Add((d.Sub(NewFromFloat(1.0)).Div(d.Add(NewFromFloat(1.0)))).xatan()).Add(NewFromFloat(0.5).Mul(Morebits)) +} + +// sin coefficients + var _sin = [...]Decimal{ + NewFromFloat(1.58962301576546568060E-10), // 0x3de5d8fd1fd19ccd + NewFromFloat(-2.50507477628578072866E-8), // 0xbe5ae5e5a9291f5d + NewFromFloat(2.75573136213857245213E-6), // 0x3ec71de3567d48a1 + NewFromFloat(-1.98412698295895385996E-4), // 0xbf2a01a019bfdf03 + NewFromFloat(8.33333333332211858878E-3), // 0x3f8111111110f7d0 + NewFromFloat(-1.66666666666666307295E-1), // 0xbfc5555555555548 + } + +// Sin returns the sine of the radian argument x. + func (d Decimal) Sin() Decimal { + PI4A := NewFromFloat(7.85398125648498535156E-1) // 0x3fe921fb40000000, Pi/4 split into three parts + PI4B := NewFromFloat(3.77489470793079817668E-8) // 0x3e64442d00000000, + PI4C := NewFromFloat(2.69515142907905952645E-15) // 0x3ce8469898cc5170, + M4PI := NewFromFloat(1.273239544735162542821171882678754627704620361328125) // 4/pi + + if d.Equal(NewFromFloat(0.0)) { + return d + } + // make argument positive but save the sign + sign := false + if d.LessThan(NewFromFloat(0.0)) { + d = d.Neg() + sign = true + } + + j := d.Mul(M4PI).IntPart() // integer part of x/(Pi/4), as integer for tests on the phase angle + y := NewFromFloat(float64(j)) // integer part of x/(Pi/4), as float + + // map zeros to origin + if j&1 == 1 { + j++ + y = y.Add(NewFromFloat(1.0)) + } + j &= 7 // octant modulo 2Pi radians (360 degrees) + // reflect in x axis + if j > 3 { + sign = !sign + j -= 4 + } + z := d.Sub(y.Mul(PI4A)).Sub(y.Mul(PI4B)).Sub(y.Mul(PI4C)) // Extended precision modular arithmetic + zz := z.Mul(z) + + if j == 1 || j == 2 { + w := zz.Mul(zz).Mul(_cos[0].Mul(zz).Add(_cos[1]).Mul(zz).Add(_cos[2]).Mul(zz).Add(_cos[3]).Mul(zz).Add(_cos[4]).Mul(zz).Add(_cos[5])) + y = NewFromFloat(1.0).Sub(NewFromFloat(0.5).Mul(zz)).Add(w) + } else { + y = z.Add(z.Mul(zz).Mul(_sin[0].Mul(zz).Add(_sin[1]).Mul(zz).Add(_sin[2]).Mul(zz).Add(_sin[3]).Mul(zz).Add(_sin[4]).Mul(zz).Add(_sin[5]))) + } + if sign { + y = y.Neg() + } + return y + } + + // cos coefficients + var _cos = [...]Decimal{ + NewFromFloat(-1.13585365213876817300E-11), // 0xbda8fa49a0861a9b + NewFromFloat(2.08757008419747316778E-9), // 0x3e21ee9d7b4e3f05 + NewFromFloat(-2.75573141792967388112E-7), // 0xbe927e4f7eac4bc6 + NewFromFloat(2.48015872888517045348E-5), // 0x3efa01a019c844f5 + NewFromFloat(-1.38888888888730564116E-3), // 0xbf56c16c16c14f91 + NewFromFloat(4.16666666666665929218E-2), // 0x3fa555555555554b + } + + // Cos returns the cosine of the radian argument x. + func (d Decimal) Cos() Decimal { + + PI4A := NewFromFloat(7.85398125648498535156E-1) // 0x3fe921fb40000000, Pi/4 split into three parts + PI4B := NewFromFloat(3.77489470793079817668E-8) // 0x3e64442d00000000, + PI4C := NewFromFloat(2.69515142907905952645E-15) // 0x3ce8469898cc5170, + M4PI := NewFromFloat(1.273239544735162542821171882678754627704620361328125) // 4/pi + + // make argument positive + sign := false + if d.LessThan(NewFromFloat(0.0)) { + d = d.Neg() + } + + j := d.Mul(M4PI).IntPart() // integer part of x/(Pi/4), as integer for tests on the phase angle + y := NewFromFloat(float64(j)) // integer part of x/(Pi/4), as float + + // map zeros to origin + if j&1 == 1 { + j++ + y = y.Add(NewFromFloat(1.0)) + } + j &= 7 // octant modulo 2Pi radians (360 degrees) + // reflect in x axis + if j > 3 { + sign = !sign + j -= 4 + } + if j > 1 { + sign = !sign + } + + z := d.Sub(y.Mul(PI4A)).Sub(y.Mul(PI4B)).Sub(y.Mul(PI4C)) // Extended precision modular arithmetic + zz := z.Mul(z) + + if j == 1 || j == 2 { + y = z.Add(z.Mul(zz).Mul(_sin[0].Mul(zz).Add(_sin[1]).Mul(zz).Add(_sin[2]).Mul(zz).Add(_sin[3]).Mul(zz).Add(_sin[4]).Mul(zz).Add(_sin[5]))) + } else { + w := zz.Mul(zz).Mul(_cos[0].Mul(zz).Add(_cos[1]).Mul(zz).Add(_cos[2]).Mul(zz).Add(_cos[3]).Mul(zz).Add(_cos[4]).Mul(zz).Add(_cos[5])) + y = NewFromFloat(1.0).Sub(NewFromFloat(0.5).Mul(zz)).Add(w) + } + if sign { + y = y.Neg() + } + return y + } + + var _tanP = [...]Decimal{ + NewFromFloat(-1.30936939181383777646E+4), // 0xc0c992d8d24f3f38 + NewFromFloat(1.15351664838587416140E+6), // 0x413199eca5fc9ddd + NewFromFloat(-1.79565251976484877988E+7), // 0xc1711fead3299176 + } + var _tanQ = [...]Decimal{ + NewFromFloat(1.00000000000000000000E+0), + NewFromFloat(1.36812963470692954678E+4), //0x40cab8a5eeb36572 + NewFromFloat(-1.32089234440210967447E+6), //0xc13427bc582abc96 + NewFromFloat(2.50083801823357915839E+7), //0x4177d98fc2ead8ef + NewFromFloat(-5.38695755929454629881E+7), //0xc189afe03cbe5a31 + } + + // Tan returns the tangent of the radian argument x. + func (d Decimal) Tan() Decimal { + + PI4A := NewFromFloat(7.85398125648498535156E-1) // 0x3fe921fb40000000, Pi/4 split into three parts + PI4B := NewFromFloat(3.77489470793079817668E-8) // 0x3e64442d00000000, + PI4C := NewFromFloat(2.69515142907905952645E-15) // 0x3ce8469898cc5170, + M4PI := NewFromFloat(1.273239544735162542821171882678754627704620361328125) // 4/pi + + if d.Equal(NewFromFloat(0.0)) { + return d + } + + // make argument positive but save the sign + sign := false + if d.LessThan(NewFromFloat(0.0)) { + d = d.Neg() + sign = true + } + + j := d.Mul(M4PI).IntPart() // integer part of x/(Pi/4), as integer for tests on the phase angle + y := NewFromFloat(float64(j)) // integer part of x/(Pi/4), as float + + // map zeros to origin + if j&1 == 1 { + j++ + y = y.Add(NewFromFloat(1.0)) + } + + z := d.Sub(y.Mul(PI4A)).Sub(y.Mul(PI4B)).Sub(y.Mul(PI4C)) // Extended precision modular arithmetic + zz := z.Mul(z) + + if zz.GreaterThan(NewFromFloat(1e-14)) { + w := zz.Mul(_tanP[0].Mul(zz).Add(_tanP[1]).Mul(zz).Add(_tanP[2])) + x := zz.Add(_tanQ[1]).Mul(zz).Add(_tanQ[2]).Mul(zz).Add(_tanQ[3]).Mul(zz).Add(_tanQ[4]) + y = z.Add(z.Mul(w.Div(x))) + } else { + y = z + } + if j&2 == 2 { + y = NewFromFloat(-1.0).Div(y) + } + if sign { + y = y.Neg() + } + return y + } diff --git a/vendor/github.com/shopspring/decimal/rounding.go b/vendor/github.com/shopspring/decimal/rounding.go new file mode 100644 index 000000000..fdd74eaa8 --- /dev/null +++ b/vendor/github.com/shopspring/decimal/rounding.go @@ -0,0 +1,118 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Multiprecision decimal numbers. +// For floating-point formatting only; not general purpose. +// Only operations are assign and (binary) left/right shift. +// Can do binary floating point in multiprecision decimal precisely +// because 2 divides 10; cannot do decimal floating point +// in multiprecision binary precisely. +package decimal + +type floatInfo struct { + mantbits uint + expbits uint + bias int +} + +var float32info = floatInfo{23, 8, -127} +var float64info = floatInfo{52, 11, -1023} + +// roundShortest rounds d (= mant * 2^exp) to the shortest number of digits +// that will let the original floating point value be precisely reconstructed. +func roundShortest(d *decimal, mant uint64, exp int, flt *floatInfo) { + // If mantissa is zero, the number is zero; stop now. + if mant == 0 { + d.nd = 0 + return + } + + // Compute upper and lower such that any decimal number + // between upper and lower (possibly inclusive) + // will round to the original floating point number. + + // We may see at once that the number is already shortest. + // + // Suppose d is not denormal, so that 2^exp <= d < 10^dp. + // The closest shorter number is at least 10^(dp-nd) away. + // The lower/upper bounds computed below are at distance + // at most 2^(exp-mantbits). + // + // So the number is already shortest if 10^(dp-nd) > 2^(exp-mantbits), + // or equivalently log2(10)*(dp-nd) > exp-mantbits. + // It is true if 332/100*(dp-nd) >= exp-mantbits (log2(10) > 3.32). + minexp := flt.bias + 1 // minimum possible exponent + if exp > minexp && 332*(d.dp-d.nd) >= 100*(exp-int(flt.mantbits)) { + // The number is already shortest. + return + } + + // d = mant << (exp - mantbits) + // Next highest floating point number is mant+1 << exp-mantbits. + // Our upper bound is halfway between, mant*2+1 << exp-mantbits-1. + upper := new(decimal) + upper.Assign(mant*2 + 1) + upper.Shift(exp - int(flt.mantbits) - 1) + + // d = mant << (exp - mantbits) + // Next lowest floating point number is mant-1 << exp-mantbits, + // unless mant-1 drops the significant bit and exp is not the minimum exp, + // in which case the next lowest is mant*2-1 << exp-mantbits-1. + // Either way, call it mantlo << explo-mantbits. + // Our lower bound is halfway between, mantlo*2+1 << explo-mantbits-1. + var mantlo uint64 + var explo int + if mant > 1<