diff --git a/cluster.go b/cluster.go index 4e54c5d81..ff431cac5 100644 --- a/cluster.go +++ b/cluster.go @@ -48,34 +48,26 @@ import ( type mongoCluster struct { sync.RWMutex - serverSynced sync.Cond - userSeeds []string - dynaSeeds []string - servers mongoServers - masters mongoServers - references int - syncing bool - direct bool - failFast bool - syncCount uint - setName string - cachedIndex map[string]bool - sync chan bool - dial dialer - appName string - minPoolSize int - maxIdleTimeMS int + serverSynced sync.Cond + userSeeds []string + dynaSeeds []string + servers mongoServers + masters mongoServers + references int + syncing bool + syncCount uint + cachedIndex map[string]bool + sync chan bool + dial dialer + dialInfo *DialInfo } -func newCluster(userSeeds []string, direct, failFast bool, dial dialer, setName string, appName string) *mongoCluster { +func newCluster(userSeeds []string, info *DialInfo) *mongoCluster { cluster := &mongoCluster{ userSeeds: userSeeds, references: 1, - direct: direct, - failFast: failFast, - dial: dial, - setName: setName, - appName: appName, + dial: dialer{info.Dial, info.DialServer}, + dialInfo: info, } cluster.serverSynced.L = cluster.RWMutex.RLocker() cluster.sync = make(chan bool, 1) @@ -147,7 +139,7 @@ type isMasterResult struct { func (cluster *mongoCluster) isMaster(socket *mongoSocket, result *isMasterResult) error { // Monotonic let's it talk to a slave and still hold the socket. - session := newSession(Monotonic, cluster, 10*time.Second) + session := newSession(Monotonic, cluster, cluster.dialInfo) session.setSocket(socket) var cmd = bson.D{{Name: "isMaster", Value: 1}} @@ -171,8 +163,8 @@ func (cluster *mongoCluster) isMaster(socket *mongoSocket, result *isMasterResul } // Include the application name if set - if cluster.appName != "" { - meta["application"] = bson.M{"name": cluster.appName} + if cluster.dialInfo.AppName != "" { + meta["application"] = bson.M{"name": cluster.dialInfo.AppName} } cmd = append(cmd, bson.DocElem{ @@ -190,19 +182,7 @@ type possibleTimeout interface { Timeout() bool } -var syncSocketTimeout = 5 * time.Second - func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerInfo, hosts []string, err error) { - var syncTimeout time.Duration - if raceDetector { - // This variable is only ever touched by tests. - globalMutex.Lock() - syncTimeout = syncSocketTimeout - globalMutex.Unlock() - } else { - syncTimeout = syncSocketTimeout - } - addr := server.Addr log("SYNC Processing ", addr, "...") @@ -210,7 +190,7 @@ func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerI var result isMasterResult var tryerr error for retry := 0; ; retry++ { - if retry == 3 || retry == 1 && cluster.failFast { + if retry == 3 || retry == 1 && cluster.dialInfo.FailFast { return nil, nil, tryerr } if retry > 0 { @@ -222,16 +202,22 @@ func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerI time.Sleep(syncShortDelay) } - // It's not clear what would be a good timeout here. Is it - // better to wait longer or to retry? - socket, _, err := server.AcquireSocket(0, syncTimeout) + // Don't ever hit the pool limit for syncing + config := cluster.dialInfo.Copy() + config.PoolLimit = 0 + + socket, _, err := server.AcquireSocket(config) if err != nil { tryerr = err logf("SYNC Failed to get socket to %s: %v", addr, err) continue } err = cluster.isMaster(socket, &result) + + // Restore the correct dial config before returning it to the pool + socket.dialInfo = cluster.dialInfo socket.Release() + if err != nil { tryerr = err logf("SYNC Command 'ismaster' to %s failed: %v", addr, err) @@ -241,9 +227,9 @@ func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerI break } - if cluster.setName != "" && result.SetName != cluster.setName { - logf("SYNC Server %s is not a member of replica set %q", addr, cluster.setName) - return nil, nil, fmt.Errorf("server %s is not a member of replica set %q", addr, cluster.setName) + if cluster.dialInfo.ReplicaSetName != "" && result.SetName != cluster.dialInfo.ReplicaSetName { + logf("SYNC Server %s is not a member of replica set %q", addr, cluster.dialInfo.ReplicaSetName) + return nil, nil, fmt.Errorf("server %s is not a member of replica set %q", addr, cluster.dialInfo.ReplicaSetName) } if result.IsMaster { @@ -255,7 +241,7 @@ func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerI } } else if result.Secondary { debugf("SYNC %s is a slave.", addr) - } else if cluster.direct { + } else if cluster.dialInfo.Direct { logf("SYNC %s in unknown state. Pretending it's a slave due to direct connection.", addr) } else { logf("SYNC %s is neither a master nor a slave.", addr) @@ -386,7 +372,7 @@ func (cluster *mongoCluster) syncServersLoop() { break } cluster.references++ // Keep alive while syncing. - direct := cluster.direct + direct := cluster.dialInfo.Direct cluster.Unlock() cluster.syncServersIteration(direct) @@ -401,7 +387,7 @@ func (cluster *mongoCluster) syncServersLoop() { // Hold off before allowing another sync. No point in // burning CPU looking for down servers. - if !cluster.failFast { + if !cluster.dialInfo.FailFast { time.Sleep(syncShortDelay) } @@ -439,13 +425,11 @@ func (cluster *mongoCluster) syncServersLoop() { func (cluster *mongoCluster) server(addr string, tcpaddr *net.TCPAddr) *mongoServer { cluster.RLock() server := cluster.servers.Search(tcpaddr.String()) - minPoolSize := cluster.minPoolSize - maxIdleTimeMS := cluster.maxIdleTimeMS cluster.RUnlock() if server != nil { return server } - return newServer(addr, tcpaddr, cluster.sync, cluster.dial, minPoolSize, maxIdleTimeMS) + return newServer(addr, tcpaddr, cluster.sync, cluster.dial, cluster.dialInfo) } func resolveAddr(addr string) (*net.TCPAddr, error) { @@ -614,19 +598,10 @@ func (cluster *mongoCluster) syncServersIteration(direct bool) { cluster.Unlock() } -// AcquireSocket returns a socket to a server in the cluster. If slaveOk is -// true, it will attempt to return a socket to a slave server. If it is -// false, the socket will necessarily be to a master server. -func (cluster *mongoCluster) AcquireSocket(mode Mode, slaveOk bool, syncTimeout time.Duration, socketTimeout time.Duration, serverTags []bson.D, poolLimit int) (s *mongoSocket, err error) { - return cluster.AcquireSocketWithPoolTimeout(mode, slaveOk, syncTimeout, socketTimeout, serverTags, poolLimit, 0) -} - // AcquireSocketWithPoolTimeout returns a socket to a server in the cluster. If slaveOk is // true, it will attempt to return a socket to a slave server. If it is // false, the socket will necessarily be to a master server. -func (cluster *mongoCluster) AcquireSocketWithPoolTimeout( - mode Mode, slaveOk bool, syncTimeout time.Duration, socketTimeout time.Duration, serverTags []bson.D, poolLimit int, poolTimeout time.Duration, -) (s *mongoSocket, err error) { +func (cluster *mongoCluster) AcquireSocketWithPoolTimeout(mode Mode, slaveOk bool, syncTimeout time.Duration, serverTags []bson.D, info *DialInfo) (s *mongoSocket, err error) { var started time.Time var syncCount uint for { @@ -645,7 +620,7 @@ func (cluster *mongoCluster) AcquireSocketWithPoolTimeout( // Initialize after fast path above. started = time.Now() syncCount = cluster.syncCount - } else if syncTimeout != 0 && started.Before(time.Now().Add(-syncTimeout)) || cluster.failFast && cluster.syncCount != syncCount { + } else if syncTimeout != 0 && started.Before(time.Now().Add(-syncTimeout)) || cluster.dialInfo.FailFast && cluster.syncCount != syncCount { cluster.RUnlock() return nil, errors.New("no reachable servers") } @@ -670,7 +645,7 @@ func (cluster *mongoCluster) AcquireSocketWithPoolTimeout( continue } - s, abended, err := server.AcquireSocketWithBlocking(poolLimit, socketTimeout, poolTimeout) + s, abended, err := server.AcquireSocketWithBlocking(info) if err == errPoolTimeout { // No need to remove servers from the topology if acquiring a socket fails for this reason. return nil, err diff --git a/cluster_test.go b/cluster_test.go index be11dc1a7..de99d414d 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -1055,8 +1055,6 @@ func (s *S) TestSocketTimeoutOnDial(c *C) { timeout := 1 * time.Second - defer mgo.HackSyncSocketTimeout(timeout)() - s.Freeze("localhost:40001") started := time.Now() diff --git a/export_test.go b/export_test.go index 998c7a2dd..1b7d7e941 100644 --- a/export_test.go +++ b/export_test.go @@ -19,20 +19,6 @@ func HackPingDelay(newDelay time.Duration) (restore func()) { return } -func HackSyncSocketTimeout(newTimeout time.Duration) (restore func()) { - globalMutex.Lock() - defer globalMutex.Unlock() - - oldTimeout := syncSocketTimeout - restore = func() { - globalMutex.Lock() - syncSocketTimeout = oldTimeout - globalMutex.Unlock() - } - syncSocketTimeout = newTimeout - return -} - func (s *Session) Cluster() *mongoCluster { return s.cluster() } diff --git a/server.go b/server.go index f34624f74..6f51ca5e3 100644 --- a/server.go +++ b/server.go @@ -67,9 +67,8 @@ type mongoServer struct { pingCount uint32 closed bool abended bool - minPoolSize int - maxIdleTimeMS int poolWaiter *sync.Cond + dialInfo *DialInfo } type dialer struct { @@ -91,21 +90,20 @@ type mongoServerInfo struct { var defaultServerInfo mongoServerInfo -func newServer(addr string, tcpaddr *net.TCPAddr, syncChan chan bool, dial dialer, minPoolSize, maxIdleTimeMS int) *mongoServer { +func newServer(addr string, tcpaddr *net.TCPAddr, syncChan chan bool, dial dialer, info *DialInfo) *mongoServer { server := &mongoServer{ - Addr: addr, - ResolvedAddr: tcpaddr.String(), - tcpaddr: tcpaddr, - sync: syncChan, - dial: dial, - info: &defaultServerInfo, - pingValue: time.Hour, // Push it back before an actual ping. - minPoolSize: minPoolSize, - maxIdleTimeMS: maxIdleTimeMS, + Addr: addr, + ResolvedAddr: tcpaddr.String(), + tcpaddr: tcpaddr, + sync: syncChan, + dial: dial, + info: &defaultServerInfo, + pingValue: time.Hour, // Push it back before an actual ping. + dialInfo: info, } server.poolWaiter = sync.NewCond(server) go server.pinger(true) - if maxIdleTimeMS != 0 { + if info.MaxIdleTimeMS != 0 { go server.poolShrinker() } return server @@ -123,22 +121,18 @@ var errServerClosed = errors.New("server was closed") // If the poolLimit argument is greater than zero and the number of sockets in // use in this server is greater than the provided limit, errPoolLimit is // returned. -func (server *mongoServer) AcquireSocket(poolLimit int, timeout time.Duration) (socket *mongoSocket, abended bool, err error) { - return server.acquireSocketInternal(poolLimit, timeout, false, 0*time.Millisecond) +func (server *mongoServer) AcquireSocket(info *DialInfo) (socket *mongoSocket, abended bool, err error) { + return server.acquireSocketInternal(info, false) } // AcquireSocketWithBlocking wraps AcquireSocket, but if a socket is not available, it will _not_ // return errPoolLimit. Instead, it will block waiting for a socket to become available. If poolTimeout // should elapse before a socket is available, it will return errPoolTimeout. -func (server *mongoServer) AcquireSocketWithBlocking( - poolLimit int, socketTimeout time.Duration, poolTimeout time.Duration, -) (socket *mongoSocket, abended bool, err error) { - return server.acquireSocketInternal(poolLimit, socketTimeout, true, poolTimeout) +func (server *mongoServer) AcquireSocketWithBlocking(info *DialInfo) (socket *mongoSocket, abended bool, err error) { + return server.acquireSocketInternal(info, true) } -func (server *mongoServer) acquireSocketInternal( - poolLimit int, timeout time.Duration, shouldBlock bool, poolTimeout time.Duration, -) (socket *mongoSocket, abended bool, err error) { +func (server *mongoServer) acquireSocketInternal(info *DialInfo, shouldBlock bool) (socket *mongoSocket, abended bool, err error) { for { server.Lock() abended = server.abended @@ -146,7 +140,7 @@ func (server *mongoServer) acquireSocketInternal( server.Unlock() return nil, abended, errServerClosed } - if poolLimit > 0 { + if info.PoolLimit > 0 { if shouldBlock { // Beautiful. Golang conditions don't have a WaitWithTimeout, so I've implemented the timeout // with a wait + broadcast. The broadcast will cause the loop here to re-check the timeout, @@ -158,11 +152,11 @@ func (server *mongoServer) acquireSocketInternal( // https://github.com/golang/go/issues/16620, since the lock needs to be held in _this_ goroutine. waitDone := make(chan struct{}) timeoutHit := false - if poolTimeout > 0 { + if info.PoolTimeout > 0 { go func() { select { case <-waitDone: - case <-time.After(poolTimeout): + case <-time.After(info.PoolTimeout): // timeoutHit is part of the wait condition, so needs to be changed under mutex. server.Lock() defer server.Unlock() @@ -172,7 +166,7 @@ func (server *mongoServer) acquireSocketInternal( }() } timeSpentWaiting := time.Duration(0) - for len(server.liveSockets)-len(server.unusedSockets) >= poolLimit && !timeoutHit { + for len(server.liveSockets)-len(server.unusedSockets) >= info.PoolLimit && !timeoutHit { // We only count time spent in Wait(), and not time evaluating the entire loop, // so that in the happy non-blocking path where the condition above evaluates true // first time, we record a nice round zero wait time. @@ -191,7 +185,7 @@ func (server *mongoServer) acquireSocketInternal( // Record that we fetched a connection of of a socket list and how long we spent waiting stats.noticeSocketAcquisition(timeSpentWaiting) } else { - if len(server.liveSockets)-len(server.unusedSockets) >= poolLimit { + if len(server.liveSockets)-len(server.unusedSockets) >= info.PoolLimit { server.Unlock() return nil, false, errPoolLimit } @@ -202,15 +196,15 @@ func (server *mongoServer) acquireSocketInternal( socket = server.unusedSockets[n-1] server.unusedSockets[n-1] = nil // Help GC. server.unusedSockets = server.unusedSockets[:n-1] - info := server.info + serverInfo := server.info server.Unlock() - err = socket.InitialAcquire(info, timeout) + err = socket.InitialAcquire(serverInfo, info) if err != nil { continue } } else { server.Unlock() - socket, err = server.Connect(timeout) + socket, err = server.Connect(info) if err == nil { server.Lock() // We've waited for the Connect, see if we got @@ -231,20 +225,18 @@ func (server *mongoServer) acquireSocketInternal( // Connect establishes a new connection to the server. This should // generally be done through server.AcquireSocket(). -func (server *mongoServer) Connect(timeout time.Duration) (*mongoSocket, error) { +func (server *mongoServer) Connect(info *DialInfo) (*mongoSocket, error) { server.RLock() master := server.info.Master dial := server.dial server.RUnlock() - logf("Establishing new connection to %s (timeout=%s)...", server.Addr, timeout) + logf("Establishing new connection to %s (timeout=%s)...", server.Addr, info.Timeout) var conn net.Conn var err error switch { case !dial.isSet(): - // Cannot do this because it lacks timeout support. :-( - //conn, err = net.DialTCP("tcp", nil, server.tcpaddr) - conn, err = net.DialTimeout("tcp", server.ResolvedAddr, timeout) + conn, err = net.DialTimeout("tcp", server.ResolvedAddr, info.Timeout) if tcpconn, ok := conn.(*net.TCPConn); ok { tcpconn.SetKeepAlive(true) } else if err == nil { @@ -264,7 +256,7 @@ func (server *mongoServer) Connect(timeout time.Duration) (*mongoSocket, error) logf("Connection to %s established.", server.Addr) stats.conn(+1, master) - return newSocket(server, conn, timeout), nil + return newSocket(server, conn, info), nil } // Close forces closing all sockets that are alive, whether @@ -407,7 +399,8 @@ func (server *mongoServer) pinger(loop bool) { time.Sleep(delay) } op := op - socket, _, err := server.AcquireSocket(0, delay) + + socket, _, err := server.AcquireSocket(server.dialInfo) if err == nil { start := time.Now() _, _ = socket.SimpleQuery(&op) @@ -448,7 +441,7 @@ func (server *mongoServer) poolShrinker() { } server.Lock() unused := len(server.unusedSockets) - if unused < server.minPoolSize { + if unused < server.dialInfo.MinPoolSize { server.Unlock() continue } @@ -457,8 +450,8 @@ func (server *mongoServer) poolShrinker() { reclaimMap := map[*mongoSocket]struct{}{} // Because the acquisition and recycle are done at the tail of array, // the head is always the oldest unused socket. - for _, s := range server.unusedSockets[:unused-server.minPoolSize] { - if s.lastTimeUsed.Add(time.Duration(server.maxIdleTimeMS) * time.Millisecond).After(now) { + for _, s := range server.unusedSockets[:unused-server.dialInfo.MinPoolSize] { + if s.lastTimeUsed.Add(time.Duration(server.dialInfo.MaxIdleTimeMS) * time.Millisecond).After(now) { break } end++ @@ -572,7 +565,7 @@ func (servers *mongoServers) BestFit(mode Mode, serverTags []bson.D) *mongoServe if best == nil { best = next best.RLock() - if serverTags != nil && !next.info.Mongos && !best.hasTags(serverTags) { + if len(serverTags) != 0 && !next.info.Mongos && !best.hasTags(serverTags) { best.RUnlock() best = nil } @@ -581,7 +574,7 @@ func (servers *mongoServers) BestFit(mode Mode, serverTags []bson.D) *mongoServe next.RLock() swap := false switch { - case serverTags != nil && !next.info.Mongos && !next.hasTags(serverTags): + case len(serverTags) != 0 && !next.info.Mongos && !next.hasTags(serverTags): // Must have requested tags. case mode == Secondary && next.info.Master && !next.info.Mongos: // Must be a secondary or mongos. diff --git a/server_test.go b/server_test.go index 1d21ef08b..43ddfa3b1 100644 --- a/server_test.go +++ b/server_test.go @@ -29,8 +29,8 @@ package mgo_test import ( "time" - . "gopkg.in/check.v1" "github.com/globalsign/mgo" + . "gopkg.in/check.v1" ) func (s *S) TestServerRecoversFromAbend(c *C) { @@ -40,7 +40,13 @@ func (s *S) TestServerRecoversFromAbend(c *C) { // Peek behind the scenes cluster := session.Cluster() server := cluster.Server("127.0.0.1:40001") - sock, abended, err := server.AcquireSocket(100, time.Second) + + info := &mgo.DialInfo{ + Timeout: time.Second, + PoolLimit: 100, + } + + sock, abended, err := server.AcquireSocket(info) c.Assert(err, IsNil) c.Assert(sock, NotNil) sock.Release() @@ -49,15 +55,15 @@ func (s *S) TestServerRecoversFromAbend(c *C) { sock.Close() server.AbendSocket(sock) // Next acquire notices the connection was abnormally ended - sock, abended, err = server.AcquireSocket(100, time.Second) + sock, abended, err = server.AcquireSocket(info) c.Assert(err, IsNil) sock.Release() c.Check(abended, Equals, true) - // cluster.AcquireSocket should fix the abended problems - sock, err = cluster.AcquireSocket(mgo.Primary, false, time.Minute, time.Second, nil, 100) + // cluster.AcquireSocketWithPoolTimeout should fix the abended problems + sock, err = cluster.AcquireSocketWithPoolTimeout(mgo.Primary, false, time.Minute, nil, info) c.Assert(err, IsNil) sock.Release() - sock, abended, err = server.AcquireSocket(100, time.Second) + sock, abended, err = server.AcquireSocket(info) c.Assert(err, IsNil) c.Check(abended, Equals, false) sock.Release() diff --git a/session.go b/session.go index 167a0375d..edd43f717 100644 --- a/session.go +++ b/session.go @@ -73,6 +73,14 @@ const ( Monotonic Mode = 1 // Strong mode is specific to mgo, and is same as Primary. Strong Mode = 2 + + // DefaultConnectionPoolLimit defines the default maximum number of + // connections in the connection pool. + // + // To override this value set DialInfo.PoolLimit. + DefaultConnectionPoolLimit = 4096 + + zeroDuration = time.Duration(0) ) // mgo.v3: Drop Strong mode, suffix all modes with "Mode". @@ -90,9 +98,6 @@ type Session struct { defaultdb string sourcedb string syncTimeout time.Duration - sockTimeout time.Duration - poolLimit int - poolTimeout time.Duration consistency Mode creds []Credential dialCred *Credential @@ -104,6 +109,8 @@ type Session struct { queryConfig query bypassValidation bool slaveOk bool + + dialInfo *DialInfo } // Database holds collections of documents @@ -196,7 +203,7 @@ const ( // Dial will timeout after 10 seconds if a server isn't reached. The returned // session will timeout operations after one minute by default if servers aren't // available. To customize the timeout, see DialWithTimeout, SetSyncTimeout, and -// SetSocketTimeout. +// DialInfo Read/WriteTimeout. // // This method is generally called just once for a given cluster. Further // sessions to the same cluster are then established using the New or Copy @@ -483,15 +490,38 @@ type DialInfo struct { Username string Password string - // PoolLimit defines the per-server socket pool limit. Defaults to 4096. - // See Session.SetPoolLimit for details. + // PoolLimit defines the per-server socket pool limit. Defaults to + // DefaultConnectionPoolLimit. See Session.SetPoolLimit for details. PoolLimit int // PoolTimeout defines max time to wait for a connection to become available - // if the pool limit is reaqched. Defaults to zero, which means forever. - // See Session.SetPoolTimeout for details + // if the pool limit is reached. Defaults to zero, which means forever. See + // Session.SetPoolTimeout for details PoolTimeout time.Duration + // ReadTimeout defines the maximum duration to wait for a response to be + // read from MongoDB. + // + // This effectively limits the maximum query execution time. If a MongoDB + // query duration exceeds this timeout, the caller will receive a timeout, + // however MongoDB will continue processing the query. This duration must be + // large enough to allow MongoDB to execute the query, and the response be + // received over the network connection. + // + // Only limits the network read - does not include unmarshalling / + // processing of the response. Defaults to DialInfo.Timeout. If 0, no + // timeout is set. + ReadTimeout time.Duration + + // WriteTimeout defines the maximum duration of a write to MongoDB over the + // network connection. + // + // This is can usually be low unless writing large documents, or over a high + // latency link. Only limits network write time - does not include + // marshalling/processing the request. Defaults to DialInfo.Timeout. If 0, + // no timeout is set. + WriteTimeout time.Duration + // The identifier of the client application which ran the operation. AppName string @@ -515,7 +545,7 @@ type DialInfo struct { // Defaults to 0. MinPoolSize int - //The maximum number of milliseconds that a connection can remain idle in the pool + // The maximum number of milliseconds that a connection can remain idle in the pool // before being removed and closed. MaxIdleTimeMS int @@ -527,6 +557,79 @@ type DialInfo struct { Dial func(addr net.Addr) (net.Conn, error) } +// Copy returns a deep-copy of i. +func (i *DialInfo) Copy() *DialInfo { + var readPreference *ReadPreference + if i.ReadPreference != nil { + readPreference = &ReadPreference{ + Mode: i.ReadPreference.Mode, + } + readPreference.TagSets = make([]bson.D, len(i.ReadPreference.TagSets)) + copy(readPreference.TagSets, i.ReadPreference.TagSets) + } + + info := &DialInfo{ + Timeout: i.Timeout, + Database: i.Database, + ReplicaSetName: i.ReplicaSetName, + Source: i.Source, + Service: i.Service, + ServiceHost: i.ServiceHost, + Mechanism: i.Mechanism, + Username: i.Username, + Password: i.Password, + PoolLimit: i.PoolLimit, + PoolTimeout: i.PoolTimeout, + ReadTimeout: i.ReadTimeout, + WriteTimeout: i.WriteTimeout, + AppName: i.AppName, + ReadPreference: readPreference, + FailFast: i.FailFast, + Direct: i.Direct, + MinPoolSize: i.MinPoolSize, + MaxIdleTimeMS: i.MaxIdleTimeMS, + DialServer: i.DialServer, + Dial: i.Dial, + } + + info.Addrs = make([]string, len(i.Addrs)) + copy(info.Addrs, i.Addrs) + + return info +} + +// readTimeout returns the configured read timeout, or i.Timeout if it's not set +func (i *DialInfo) readTimeout() time.Duration { + if i.ReadTimeout == zeroDuration { + return i.Timeout + } + return i.ReadTimeout +} + +// writeTimeout returns the configured write timeout, or i.Timeout if it's not +// set +func (i *DialInfo) writeTimeout() time.Duration { + if i.WriteTimeout == zeroDuration { + return i.Timeout + } + return i.WriteTimeout +} + +// roundTripTimeout returns the total time allocated for a single network read +// and write. +func (i *DialInfo) roundTripTimeout() time.Duration { + return i.readTimeout() + i.writeTimeout() +} + +// poolLimit returns the configured connection pool size, or +// DefaultConnectionPoolLimit. +func (i *DialInfo) poolLimit() int { + if i.PoolLimit == 0 { + return DefaultConnectionPoolLimit + } + return i.PoolLimit +} + // ReadPreference defines the manner in which servers are chosen. type ReadPreference struct { // Mode determines the consistency of results. See Session.SetMode. @@ -556,7 +659,12 @@ func (addr *ServerAddr) TCPAddr() *net.TCPAddr { } // DialWithInfo establishes a new session to the cluster identified by info. -func DialWithInfo(info *DialInfo) (*Session, error) { +func DialWithInfo(dialInfo *DialInfo) (*Session, error) { + info := dialInfo.Copy() + info.PoolLimit = info.poolLimit() + info.ReadTimeout = info.readTimeout() + info.WriteTimeout = info.writeTimeout() + addrs := make([]string, len(info.Addrs)) for i, addr := range info.Addrs { p := strings.LastIndexAny(addr, "]:") @@ -566,8 +674,8 @@ func DialWithInfo(info *DialInfo) (*Session, error) { } addrs[i] = addr } - cluster := newCluster(addrs, info.Direct, info.FailFast, dialer{info.Dial, info.DialServer}, info.ReplicaSetName, info.AppName) - session := newSession(Eventual, cluster, info.Timeout) + cluster := newCluster(addrs, info) + session := newSession(Eventual, cluster, info) session.defaultdb = info.Database if session.defaultdb == "" { session.defaultdb = "test" @@ -595,16 +703,6 @@ func DialWithInfo(info *DialInfo) (*Session, error) { } session.creds = []Credential{*session.dialCred} } - if info.PoolLimit > 0 { - session.poolLimit = info.PoolLimit - } - - cluster.minPoolSize = info.MinPoolSize - cluster.maxIdleTimeMS = info.MaxIdleTimeMS - - if info.PoolTimeout > 0 { - session.poolTimeout = info.PoolTimeout - } cluster.Release() @@ -624,6 +722,8 @@ func DialWithInfo(info *DialInfo) (*Session, error) { session.SetMode(Strong, true) } + session.dialInfo = info + return session, nil } @@ -684,13 +784,12 @@ func extractURL(s string) (*urlInfo, error) { return info, nil } -func newSession(consistency Mode, cluster *mongoCluster, timeout time.Duration) (session *Session) { +func newSession(consistency Mode, cluster *mongoCluster, info *DialInfo) (session *Session) { cluster.Acquire() session = &Session{ mgoCluster: cluster, - syncTimeout: timeout, - sockTimeout: timeout, - poolLimit: 4096, + syncTimeout: info.Timeout, + dialInfo: info, } debugf("New session %p on cluster %p", session, cluster) session.SetMode(consistency, true) @@ -719,9 +818,6 @@ func copySession(session *Session, keepCreds bool) (s *Session) { defaultdb: session.defaultdb, sourcedb: session.sourcedb, syncTimeout: session.syncTimeout, - sockTimeout: session.sockTimeout, - poolLimit: session.poolLimit, - poolTimeout: session.poolTimeout, consistency: session.consistency, creds: creds, dialCred: session.dialCred, @@ -733,6 +829,7 @@ func copySession(session *Session, keepCreds bool) (s *Session) { queryConfig: session.queryConfig, bypassValidation: session.bypassValidation, slaveOk: session.slaveOk, + dialInfo: session.dialInfo, } s = &scopy debugf("New session %p on cluster %p (copy from %p)", s, cluster, session) @@ -2018,13 +2115,21 @@ func (s *Session) SetSyncTimeout(d time.Duration) { s.m.Unlock() } -// SetSocketTimeout sets the amount of time to wait for a non-responding -// socket to the database before it is forcefully closed. +// SetSocketTimeout is deprecated - use DialInfo read/write timeouts instead. +// +// SetSocketTimeout sets the amount of time to wait for a non-responding socket +// to the database before it is forcefully closed. // // The default timeout is 1 minute. func (s *Session) SetSocketTimeout(d time.Duration) { s.m.Lock() - s.sockTimeout = d + + // Set both the read and write timeout, as well as the DialInfo.Timeout for + // backwards compatibility, + s.dialInfo.Timeout = d + s.dialInfo.ReadTimeout = d + s.dialInfo.WriteTimeout = d + if s.masterSocket != nil { s.masterSocket.SetTimeout(d) } @@ -2058,7 +2163,7 @@ func (s *Session) SetCursorTimeout(d time.Duration) { // of used resources and number of goroutines before they are created. func (s *Session) SetPoolLimit(limit int) { s.m.Lock() - s.poolLimit = limit + s.dialInfo.PoolLimit = limit s.m.Unlock() } @@ -2068,7 +2173,7 @@ func (s *Session) SetPoolLimit(limit int) { // The default value is zero, which means to wait forever with no timeout. func (s *Session) SetPoolTimeout(timeout time.Duration) { s.m.Lock() - s.poolTimeout = timeout + s.dialInfo.PoolTimeout = timeout s.m.Unlock() } @@ -4356,11 +4461,13 @@ func (iter *Iter) acquireSocket() (*mongoSocket, error) { // with Eventual sessions, if a Refresh is done, or if a // monotonic session gets a write and shifts from secondary // to primary. Our cursor is in a specific server, though. + iter.session.m.Lock() - sockTimeout := iter.session.sockTimeout + info := iter.session.dialInfo iter.session.m.Unlock() + socket.Release() - socket, _, err = iter.server.AcquireSocket(0, sockTimeout) + socket, _, err = iter.server.AcquireSocket(info) if err != nil { return nil, err } @@ -4950,7 +5057,11 @@ func (s *Session) acquireSocket(slaveOk bool) (*mongoSocket, error) { // Still not good. We need a new socket. sock, err := s.cluster().AcquireSocketWithPoolTimeout( - s.consistency, slaveOk && s.slaveOk, s.syncTimeout, s.sockTimeout, s.queryConfig.op.serverTags, s.poolLimit, s.poolTimeout, + s.consistency, + slaveOk && s.slaveOk, + s.syncTimeout, + s.queryConfig.op.serverTags, + s.dialInfo, ) if err != nil { return nil, err diff --git a/session_internal_test.go b/session_internal_test.go index ddce59cae..3e214b174 100644 --- a/session_internal_test.go +++ b/session_internal_test.go @@ -3,9 +3,11 @@ package mgo import ( "crypto/x509/pkix" "encoding/asn1" + "testing" + "time" + "github.com/globalsign/mgo/bson" . "gopkg.in/check.v1" - "testing" ) type S struct{} @@ -62,3 +64,22 @@ func (s *S) TestGetRFC2253NameStringMultiValued(c *C) { c.Assert(getRFC2253NameString(&RDNElements), Equals, "OU=Sales+CN=J. Smith,O=Widget Inc.,C=US") } + +func (s *S) TestDialTimeouts(c *C) { + info := &DialInfo{} + + c.Assert(info.readTimeout(), Equals, time.Duration(0)) + c.Assert(info.writeTimeout(), Equals, time.Duration(0)) + c.Assert(info.roundTripTimeout(), Equals, time.Duration(0)) + + info.Timeout = 60 * time.Second + c.Assert(info.readTimeout(), Equals, 60*time.Second) + c.Assert(info.writeTimeout(), Equals, 60*time.Second) + c.Assert(info.roundTripTimeout(), Equals, 120*time.Second) + + info.ReadTimeout = time.Second + c.Assert(info.readTimeout(), Equals, time.Second) + + info.WriteTimeout = time.Second + c.Assert(info.writeTimeout(), Equals, time.Second) +} diff --git a/socket.go b/socket.go index ae13e401f..9dcedf219 100644 --- a/socket.go +++ b/socket.go @@ -42,7 +42,6 @@ type mongoSocket struct { sync.Mutex server *mongoServer // nil when cached conn net.Conn - timeout time.Duration addr string // For debugging only. nextRequestId uint32 replyFuncs map[uint32]replyFunc @@ -56,6 +55,8 @@ type mongoSocket struct { closeAfterIdle bool lastTimeUsed time.Time // for time based idle socket release sendMeta sync.Once + + dialInfo *DialInfo } type queryOpFlags uint32 @@ -181,15 +182,16 @@ type requestInfo struct { replyFunc replyFunc } -func newSocket(server *mongoServer, conn net.Conn, timeout time.Duration) *mongoSocket { +func newSocket(server *mongoServer, conn net.Conn, info *DialInfo) *mongoSocket { socket := &mongoSocket{ conn: conn, addr: server.Addr, server: server, replyFuncs: make(map[uint32]replyFunc), + dialInfo: info, } socket.gotNonce.L = &socket.Mutex - if err := socket.InitialAcquire(server.Info(), timeout); err != nil { + if err := socket.InitialAcquire(server.Info(), info); err != nil { panic("newSocket: InitialAcquire returned error: " + err.Error()) } stats.socketsAlive(+1) @@ -223,7 +225,7 @@ func (socket *mongoSocket) ServerInfo() *mongoServerInfo { // InitialAcquire obtains the first reference to the socket, either // right after the connection is made or once a recycled socket is // being put back in use. -func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout time.Duration) error { +func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, dialInfo *DialInfo) error { socket.Lock() if socket.references > 0 { panic("Socket acquired out of cache with references") @@ -235,7 +237,7 @@ func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout t } socket.references++ socket.serverInfo = serverInfo - socket.timeout = timeout + socket.dialInfo = dialInfo stats.socketsInUse(+1) stats.socketRefs(+1) socket.Unlock() @@ -288,7 +290,8 @@ func (socket *mongoSocket) Release() { // SetTimeout changes the timeout used on socket operations. func (socket *mongoSocket) SetTimeout(d time.Duration) { socket.Lock() - socket.timeout = d + socket.dialInfo.ReadTimeout = d + socket.dialInfo.WriteTimeout = d socket.Unlock() } @@ -301,24 +304,37 @@ const ( func (socket *mongoSocket) updateDeadline(which deadlineType) { var when time.Time - if socket.timeout > 0 { - when = time.Now().Add(socket.timeout) - } - whichstr := "" + var whichStr string switch which { case readDeadline | writeDeadline: - whichstr = "read/write" + if socket.dialInfo.roundTripTimeout() == 0 { + return + } + whichStr = "read/write" + when = time.Now().Add(socket.dialInfo.roundTripTimeout()) socket.conn.SetDeadline(when) + case readDeadline: - whichstr = "read" + if socket.dialInfo.ReadTimeout == zeroDuration { + return + } + whichStr = "read" + when = time.Now().Add(socket.dialInfo.ReadTimeout) socket.conn.SetReadDeadline(when) + case writeDeadline: - whichstr = "write" + if socket.dialInfo.WriteTimeout == zeroDuration { + return + } + whichStr = "write" + when = time.Now().Add(socket.dialInfo.WriteTimeout) socket.conn.SetWriteDeadline(when) + default: panic("invalid parameter to updateDeadline") } - debugf("Socket %p to %s: updated %s deadline to %s ahead (%s)", socket, socket.addr, whichstr, socket.timeout, when) + + debugf("Socket %p to %s: updated %s deadline to %s", socket, socket.addr, whichStr, when) } // Close terminates the socket use.