diff --git a/mq.go b/mq.go index 161595e..6455a59 100644 --- a/mq.go +++ b/mq.go @@ -19,6 +19,10 @@ const ( // Describes states during reconnect. statusReadyForReconnect int32 = 0 statusReconnecting int32 = 1 + + ConnectionStateDisconnected ConnectionState = 1 + ConnectionStateConnected ConnectionState = 2 + ConnectionStateConnecting ConnectionState = 3 ) // Used for creating connection to the fake AMQP server for tests. @@ -45,8 +49,12 @@ type MQ interface { Error() <-chan error // Close stop all consumers and producers and close connection to broker. Close() + // Shows connection state + ConnectionState() ConnectionState } +type ConnectionState uint8 + type mq struct { channel wabbit.Channel config Config @@ -60,6 +68,7 @@ type mq struct { sync.Once currentNode int32 } + state *int32 } // New initializes AMQP connection to the message broker @@ -74,7 +83,9 @@ func New(config Config) (MQ, error) { internalErrorChannel: make(chan error), consumers: newConsumersRegistry(len(config.Consumers)), producers: newProducersRegistry(len(config.Producers)), + state: new(int32), } + atomic.StoreInt32(mq.state, int32(ConnectionStateDisconnected)) if err := mq.connect(); err != nil { return nil, err @@ -152,9 +163,15 @@ func (mq *mq) Close() { } } +func (mq *mq) ConnectionState() ConnectionState { + return ConnectionState(atomic.LoadInt32(mq.state)) +} + func (mq *mq) connect() error { + atomic.StoreInt32(mq.state, int32(ConnectionStateConnecting)) connection, err := mq.createConnection() if err != nil { + atomic.StoreInt32(mq.state, int32(ConnectionStateDisconnected)) return err } @@ -162,6 +179,7 @@ func (mq *mq) connect() error { if err != nil { _ = connection.Close() + atomic.StoreInt32(mq.state, int32(ConnectionStateDisconnected)) return err } @@ -170,6 +188,7 @@ func (mq *mq) connect() error { go mq.handleCloseEvent() + atomic.StoreInt32(mq.state, int32(ConnectionStateConnected)) return nil } @@ -195,6 +214,7 @@ func (mq *mq) handleCloseEvent() { if err != nil { mq.internalErrorChannel <- err } + atomic.StoreInt32(mq.state, int32(ConnectionStateDisconnected)) } func (mq *mq) errorHandler() { diff --git a/mq_test.go b/mq_test.go index 9899017..412e0d3 100644 --- a/mq_test.go +++ b/mq_test.go @@ -530,6 +530,81 @@ func Test_mq_createConnection(t *testing.T) { } } +func TestMq_ConnectionState(t *testing.T) { + cases := []struct { + name string + expected ConnectionState + }{ + {name: "status disconnected", expected: ConnectionStateDisconnected}, + {name: "status changed", expected: ConnectionStateConnecting}, + } + for _, tc := range cases { + cfg := newDefaultConfig() + cfg.TestMode = true + cfg.normalize() + + mq := &mq{ + config: cfg, + errorChannel: make(chan error), + internalErrorChannel: make(chan error), + consumers: newConsumersRegistry(len(cfg.Consumers)), + producers: newProducersRegistry(len(cfg.Producers)), + state: new(int32), + } + atomic.StoreInt32(mq.state, int32(tc.expected)) + + t.Run(tc.name, func(t *testing.T) { + defer mq.Close() + if mq.ConnectionState() != tc.expected { + t.Errorf("ConnectionState() current value %v, expected broker %v", mq.ConnectionState(), tc.expected) + } + }) + } + +} + +func TestMq_connect(t *testing.T) { + s := server.NewServer(dsnForTests) + _ = s.Start() + defer func() { _ = s.Stop() }() + cases := []struct { + name string + expected ConnectionState + isConnectError bool + isChannelError bool + }{ + {name: "success connect", expected: ConnectionStateConnected}, + {name: "failed to connect", expected: ConnectionStateDisconnected, isConnectError: true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cfg := newDefaultConfig() + cfg.TestMode = true + cfg.normalize() + + mq := &mq{ + config: cfg, + errorChannel: make(chan error), + internalErrorChannel: make(chan error), + consumers: newConsumersRegistry(len(cfg.Consumers)), + producers: newProducersRegistry(len(cfg.Producers)), + state: new(int32), + } + defer mq.Close() + if tc.isConnectError { + _ = s.Stop() + } + err := mq.connect() + if err != nil && !tc.isConnectError { + t.Errorf("connect() no error expected, but got: %v", err) + } + if mq.ConnectionState() != tc.expected { + t.Errorf("connect() expected state %v, got: %v", tc.expected, mq.ConnectionState()) + } + }) + } +} + func assertNoMqError(t *testing.T, mq MQ) { select { case err := <-mq.Error():