From a15e541309bc9d0aadd9827fab9445a76bba1058 Mon Sep 17 00:00:00 2001 From: martonp Date: Wed, 25 Sep 2024 18:14:43 +0200 Subject: [PATCH] AutoReconnect and other fixes --- client/comms/wsconn.go | 175 ++++++++++++++++++++++++------------- client/core/core_test.go | 2 + client/mm/libxc/binance.go | 173 ++++++++++++++---------------------- 3 files changed, 182 insertions(+), 168 deletions(-) diff --git a/client/comms/wsconn.go b/client/comms/wsconn.go index 96ffabd082..63266894b7 100644 --- a/client/comms/wsconn.go +++ b/client/comms/wsconn.go @@ -98,6 +98,7 @@ type WsConn interface { RequestWithTimeout(msg *msgjson.Message, respHandler func(*msgjson.Message), expireTime time.Duration, expire func()) error Connect(ctx context.Context) (*sync.WaitGroup, error) MessageSource() <-chan *msgjson.Message + UpdateURL(string) } // When the DEX sends a request to the client, a responseHandler is created @@ -118,6 +119,10 @@ type WsCfg struct { // latency. PingWait time.Duration + // AutoReconnect, if non-nil, will reconnect to the server after each + // interval of the amount of time specified. + AutoReconnect *time.Duration + // The server's certificate. Cert []byte @@ -161,6 +166,7 @@ type wsConn struct { cfg *WsCfg tlsCfg *tls.Config readCh chan *msgjson.Message + URL atomic.Value // string wsMtx sync.Mutex ws *websocket.Conn @@ -203,14 +209,23 @@ func NewWsConn(cfg *WsCfg) (WsConn, error) { ServerName: uri.Hostname(), } - return &wsConn{ + conn := &wsConn{ cfg: cfg, log: cfg.Logger, tlsCfg: tlsConfig, readCh: make(chan *msgjson.Message, readBuffSize), respHandlers: make(map[uint64]*responseHandler), reconnectCh: make(chan struct{}, 1), - }, nil + } + + conn.URL.Store(cfg.URL) + + return conn, nil +} + +// UpdateURL updates the URL that the connection uses when reconnecting. +func (conn *wsConn) UpdateURL(rawURL string) { + conn.URL.Store(rawURL) } // IsDown indicates if the connection is known to be down. @@ -240,7 +255,7 @@ func (conn *wsConn) connect(ctx context.Context) error { dialer.Proxy = http.ProxyFromEnvironment } - ws, _, err := dialer.DialContext(ctx, conn.cfg.URL, conn.cfg.ConnectHeaders) + ws, _, err := dialer.DialContext(ctx, conn.URL.Load().(string), conn.cfg.ConnectHeaders) if err != nil { if isErrorInvalidCert(err) { conn.setConnectionStatus(InvalidCert) @@ -303,9 +318,9 @@ func (conn *wsConn) connect(ctx context.Context) error { go func() { defer conn.wg.Done() if conn.cfg.RawHandler != nil { - conn.readRaw(ctx) + conn.readRaw(ctx, ws) } else { - conn.read(ctx) + conn.read(ctx, ws) } }() @@ -331,7 +346,7 @@ func (conn *wsConn) handleReadError(err error) { var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - conn.log.Errorf("Read timeout on connection to %s.", conn.cfg.URL) + conn.log.Errorf("Read timeout on connection to %s.", conn.URL.Load().(string)) reconnect() return } @@ -372,75 +387,110 @@ func (conn *wsConn) close() { conn.ws.Close() } -func (conn *wsConn) readRaw(ctx context.Context) { - for { - // Lock since conn.ws may be set by connect. - conn.wsMtx.Lock() - ws := conn.ws - conn.wsMtx.Unlock() +func (conn *wsConn) readRaw(ctx context.Context, ws *websocket.Conn) { + var reconnectTimer <-chan time.Time + if conn.cfg.AutoReconnect != nil { + reconnectTimer = time.After(*conn.cfg.AutoReconnect) + } + + type readResult struct { + msgBytes []byte + err error + } + + readMessage := func() chan *readResult { + ch := make(chan *readResult, 1) + go func() { + _, msgBytes, err := ws.ReadMessage() + ch <- &readResult{msgBytes, err} + }() + return ch + } - // Block until a message is received or an error occurs. - _, msgBytes, err := ws.ReadMessage() - // Drop the read error on context cancellation. - if ctx.Err() != nil { + for { + select { + case result := <-readMessage(): + if ctx.Err() != nil { + return + } + if result.err != nil { + conn.handleReadError(result.err) + return + } + conn.cfg.RawHandler(result.msgBytes) + case <-reconnectTimer: + conn.reconnectCh <- struct{}{} return - } - if err != nil { - conn.handleReadError(err) + case <-ctx.Done(): return } - conn.cfg.RawHandler(msgBytes) } } // read fetches and parses incoming messages for processing. This should be // run as a goroutine. Increment the wg before calling read. -func (conn *wsConn) read(ctx context.Context) { - for { - msg := new(msgjson.Message) +func (conn *wsConn) read(ctx context.Context, ws *websocket.Conn) { + var reconnectTimer <-chan time.Time + if conn.cfg.AutoReconnect != nil { + reconnectTimer = time.After(*conn.cfg.AutoReconnect) + } - // Lock since conn.ws may be set by connect. - conn.wsMtx.Lock() - ws := conn.ws - conn.wsMtx.Unlock() + type readResult struct { + msg *msgjson.Message + err error + } - // The read itself does not require locking since only this goroutine - // uses read functions that are not safe for concurrent use. - err := ws.ReadJSON(msg) - // Drop the read error on context cancellation. - if ctx.Err() != nil { - return - } - if err != nil { - var mErr *json.UnmarshalTypeError - if errors.As(err, &mErr) { - // JSON decode errors are not fatal, log and proceed. - conn.log.Errorf("json decode error: %v", mErr) - continue - } - conn.handleReadError(err) - return - } + readMessage := func() chan *readResult { + ch := make(chan *readResult, 1) + go func() { + msg := new(msgjson.Message) + err := ws.ReadJSON(msg) + ch <- &readResult{msg, err} + }() + return ch + } - // If the message is a response, find the handler. - if msg.Type == msgjson.Response { - handler := conn.respHandler(msg.ID) - if handler == nil { - b, _ := json.Marshal(msg) - conn.log.Errorf("No handler found for response: %v", string(b)) + for { + select { + case result := <-readMessage(): + if ctx.Err() != nil { + return + } + if result.err != nil { + var mErr *json.UnmarshalTypeError + if errors.As(result.err, &mErr) { + // JSON decode errors are not fatal, log and proceed. + conn.log.Errorf("json decode error: %v", mErr) + continue + } + conn.handleReadError(result.err) + return + } + // If the message is a response, find the handler. + if result.msg.Type == msgjson.Response { + handler := conn.respHandler(result.msg.ID) + if handler == nil { + b, _ := json.Marshal(result.msg) + conn.log.Errorf("No handler found for response: %v", string(b)) + continue + } + // Run handlers in a goroutine so that other messages can be + // received. Include the handler goroutines in the WaitGroup to + // allow them to complete if the connection master desires. + conn.wg.Add(1) + go func() { + defer conn.wg.Done() + handler.f(result.msg) + }() continue } - // Run handlers in a goroutine so that other messages can be - // received. Include the handler goroutines in the WaitGroup to - // allow them to complete if the connection master desires. - conn.wg.Add(1) - go func() { - defer conn.wg.Done() - handler.f(msg) - }() - continue + conn.readCh <- result.msg + case <-reconnectTimer: + conn.reconnectCh <- struct{}{} + return + case <-ctx.Done(): + return } - conn.readCh <- msg } } @@ -457,11 +507,11 @@ func (conn *wsConn) keepAlive(ctx context.Context) { return } - conn.log.Infof("Attempting to reconnect to %s...", conn.cfg.URL) + conn.log.Infof("Attempting to reconnect to %s...", conn.URL.Load().(string)) err := conn.connect(ctx) if err != nil { conn.log.Errorf("Reconnect failed. Scheduling reconnect to %s in %.1f seconds.", - conn.cfg.URL, rcInt.Seconds()) + conn.URL.Load().(string), rcInt.Seconds()) time.AfterFunc(rcInt, func() { conn.reconnectCh <- struct{}{} }) @@ -479,7 +529,6 @@ func (conn *wsConn) keepAlive(ctx context.Context) { if conn.cfg.ReconnectSync != nil { conn.cfg.ReconnectSync() } - case <-ctx.Done(): return } diff --git a/client/core/core_test.go b/client/core/core_test.go index 8f6fa9c5b0..e0d097660a 100644 --- a/client/core/core_test.go +++ b/client/core/core_test.go @@ -346,6 +346,8 @@ func (conn *TWebsocket) Connect(context.Context) (*sync.WaitGroup, error) { return &sync.WaitGroup{}, conn.connectErr } +func (conn *TWebsocket) UpdateURL(rawURL string) {} + type TDB struct { updateWalletErr error acct *db.AccountInfo diff --git a/client/mm/libxc/binance.go b/client/mm/libxc/binance.go index 020d8a4818..79beab1edb 100644 --- a/client/mm/libxc/binance.go +++ b/client/mm/libxc/binance.go @@ -88,7 +88,7 @@ func newBinanceOrderBook( quoteConversionFactor: quoteConversionFactor, log: log, getSnapshot: getSnapshot, - connectedChan: make(chan bool), + connectedChan: make(chan bool, 4), } } @@ -164,7 +164,7 @@ func (b *binanceOrderBook) Connect(ctx context.Context) (*sync.WaitGroup, error resyncChan := make(chan struct{}, 1) - desync := func() { + desync := func(resync bool) { // clear the sync cache, set the special ID, trigger a book refresh. syncMtx.Lock() defer syncMtx.Unlock() @@ -173,7 +173,9 @@ func (b *binanceOrderBook) Connect(ctx context.Context) (*sync.WaitGroup, error if updateID != updateIDUnsynced { b.synced.Store(false) updateID = updateIDUnsynced - resyncChan <- struct{}{} + if resync { + resyncChan <- struct{}{} + } } } @@ -268,7 +270,7 @@ func (b *binanceOrderBook) Connect(ctx context.Context) (*sync.WaitGroup, error case update := <-b.updateQueue: if !processUpdate(update) { b.log.Tracef("Bad %s update with ID %d", b.mktID, update.LastUpdateID) - desync() + desync(true) } case <-ctx.Done(): return @@ -288,13 +290,10 @@ func (b *binanceOrderBook) Connect(ctx context.Context) (*sync.WaitGroup, error select { case <-retry: case <-resyncChan: - if retry != nil { // don't hammer - continue - } case connected := <-b.connectedChan: if !connected { - b.log.Debugf("Unsyncing %s orderbook due to disconnect.", b.mktID, retryFrequency) - desync() + b.log.Debugf("Unsyncing %s orderbook due to disconnect.", b.mktID) + desync(false) retry = nil continue } @@ -307,7 +306,7 @@ func (b *binanceOrderBook) Connect(ctx context.Context) (*sync.WaitGroup, error retry = nil } else { b.log.Infof("Failed to sync %s orderbook. Trying again in %s", b.mktID, retryFrequency) - desync() // Clears the syncCache + desync(false) // Clears the syncCache retry = time.After(retryFrequency) } } @@ -1647,12 +1646,13 @@ func (bnc *binance) subscribeToAdditionalMarketDataStream(ctx context.Context, b bnc.books[mktID] = book book.sync(ctx) + bnc.marketStream.UpdateURL(bnc.marketStreamsURL()) + return nil } +// bnc.booksMtx MUST be read locked when calling this function. func (bnc *binance) streams() []string { - bnc.booksMtx.RLock() - defer bnc.booksMtx.RUnlock() streamNames := make([]string, 0, len(bnc.books)) for mktID := range bnc.books { streamNames = append(streamNames, marketDataStreamID(mktID)) @@ -1660,13 +1660,22 @@ func (bnc *binance) streams() []string { return streamNames } +// bnc.booksMtx MUST be read locked when calling this function. +func (bnc *binance) marketStreamsURL() string { + return fmt.Sprintf("%s/stream?streams=%s", bnc.wsURL, strings.Join(bnc.streams(), "/")) +} + // checkSubs will query binance for current market subscriptions and compare // that to what subscriptions we should have. If there is a discrepancy a // warning is logged and the market subbed or unsubbed. func (bnc *binance) checkSubs(ctx context.Context) error { bnc.marketStreamMtx.Lock() defer bnc.marketStreamMtx.Unlock() + + bnc.booksMtx.RLock() streams := bnc.streams() + bnc.booksMtx.RUnlock() + if len(streams) == 0 { return nil } @@ -1746,61 +1755,9 @@ out: } // connectToMarketDataStream is called when the first market is subscribed to. -// It creates a connection to the market data stream and starts a goroutine -// to reconnect every 12 hours, as Binance will close the stream every 24 -// hours. Additional markets are subscribed to by calling +// Additional markets are subscribed to by calling // subscribeToAdditionalMarketDataStream. func (bnc *binance) connectToMarketDataStream(ctx context.Context, baseID, quoteID uint32) error { - reconnectC := make(chan struct{}) - - newConnection := func() (comms.WsConn, *dex.ConnectionMaster, error) { - addr := fmt.Sprintf("%s/stream?streams=%s", bnc.wsURL, strings.Join(bnc.streams(), "/")) - // Need to send key but not signature - connectEventFunc := func(cs comms.ConnectionStatus) { - if cs != comms.Disconnected && cs != comms.Connected { - return - } - // If disconnected, set all books to unsynced so bots - // will not place new orders. - connected := cs == comms.Connected - bnc.booksMtx.RLock() - defer bnc.booksMtx.RLock() - for _, b := range bnc.books { - b.connectedChan <- connected - } - } - conn, err := comms.NewWsConn(&comms.WsCfg{ - URL: addr, - // Binance Docs: The websocket server will send a ping frame every 3 - // minutes. If the websocket server does not receive a pong frame - // back from the connection within a 10 minute period, the connection - // will be disconnected. Unsolicited pong frames are allowed. - PingWait: time.Minute * 4, - EchoPingData: true, - ReconnectSync: func() { - bnc.log.Debugf("Binance reconnected") - select { - case reconnectC <- struct{}{}: - default: - } - }, - ConnectEventFunc: connectEventFunc, - Logger: bnc.log.SubLogger("BNCBOOK"), - RawHandler: bnc.handleMarketDataNote, - }) - if err != nil { - return nil, nil, err - } - - bnc.marketStream = conn - cm := dex.NewConnectionMaster(conn) - if err = cm.ConnectOnce(ctx); err != nil { - return nil, nil, fmt.Errorf("websocketHandler remote connect: %v", err) - } - - return conn, cm, nil - } - // Add the initial book to the books map baseCfg, quoteCfg, err := bncAssetCfgs(baseID, quoteID) if err != nil { @@ -1813,60 +1770,65 @@ func (bnc *binance) connectToMarketDataStream(ctx context.Context, baseID, quote } book := newBinanceOrderBook(baseCfg.conversionFactor, quoteCfg.conversionFactor, mktID, getSnapshot, bnc.log) bnc.books[mktID] = book + marketStreamsURL := bnc.marketStreamsURL() bnc.booksMtx.Unlock() - // Create initial connection to the market data stream - conn, cm, err := newConnection() + // Need to send key but not signature + connectEventFunc := func(cs comms.ConnectionStatus) { + if cs != comms.Disconnected && cs != comms.Connected { + return + } + + // If disconnected, set all books to unsynced so bots + // will not place new orders. + connected := cs == comms.Connected + + bnc.booksMtx.RLock() + defer bnc.booksMtx.RUnlock() + + for _, b := range bnc.books { + select { + case b.connectedChan <- connected: + default: // don't block + } + } + } + + reconnectInterval := 12 * time.Hour + conn, err := comms.NewWsConn(&comms.WsCfg{ + URL: marketStreamsURL, + // Binance Docs: The websocket server will send a ping frame every 3 + // minutes. If the websocket server does not receive a pong frame + // back from the connection within a 10 minute period, the connection + // will be disconnected. Unsolicited pong frames are allowed. + PingWait: time.Minute * 4, + EchoPingData: true, + ReconnectSync: func() { + bnc.log.Debugf("Binance reconnected") + }, + ConnectEventFunc: connectEventFunc, + Logger: bnc.log.SubLogger("BNCBOOK"), + RawHandler: bnc.handleMarketDataNote, + AutoReconnect: &reconnectInterval, + }) if err != nil { - return fmt.Errorf("error connecting to market data stream : %v", err) + return err + } + + cm := dex.NewConnectionMaster(conn) + if err = cm.ConnectOnce(ctx); err != nil { + return fmt.Errorf("websocketHandler remote connect: %v", err) } bnc.marketStream = conn book.sync(ctx) - // Start a goroutine to reconnect every 12 hours go func() { - reconnect := func() error { - bnc.marketStreamMtx.Lock() - defer bnc.marketStreamMtx.Unlock() - - oldCm := cm - conn, cm, err = newConnection() - if err != nil { - return err - } - - if oldCm != nil { - oldCm.Disconnect() - } - - bnc.marketStream = conn - return nil - } - checkSubsInterval := time.Minute checkSubs := time.After(checkSubsInterval) - reconnectTimer := time.After(time.Hour * 12) for { select { - case <-reconnectC: - if err := reconnect(); err != nil { - bnc.log.Errorf("Error reconnecting: %v", err) - reconnectTimer = time.After(time.Second * 30) - checkSubs = make(<-chan time.Time) - continue - } - checkSubs = time.After(checkSubsInterval) - case <-reconnectTimer: - if err := reconnect(); err != nil { - bnc.log.Errorf("Error refreshing connection: %v", err) - reconnectTimer = time.After(time.Second * 30) - checkSubs = make(<-chan time.Time) - continue - } - reconnectTimer = time.After(time.Hour * 12) - checkSubs = time.After(checkSubsInterval) case <-checkSubs: if err := bnc.checkSubs(ctx); err != nil { bnc.log.Errorf("Error checking subscriptions: %v", err) @@ -1934,6 +1896,7 @@ func (bnc *binance) UnsubscribeMarket(baseID, quoteID uint32) (err error) { unsubscribe = true delete(bnc.books, mktID) closer = book.cm + bnc.marketStream.UpdateURL(bnc.marketStreamsURL()) } book.mtx.Unlock()