Skip to content

Commit

Permalink
Merge pull request #133 from maelick/master
Browse files Browse the repository at this point in the history
Unsubscribe timeout to prevent deadlock (#1)
  • Loading branch information
worg authored Mar 2, 2024
2 parents b7da9a7 + 62a2b00 commit 5179ed7
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 55 deletions.
54 changes: 30 additions & 24 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,32 @@ const DefaultRcvReceiptTimeout = 30 * time.Second
// Default receipt timeout in Conn.Disconnect function
const DefaultDisconnectReceiptTimeout = 30 * time.Second

// Default receipt timeout in Subscription.Unsubscribe function
const DefaultUnsubscribeReceiptTimeout = 30 * time.Second

// Reply-To header used for temporary queues/RPC with rabbit.
const ReplyToHeader = "reply-to"

// A Conn is a connection to a STOMP server. Create a Conn using either
// the Dial or Connect function.
type Conn struct {
conn io.ReadWriteCloser
readCh chan *frame.Frame
writeCh chan writeRequest
version Version
session string
server string
readTimeout time.Duration
writeTimeout time.Duration
msgSendTimeout time.Duration
rcvReceiptTimeout time.Duration
disconnectReceiptTimeout time.Duration
hbGracePeriodMultiplier float64
closed bool
closeMutex *sync.Mutex
options *connOptions
log Logger
conn io.ReadWriteCloser
readCh chan *frame.Frame
writeCh chan writeRequest
version Version
session string
server string
readTimeout time.Duration
writeTimeout time.Duration
msgSendTimeout time.Duration
rcvReceiptTimeout time.Duration
disconnectReceiptTimeout time.Duration
unsubscribeReceiptTimeout time.Duration
hbGracePeriodMultiplier float64
closed bool
closeMutex *sync.Mutex
options *connOptions
log Logger
}

type writeRequest struct {
Expand Down Expand Up @@ -204,6 +208,7 @@ func Connect(conn io.ReadWriteCloser, opts ...func(*Conn) error) (*Conn, error)
c.msgSendTimeout = options.MsgSendTimeout
c.rcvReceiptTimeout = options.RcvReceiptTimeout
c.disconnectReceiptTimeout = options.DisconnectReceiptTimeout
c.unsubscribeReceiptTimeout = options.UnsubscribeReceiptTimeout

if options.ResponseHeadersCallback != nil {
options.ResponseHeadersCallback(response.Header)
Expand Down Expand Up @@ -678,14 +683,15 @@ func (c *Conn) Subscribe(destination string, ack AckMode, opts ...func(*frame.Fr

closeMutex := &sync.Mutex{}
sub := &Subscription{
id: id,
replyToSet: replyToSet,
destination: destination,
conn: c,
ackMode: ack,
C: make(chan *Message, 16),
closeMutex: closeMutex,
closeCond: sync.NewCond(closeMutex),
id: id,
replyToSet: replyToSet,
destination: destination,
conn: c,
ackMode: ack,
C: make(chan *Message, 16),
closeMutex: closeMutex,
closeCond: sync.NewCond(closeMutex),
unsubscribeReceiptTimeout: c.unsubscribeReceiptTimeout,
}
go sub.readLoop(ch)

Expand Down
14 changes: 14 additions & 0 deletions conn_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type connOptions struct {
MsgSendTimeout time.Duration
RcvReceiptTimeout time.Duration
DisconnectReceiptTimeout time.Duration
UnsubscribeReceiptTimeout time.Duration
HeartBeatGracePeriodMultiplier float64
Login, Passcode string
AcceptVersions []string
Expand All @@ -40,6 +41,7 @@ func newConnOptions(conn *Conn, opts []func(*Conn) error) (*connOptions, error)
MsgSendTimeout: DefaultMsgSendTimeout,
RcvReceiptTimeout: DefaultRcvReceiptTimeout,
DisconnectReceiptTimeout: DefaultDisconnectReceiptTimeout,
UnsubscribeReceiptTimeout: DefaultUnsubscribeReceiptTimeout,
Logger: log.StdLogger{},
}

Expand Down Expand Up @@ -156,6 +158,11 @@ var ConnOpt struct {
// avoid deadlocks. If this is not specified, the default is 30 seconds.
DisconnectReceiptTimeout func(disconnectReceiptTimeout time.Duration) func(*Conn) error

// UnsubscribeReceiptTimeout is a connect option that allows the client to specify
// how long to wait for a receipt in the Conn.Unsubscribe function. This helps
// avoid deadlocks. If this is not specified, the default is 30 seconds.
UnsubscribeReceiptTimeout func(unsubscribeReceiptTimeout time.Duration) func(*Conn) error

// HeartBeatGracePeriodMultiplier is used to calculate the effective read heart-beat timeout
// the broker will enforce for each client’s connection. The multiplier is applied to
// the read-timeout interval the client specifies in its CONNECT frame
Expand Down Expand Up @@ -262,6 +269,13 @@ func init() {
}
}

ConnOpt.UnsubscribeReceiptTimeout = func(unsubscribeReceiptTimeout time.Duration) func(*Conn) error {
return func(c *Conn) error {
c.options.UnsubscribeReceiptTimeout = unsubscribeReceiptTimeout
return nil
}
}

ConnOpt.HeartBeatGracePeriodMultiplier = func(multiplier float64) func(*Conn) error {
return func(c *Conn) error {
c.options.HeartBeatGracePeriodMultiplier = multiplier
Expand Down
29 changes: 15 additions & 14 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@ import (

// Error values
var (
ErrInvalidCommand = newErrorMessage("invalid command")
ErrInvalidFrameFormat = newErrorMessage("invalid frame format")
ErrUnsupportedVersion = newErrorMessage("unsupported version")
ErrCompletedTransaction = newErrorMessage("transaction is completed")
ErrNackNotSupported = newErrorMessage("NACK not supported in STOMP 1.0")
ErrNotReceivedMessage = newErrorMessage("cannot ack/nack a message, not from server")
ErrCannotNackAutoSub = newErrorMessage("cannot send NACK for a subscription with ack:auto")
ErrCompletedSubscription = newErrorMessage("subscription is unsubscribed")
ErrClosedUnexpectedly = newErrorMessage("connection closed unexpectedly")
ErrAlreadyClosed = newErrorMessage("connection already closed")
ErrMsgSendTimeout = newErrorMessage("msg send timeout")
ErrMsgReceiptTimeout = newErrorMessage("msg receipt timeout")
ErrDisconnectReceiptTimeout = newErrorMessage("disconnect receipt timeout")
ErrNilOption = newErrorMessage("nil option")
ErrInvalidCommand = newErrorMessage("invalid command")
ErrInvalidFrameFormat = newErrorMessage("invalid frame format")
ErrUnsupportedVersion = newErrorMessage("unsupported version")
ErrCompletedTransaction = newErrorMessage("transaction is completed")
ErrNackNotSupported = newErrorMessage("NACK not supported in STOMP 1.0")
ErrNotReceivedMessage = newErrorMessage("cannot ack/nack a message, not from server")
ErrCannotNackAutoSub = newErrorMessage("cannot send NACK for a subscription with ack:auto")
ErrCompletedSubscription = newErrorMessage("subscription is unsubscribed")
ErrClosedUnexpectedly = newErrorMessage("connection closed unexpectedly")
ErrAlreadyClosed = newErrorMessage("connection already closed")
ErrMsgSendTimeout = newErrorMessage("msg send timeout")
ErrMsgReceiptTimeout = newErrorMessage("msg receipt timeout")
ErrDisconnectReceiptTimeout = newErrorMessage("disconnect receipt timeout")
ErrUnsubscribeReceiptTimeout = newErrorMessage("unsubscribe receipt timeout")
ErrNilOption = newErrorMessage("nil option")
)

// StompError implements the Error interface, and provides
Expand Down
69 changes: 52 additions & 17 deletions subscription.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package stomp

import (
"errors"
"fmt"
"sync"
"sync/atomic"
"time"

"github.com/go-stomp/stomp/v3/frame"
)
Expand All @@ -19,15 +21,16 @@ const (
//
// Once a client has subscribed, it can receive messages from the C channel.
type Subscription struct {
C chan *Message
id string
replyToSet bool
destination string
conn *Conn
ackMode AckMode
state int32
closeMutex *sync.Mutex
closeCond *sync.Cond
C chan *Message
id string
replyToSet bool
destination string
conn *Conn
ackMode AckMode
state int32
closeMutex *sync.Mutex
closeCond *sync.Cond
unsubscribeReceiptTimeout time.Duration
}

// BUG(jpj): If the client does not read messages from the Subscription.C
Expand Down Expand Up @@ -80,7 +83,12 @@ func (s *Subscription) Unsubscribe(opts ...func(*frame.Frame) error) error {
f.Header.Set(ReplyToHeader, s.id)
}

s.conn.sendFrame(f)
err := s.conn.sendFrame(f)
if errors.Is(err, ErrClosedUnexpectedly) {
msg := s.subscriptionErrorMessage("connection closed unexpectedly")
s.closeChannel(msg)
return err
}

// UNSUBSCRIBE is a bit weird in that it is tagged with a "receipt" header
// on the I/O goroutine, so the above call to sendFrame() will not wait
Expand All @@ -91,10 +99,33 @@ func (s *Subscription) Unsubscribe(opts ...func(*frame.Frame) error) error {
// wait for the terminal state transition instead.
s.closeMutex.Lock()
for atomic.LoadInt32(&s.state) != subStateClosed {
s.closeCond.Wait()
err = waitWithTimeout(s.closeCond, s.unsubscribeReceiptTimeout)
if err != nil && errors.Is(err, &ErrUnsubscribeReceiptTimeout) {
msg := s.subscriptionErrorMessage("channel unsubscribe receipt timeout")
s.C <- msg
return err
}
}
s.closeMutex.Unlock()
return nil
return err
}

func waitWithTimeout(cond *sync.Cond, timeout time.Duration) error {
if timeout == 0 {
cond.Wait()
return nil
}
waitChan := make(chan struct{})
go func() {
cond.Wait()
close(waitChan)
}()
select {
case <-waitChan:
return nil
case <-time.After(timeout):
return &ErrUnsubscribeReceiptTimeout
}
}

// Read a message from the subscription. This is a convenience
Expand Down Expand Up @@ -123,17 +154,21 @@ func (s *Subscription) closeChannel(msg *Message) {
s.closeCond.Broadcast()
}

func (s *Subscription) subscriptionErrorMessage(message string) *Message {
return &Message{
Err: &Error{
Message: fmt.Sprintf("Subscription %s: %s: %s", s.id, s.destination, message),
},
}
}

func (s *Subscription) readLoop(ch chan *frame.Frame) {
for {
f, ok := <-ch
if !ok {
state := atomic.LoadInt32(&s.state)
if state == subStateActive || state == subStateClosing {
msg := &Message{
Err: &Error{
Message: fmt.Sprintf("Subscription %s: %s: channel read failed", s.id, s.destination),
},
}
msg := s.subscriptionErrorMessage("channel read failed")
s.closeChannel(msg)
}
return
Expand Down
Loading

0 comments on commit 5179ed7

Please sign in to comment.