Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate TCPKeepalive in server.Serve (#1320) #1324

Merged
merged 1 commit into from
Jun 20, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 21 additions & 50 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1554,34 +1554,6 @@ func (s *Server) getNextProto(c net.Conn) (proto string, err error) {
return
}

// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe, ListenAndServeTLS and
// ListenAndServeTLSEmbed so dead TCP connections (e.g. closing laptop mid-download)
// eventually go away.
type tcpKeepaliveListener struct {
*net.TCPListener
keepalive bool
keepalivePeriod time.Duration
}

func (ln tcpKeepaliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
if err := tc.SetKeepAlive(ln.keepalive); err != nil {
tc.Close() //nolint:errcheck
return nil, err
}
if ln.keepalivePeriod > 0 {
if err := tc.SetKeepAlivePeriod(ln.keepalivePeriod); err != nil {
tc.Close() //nolint:errcheck
return nil, err
}
}
return tc, nil
}

// ListenAndServe serves HTTP requests from the given TCP4 addr.
//
// Pass custom listener to Serve if you need listening on non-TCP4 media
Expand All @@ -1593,13 +1565,6 @@ func (s *Server) ListenAndServe(addr string) error {
if err != nil {
return err
}
if tcpln, ok := ln.(*net.TCPListener); ok {
return s.Serve(tcpKeepaliveListener{
TCPListener: tcpln,
keepalive: s.TCPKeepalive,
keepalivePeriod: s.TCPKeepalivePeriod,
})
}
return s.Serve(ln)
}

Expand Down Expand Up @@ -1638,13 +1603,6 @@ func (s *Server) ListenAndServeTLS(addr, certFile, keyFile string) error {
if err != nil {
return err
}
if tcpln, ok := ln.(*net.TCPListener); ok {
return s.ServeTLS(tcpKeepaliveListener{
TCPListener: tcpln,
keepalive: s.TCPKeepalive,
keepalivePeriod: s.TCPKeepalivePeriod,
}, certFile, keyFile)
}
return s.ServeTLS(ln, certFile, keyFile)
}

Expand All @@ -1664,13 +1622,6 @@ func (s *Server) ListenAndServeTLSEmbed(addr string, certData, keyData []byte) e
if err != nil {
return err
}
if tcpln, ok := ln.(*net.TCPListener); ok {
return s.ServeTLSEmbed(tcpKeepaliveListener{
TCPListener: tcpln,
keepalive: s.TCPKeepalive,
keepalivePeriod: s.TCPKeepalivePeriod,
}, certData, keyData)
}
return s.ServeTLSEmbed(ln, certData, keyData)
}

Expand Down Expand Up @@ -1910,7 +1861,27 @@ func (s *Server) Shutdown() error {

func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) {
for {
c, err := ln.Accept()
var c net.Conn
var err error
if tl, ok := ln.(*net.TCPListener); ok && s.TCPKeepalive {
tc, err := tl.AcceptTCP()
if err != nil {
return nil, err
}
if err := tc.SetKeepAlive(s.TCPKeepalive); err != nil {
tc.Close() //nolint:errcheck
return nil, err
}
if s.TCPKeepalivePeriod > 0 {
if err := tc.SetKeepAlivePeriod(s.TCPKeepalivePeriod); err != nil {
tc.Close() //nolint:errcheck
return nil, err
}
}
c = tc
} else {
c, err = ln.Accept()
}
if err != nil {
if c != nil {
panic("BUG: net.Listener returned non-nil conn and non-nil error")
Expand Down