diff --git a/pkg/util/httpstream/httpstream.go b/pkg/util/httpstream/httpstream.go index 00ce5f785..32f075782 100644 --- a/pkg/util/httpstream/httpstream.go +++ b/pkg/util/httpstream/httpstream.go @@ -78,6 +78,8 @@ type Connection interface { // SetIdleTimeout sets the amount of time the connection may remain idle before // it is automatically closed. SetIdleTimeout(timeout time.Duration) + // RemoveStreams can be used to remove a set of streams from the Connection. + RemoveStreams(streams ...Stream) } // Stream represents a bidirectional communications channel that is part of an diff --git a/pkg/util/httpstream/spdy/connection.go b/pkg/util/httpstream/spdy/connection.go index 7a6881250..6cd22b9b6 100644 --- a/pkg/util/httpstream/spdy/connection.go +++ b/pkg/util/httpstream/spdy/connection.go @@ -31,7 +31,7 @@ import ( // streams. type connection struct { conn *spdystream.Connection - streams []httpstream.Stream + streams map[uint32]httpstream.Stream streamLock sync.Mutex newStreamHandler httpstream.NewStreamHandler } @@ -64,7 +64,11 @@ func NewServerConnection(conn net.Conn, newStreamHandler httpstream.NewStreamHan // will be invoked when the server receives a newly created stream from the // client. func newConnection(conn *spdystream.Connection, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection { - c := &connection{conn: conn, newStreamHandler: newStreamHandler} + c := &connection{ + conn: conn, + newStreamHandler: newStreamHandler, + streams: make(map[uint32]httpstream.Stream), + } go conn.Serve(c.newSpdyStream) return c } @@ -81,7 +85,7 @@ func (c *connection) Close() error { // calling Reset instead of Close ensures that all streams are fully torn down s.Reset() } - c.streams = make([]httpstream.Stream, 0) + c.streams = make(map[uint32]httpstream.Stream, 0) c.streamLock.Unlock() // now that all streams are fully torn down, it's safe to call close on the underlying connection, @@ -90,6 +94,15 @@ func (c *connection) Close() error { return c.conn.Close() } +// RemoveStreams can be used to removes a set of streams from the Connection. +func (c *connection) RemoveStreams(streams ...httpstream.Stream) { + c.streamLock.Lock() + for _, stream := range streams { + delete(c.streams, stream.Identifier()) + } + c.streamLock.Unlock() +} + // CreateStream creates a new stream with the specified headers and registers // it with the connection. func (c *connection) CreateStream(headers http.Header) (httpstream.Stream, error) { @@ -109,7 +122,7 @@ func (c *connection) CreateStream(headers http.Header) (httpstream.Stream, error // it owns. func (c *connection) registerStream(s httpstream.Stream) { c.streamLock.Lock() - c.streams = append(c.streams, s) + c.streams[s.Identifier()] = s c.streamLock.Unlock() } diff --git a/pkg/util/httpstream/spdy/connection_test.go b/pkg/util/httpstream/spdy/connection_test.go index e00b29c46..cfeef2c90 100644 --- a/pkg/util/httpstream/spdy/connection_test.go +++ b/pkg/util/httpstream/spdy/connection_test.go @@ -178,3 +178,41 @@ func TestConnectionCloseIsImmediateThroughAProxy(t *testing.T) { } } } + +type fakeStream struct{ id uint32 } + +func (*fakeStream) Read(p []byte) (int, error) { return 0, nil } +func (*fakeStream) Write(p []byte) (int, error) { return 0, nil } +func (*fakeStream) Close() error { return nil } +func (*fakeStream) Reset() error { return nil } +func (*fakeStream) Headers() http.Header { return nil } +func (f *fakeStream) Identifier() uint32 { return f.id } + +func TestConnectionRemoveStreams(t *testing.T) { + c := &connection{streams: make(map[uint32]httpstream.Stream)} + stream0 := &fakeStream{id: 0} + stream1 := &fakeStream{id: 1} + stream2 := &fakeStream{id: 2} + + c.registerStream(stream0) + c.registerStream(stream1) + + if len(c.streams) != 2 { + t.Fatalf("should have two streams, has %d", len(c.streams)) + } + + // not exists + c.RemoveStreams(stream2) + + if len(c.streams) != 2 { + t.Fatalf("should have two streams, has %d", len(c.streams)) + } + + // remove all existing + c.RemoveStreams(stream0, stream1) + + if len(c.streams) != 0 { + t.Fatalf("should not have any streams, has %d", len(c.streams)) + } + +}