Skip to content

Commit

Permalink
AutoReconnect and other fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
martonp committed Sep 28, 2024
1 parent 23f73db commit a15e541
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 168 deletions.
175 changes: 112 additions & 63 deletions client/comms/wsconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ type WsConn interface {
RequestWithTimeout(msg *msgjson.Message, respHandler func(*msgjson.Message), expireTime time.Duration, expire func()) error
Connect(ctx context.Context) (*sync.WaitGroup, error)
MessageSource() <-chan *msgjson.Message
UpdateURL(string)
}

// When the DEX sends a request to the client, a responseHandler is created
Expand All @@ -118,6 +119,10 @@ type WsCfg struct {
// latency.
PingWait time.Duration

// AutoReconnect, if non-nil, will reconnect to the server after each
// interval of the amount of time specified.
AutoReconnect *time.Duration

// The server's certificate.
Cert []byte

Expand Down Expand Up @@ -161,6 +166,7 @@ type wsConn struct {
cfg *WsCfg
tlsCfg *tls.Config
readCh chan *msgjson.Message
URL atomic.Value // string

wsMtx sync.Mutex
ws *websocket.Conn
Expand Down Expand Up @@ -203,14 +209,23 @@ func NewWsConn(cfg *WsCfg) (WsConn, error) {
ServerName: uri.Hostname(),
}

return &wsConn{
conn := &wsConn{
cfg: cfg,
log: cfg.Logger,
tlsCfg: tlsConfig,
readCh: make(chan *msgjson.Message, readBuffSize),
respHandlers: make(map[uint64]*responseHandler),
reconnectCh: make(chan struct{}, 1),
}, nil
}

conn.URL.Store(cfg.URL)

return conn, nil
}

// UpdateURL updates the URL that the connection uses when reconnecting.
func (conn *wsConn) UpdateURL(rawURL string) {
conn.URL.Store(rawURL)
}

// IsDown indicates if the connection is known to be down.
Expand Down Expand Up @@ -240,7 +255,7 @@ func (conn *wsConn) connect(ctx context.Context) error {
dialer.Proxy = http.ProxyFromEnvironment
}

ws, _, err := dialer.DialContext(ctx, conn.cfg.URL, conn.cfg.ConnectHeaders)
ws, _, err := dialer.DialContext(ctx, conn.URL.Load().(string), conn.cfg.ConnectHeaders)
if err != nil {
if isErrorInvalidCert(err) {
conn.setConnectionStatus(InvalidCert)
Expand Down Expand Up @@ -303,9 +318,9 @@ func (conn *wsConn) connect(ctx context.Context) error {
go func() {
defer conn.wg.Done()
if conn.cfg.RawHandler != nil {
conn.readRaw(ctx)
conn.readRaw(ctx, ws)
} else {
conn.read(ctx)
conn.read(ctx, ws)
}
}()

Expand All @@ -331,7 +346,7 @@ func (conn *wsConn) handleReadError(err error) {

var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
conn.log.Errorf("Read timeout on connection to %s.", conn.cfg.URL)
conn.log.Errorf("Read timeout on connection to %s.", conn.URL.Load().(string))
reconnect()
return
}
Expand Down Expand Up @@ -372,75 +387,110 @@ func (conn *wsConn) close() {
conn.ws.Close()
}

func (conn *wsConn) readRaw(ctx context.Context) {
for {
// Lock since conn.ws may be set by connect.
conn.wsMtx.Lock()
ws := conn.ws
conn.wsMtx.Unlock()
func (conn *wsConn) readRaw(ctx context.Context, ws *websocket.Conn) {
var reconnectTimer <-chan time.Time
if conn.cfg.AutoReconnect != nil {
reconnectTimer = time.After(*conn.cfg.AutoReconnect)
}

type readResult struct {
msgBytes []byte
err error
}

readMessage := func() chan *readResult {
ch := make(chan *readResult, 1)
go func() {
_, msgBytes, err := ws.ReadMessage()
ch <- &readResult{msgBytes, err}
}()
return ch
}

// Block until a message is received or an error occurs.
_, msgBytes, err := ws.ReadMessage()
// Drop the read error on context cancellation.
if ctx.Err() != nil {
for {
select {
case result := <-readMessage():
if ctx.Err() != nil {
return
}
if result.err != nil {
conn.handleReadError(result.err)
return
}
conn.cfg.RawHandler(result.msgBytes)
case <-reconnectTimer:
conn.reconnectCh <- struct{}{}
return
}
if err != nil {
conn.handleReadError(err)
case <-ctx.Done():
return
}
conn.cfg.RawHandler(msgBytes)
}
}

// read fetches and parses incoming messages for processing. This should be
// run as a goroutine. Increment the wg before calling read.
func (conn *wsConn) read(ctx context.Context) {
for {
msg := new(msgjson.Message)
func (conn *wsConn) read(ctx context.Context, ws *websocket.Conn) {
var reconnectTimer <-chan time.Time
if conn.cfg.AutoReconnect != nil {
reconnectTimer = time.After(*conn.cfg.AutoReconnect)
}

// Lock since conn.ws may be set by connect.
conn.wsMtx.Lock()
ws := conn.ws
conn.wsMtx.Unlock()
type readResult struct {
msg *msgjson.Message
err error
}

// The read itself does not require locking since only this goroutine
// uses read functions that are not safe for concurrent use.
err := ws.ReadJSON(msg)
// Drop the read error on context cancellation.
if ctx.Err() != nil {
return
}
if err != nil {
var mErr *json.UnmarshalTypeError
if errors.As(err, &mErr) {
// JSON decode errors are not fatal, log and proceed.
conn.log.Errorf("json decode error: %v", mErr)
continue
}
conn.handleReadError(err)
return
}
readMessage := func() chan *readResult {
ch := make(chan *readResult, 1)
go func() {
msg := new(msgjson.Message)
err := ws.ReadJSON(msg)
ch <- &readResult{msg, err}
}()
return ch
}

// If the message is a response, find the handler.
if msg.Type == msgjson.Response {
handler := conn.respHandler(msg.ID)
if handler == nil {
b, _ := json.Marshal(msg)
conn.log.Errorf("No handler found for response: %v", string(b))
for {
select {
case result := <-readMessage():
if ctx.Err() != nil {
return
}
if result.err != nil {
var mErr *json.UnmarshalTypeError
if errors.As(result.err, &mErr) {
// JSON decode errors are not fatal, log and proceed.
conn.log.Errorf("json decode error: %v", mErr)
continue
}
conn.handleReadError(result.err)
return
}
// If the message is a response, find the handler.
if result.msg.Type == msgjson.Response {
handler := conn.respHandler(result.msg.ID)
if handler == nil {
b, _ := json.Marshal(result.msg)
conn.log.Errorf("No handler found for response: %v", string(b))
continue
}
// Run handlers in a goroutine so that other messages can be
// received. Include the handler goroutines in the WaitGroup to
// allow them to complete if the connection master desires.
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
handler.f(result.msg)
}()
continue
}
// Run handlers in a goroutine so that other messages can be
// received. Include the handler goroutines in the WaitGroup to
// allow them to complete if the connection master desires.
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
handler.f(msg)
}()
continue
conn.readCh <- result.msg
case <-reconnectTimer:
conn.reconnectCh <- struct{}{}
return
case <-ctx.Done():
return
}
conn.readCh <- msg
}
}

Expand All @@ -457,11 +507,11 @@ func (conn *wsConn) keepAlive(ctx context.Context) {
return
}

conn.log.Infof("Attempting to reconnect to %s...", conn.cfg.URL)
conn.log.Infof("Attempting to reconnect to %s...", conn.URL.Load().(string))
err := conn.connect(ctx)
if err != nil {
conn.log.Errorf("Reconnect failed. Scheduling reconnect to %s in %.1f seconds.",
conn.cfg.URL, rcInt.Seconds())
conn.URL.Load().(string), rcInt.Seconds())
time.AfterFunc(rcInt, func() {
conn.reconnectCh <- struct{}{}
})
Expand All @@ -479,7 +529,6 @@ func (conn *wsConn) keepAlive(ctx context.Context) {
if conn.cfg.ReconnectSync != nil {
conn.cfg.ReconnectSync()
}

case <-ctx.Done():
return
}
Expand Down
2 changes: 2 additions & 0 deletions client/core/core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ func (conn *TWebsocket) Connect(context.Context) (*sync.WaitGroup, error) {
return &sync.WaitGroup{}, conn.connectErr
}

func (conn *TWebsocket) UpdateURL(rawURL string) {}

type TDB struct {
updateWalletErr error
acct *db.AccountInfo
Expand Down
Loading

0 comments on commit a15e541

Please sign in to comment.