diff --git a/client_test.go b/client_test.go index 815de15..e5a0154 100644 --- a/client_test.go +++ b/client_test.go @@ -288,7 +288,41 @@ func TestOpen(t *testing.T) { if c, err := Open(rwc, defaultConfig()); err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } +} + +func TestOpenClose_ShouldNotPanic(t *testing.T) { + rwc, srv := newSession(t) + t.Cleanup(func() { + _ = rwc.Close() + }) + + go func() { + srv.connectionOpen() + srv.connectionClose() + }() + + c, err := Open(rwc, defaultConfig()) + if err != nil { + t.Fatalf("could not create connection: %v (%s)", c, err) + } + + if err := c.Close(); err != nil { + t.Fatalf("could not close connection: %s", err) + } + defer func() { + if r := recover(); r != nil { + t.Fatalf("creating a channel on a closed connection should not panic: %s", r) + } + }() + + ch, err := c.Channel() + if ch != nil { + t.Fatalf("creating a channel on a closed connection should not succeed: %v, (%s)", ch, err) + } + if err != ErrClosed { + t.Fatalf("error should be closed: %s", err) + } } func TestChannelOpen(t *testing.T) { diff --git a/connection.go b/connection.go index 3cad6e7..def2260 100644 --- a/connection.go +++ b/connection.go @@ -551,8 +551,8 @@ func (c *Connection) shutdown(err *Error) { c.conn.Close() - c.channels = map[uint16]*Channel{} - c.allocator = newAllocator(1, c.Config.ChannelMax) + c.channels = nil + c.allocator = nil c.noNotify = true }) } @@ -770,8 +770,10 @@ func (c *Connection) releaseChannel(id uint16) { c.m.Lock() defer c.m.Unlock() - delete(c.channels, id) - c.allocator.release(int(id)) + if !c.IsClosed() { + delete(c.channels, id) + c.allocator.release(int(id)) + } } // openChannel allocates and opens a channel, must be paired with closeChannel @@ -919,6 +921,7 @@ func (c *Connection) openTune(config Config, auth Authentication) error { // Edge case that may race with c.shutdown() // https://github.com/rabbitmq/amqp091-go/issues/170 c.m.Lock() + // When the server and client both use default 0, then the max channel is // only limited by uint16. c.Config.ChannelMax = pick(config.ChannelMax, int(tune.ChannelMax)) @@ -926,6 +929,9 @@ func (c *Connection) openTune(config Config, auth Authentication) error { c.Config.ChannelMax = defaultChannelMax } c.Config.ChannelMax = min(c.Config.ChannelMax, maxChannelMax) + + c.allocator = newAllocator(1, c.Config.ChannelMax) + c.m.Unlock() // Frame size includes headers and end byte (len(payload)+8), even if @@ -982,9 +988,6 @@ func (c *Connection) openComplete() error { _ = deadliner.SetDeadline(time.Time{}) } - c.m.Lock() - c.allocator = newAllocator(1, c.Config.ChannelMax) - c.m.Unlock() return nil }