Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 53 additions & 75 deletions modules/caddyhttp/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (

"go.uber.org/zap"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"

"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddyevents"
Expand Down Expand Up @@ -236,15 +235,6 @@ func (app *App) Provision(ctx caddy.Context) error {
for _, srvProtocol := range srv.Protocols {
srvProtocolsUnique[srvProtocol] = struct{}{}
}
_, h1ok := srvProtocolsUnique["h1"]
_, h2ok := srvProtocolsUnique["h2"]
_, h2cok := srvProtocolsUnique["h2c"]

// the Go standard library does not let us serve only HTTP/2 using
// http.Server; we would probably need to write our own server
if !h1ok && (h2ok || h2cok) {
return fmt.Errorf("server %s: cannot enable HTTP/2 or H2C without enabling HTTP/1.1; add h1 to protocols or remove h2/h2c", srvName)
}

if srv.ListenProtocols != nil {
if len(srv.ListenProtocols) != len(srv.Listen) {
Expand Down Expand Up @@ -278,19 +268,6 @@ func (app *App) Provision(ctx caddy.Context) error {
}
}

lnProtocolsIncludeUnique := map[string]struct{}{}
for _, lnProtocol := range lnProtocolsInclude {
lnProtocolsIncludeUnique[lnProtocol] = struct{}{}
}
_, h1ok := lnProtocolsIncludeUnique["h1"]
_, h2ok := lnProtocolsIncludeUnique["h2"]
_, h2cok := lnProtocolsIncludeUnique["h2c"]

// check if any listener protocols contain h2 or h2c without h1
if !h1ok && (h2ok || h2cok) {
return fmt.Errorf("server %s, listener %d: cannot enable HTTP/2 or H2C without enabling HTTP/1.1; add h1 to protocols or remove h2/h2c", srvName, i)
}

srv.ListenProtocols[i] = lnProtocolsInclude
}
}
Expand Down Expand Up @@ -448,6 +425,25 @@ func (app *App) Validate() error {
return nil
}

func removeTLSALPN(srv *Server, target string) {
for _, cp := range srv.TLSConnPolicies {
// the TLSConfig was already provisioned, so... manually remove it
for i, np := range cp.TLSConfig.NextProtos {
if np == target {
cp.TLSConfig.NextProtos = append(cp.TLSConfig.NextProtos[:i], cp.TLSConfig.NextProtos[i+1:]...)
break
}
}
// remove it from the parent connection policy too, just to keep things tidy
for i, alpn := range cp.ALPN {
if alpn == target {
cp.ALPN = append(cp.ALPN[:i], cp.ALPN[i+1:]...)
break
}
}
}
}

// Start runs the app. It finishes automatic HTTPS if enabled,
// including management of certificates.
func (app *App) Start() error {
Expand All @@ -466,32 +462,37 @@ func (app *App) Start() error {
MaxHeaderBytes: srv.MaxHeaderBytes,
Handler: srv,
ErrorLog: serverLogger,
Protocols: new(http.Protocols),
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
return context.WithValue(ctx, ConnCtxKey, c)
},
}
h2server := new(http2.Server)

// disable HTTP/2, which we enabled by default during provisioning
if !srv.protocol("h2") {
srv.server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
for _, cp := range srv.TLSConnPolicies {
// the TLSConfig was already provisioned, so... manually remove it
for i, np := range cp.TLSConfig.NextProtos {
if np == "h2" {
cp.TLSConfig.NextProtos = append(cp.TLSConfig.NextProtos[:i], cp.TLSConfig.NextProtos[i+1:]...)
break
}
}
// remove it from the parent connection policy too, just to keep things tidy
for i, alpn := range cp.ALPN {
if alpn == "h2" {
cp.ALPN = append(cp.ALPN[:i], cp.ALPN[i+1:]...)
break
}
}
}
} else {
removeTLSALPN(srv, "h2")
}
if !srv.protocol("h1") {
removeTLSALPN(srv, "http/1.1")
}

// configure the http versions the server will serve
if srv.protocol("h1") {
srv.server.Protocols.SetHTTP1(true)
}

if srv.protocol("h2") || srv.protocol("h2c") {
// skip setting h2 because if NextProtos is present, it's list of alpn versions will take precedence.
// it will always be present because http2.ConfigureServer will populate that field
// enabling h2c because some listener wrapper will wrap the connection that is no longer *tls.Conn
// However, we need to handle the case that if the connection is h2c but h2c is not enabled. We identify
// this type of connection by checking if it's behind a TLS listener wrapper or if it implements tls.ConnectionState.
srv.server.Protocols.SetUnencryptedHTTP2(true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this only be set if h2c, but not h2?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is deliberate. if the handshake succeeds with h2, it will be handled accordingly. h2c is enabled to handle listener wrappers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if the user only configured h2 and not h2c, it'll still have SetUnencryptedHTTP2(true), right? That doesn't seem correct.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or wait, you mean if there's LWs then it could appear as though the conn is no longer encrypted cause it's already been decrypted? Is that what you mean? Tricky.

// when h2c is enabled but h2 disabled, we already removed h2 from NextProtos
// the handshake will never succeed with h2
// http2.ConfigureServer will enable the server to handle both h2 and h2c
h2server := new(http2.Server)
//nolint:errcheck
http2.ConfigureServer(srv.server, h2server)
}
Expand All @@ -501,11 +502,6 @@ func (app *App) Start() error {
tlsCfg := srv.TLSConnPolicies.TLSConfig(app.ctx)
srv.configureServer(srv.server)

// enable H2C if configured
if srv.protocol("h2c") {
srv.server.Handler = h2c.NewHandler(srv, h2server)
}

for lnIndex, lnAddr := range srv.Listen {
listenAddr, err := caddy.ParseNetworkAddress(lnAddr)
if err != nil {
Expand Down Expand Up @@ -570,15 +566,13 @@ func (app *App) Start() error {
ln = srv.listenerWrappers[i].WrapListener(ln)
}

// handle http2 if use tls listener wrapper
if h2ok {
http2lnWrapper := &http2Listener{
Listener: ln,
server: srv.server,
h2server: h2server,
}
srv.h2listeners = append(srv.h2listeners, http2lnWrapper)
ln = http2lnWrapper
// check if the connection is h2c
ln = &http2Listener{
useTLS: useTLS,
useH1: h1ok,
useH2: h2ok || h2cok,
Listener: ln,
logger: app.logger,
}

// if binding to port 0, the OS chooses a port for us;
Expand All @@ -596,11 +590,8 @@ func (app *App) Start() error {

srv.listeners = append(srv.listeners, ln)

// enable HTTP/1 if configured
if h1ok {
//nolint:errcheck
go srv.server.Serve(ln)
}
//nolint:errcheck
go srv.server.Serve(ln)
}

if h2ok && !useTLS {
Expand Down Expand Up @@ -756,25 +747,12 @@ func (app *App) Stop() error {
}
}
}
stopH2Listener := func(server *Server) {
defer finishedShutdown.Done()
startedShutdown.Done()

for i, s := range server.h2listeners {
if err := s.Shutdown(ctx); err != nil {
app.logger.Error("http2 listener shutdown",
zap.Error(err),
zap.Int("index", i))
}
}
}

for _, server := range app.Servers {
startedShutdown.Add(3)
finishedShutdown.Add(3)
startedShutdown.Add(2)
finishedShutdown.Add(2)
go stopServer(server)
go stopH3Server(server)
go stopH2Listener(server)
}

// block until all the goroutines have been run by the scheduler;
Expand Down
154 changes: 81 additions & 73 deletions modules/caddyhttp/http2listener.go
Original file line number Diff line number Diff line change
@@ -1,102 +1,110 @@
package caddyhttp

import (
"context"
"crypto/tls"
weakrand "math/rand"
"io"
"net"
"net/http"
"sync/atomic"
"time"

"go.uber.org/zap"
"golang.org/x/net/http2"
)

type connectionStater interface {
ConnectionState() tls.ConnectionState
}

// http2Listener wraps the listener to solve the following problems:
// 1. server h2 natively without using h2c hack when listener handles tls connection but
// don't return *tls.Conn
// 2. graceful shutdown. the shutdown logic is copied from stdlib http.Server, it's an extra maintenance burden but
// whatever, the shutdown logic maybe extracted to be used with h2c graceful shutdown. http2.Server supports graceful shutdown
// sending GO_AWAY frame to connected clients, but doesn't track connection status. It requires explicit call of http2.ConfigureServer
// 1. prevent genuine h2c connections from succeeding if h2c is not enabled
// and the connection doesn't implment connectionStater or the resulting NegotiatedProtocol
// isn't http2.
// This does allow a connection to pass as tls enabled even if it's not, listener wrappers
// can do this.
// 2. After wrapping the connection doesn't implement connectionStater, emit a warning so that listener
// wrapper authors will hopefully implement it.
// 3. check if the connection matches a specific http version. h2/h2c has a distinct preface.
type http2Listener struct {
cnt uint64
useTLS bool
useH1 bool
useH2 bool
net.Listener
server *http.Server
h2server *http2.Server
}

type connectionStateConn interface {
net.Conn
ConnectionState() tls.ConnectionState
logger *zap.Logger
}

func (h *http2Listener) Accept() (net.Conn, error) {
for {
conn, err := h.Listener.Accept()
if err != nil {
return nil, err
}
conn, err := h.Listener.Accept()
if err != nil {
return nil, err
}

if csc, ok := conn.(connectionStateConn); ok {
// *tls.Conn will return empty string because it's only populated after handshake is complete
if csc.ConnectionState().NegotiatedProtocol == http2.NextProtoTLS {
go h.serveHttp2(csc)
continue
}
}
_, isConnectionStater := conn.(connectionStater)
// emit a warning
if h.useTLS && !isConnectionStater {
h.logger.Warn("tls is enabled, but listener wrapper returns a connection that doesn't implement connectionStater")
} else if !h.useTLS && isConnectionStater {
h.logger.Warn("tls is disabled, but listener wrapper returns a connection that implements connectionStater")
}

// if both h1 and h2 are enabled, we don't need to check the preface
if h.useH1 && h.useH2 {
return conn, nil
}

// impossible both are false, either useH1 or useH2 must be true,
// or else the listener wouldn't be created
h2Conn := &http2Conn{
h2Expected: h.useH2,
Conn: conn,
}
if isConnectionStater {
return http2StateConn{h2Conn}, nil
}
return h2Conn, nil
}

func (h *http2Listener) serveHttp2(csc connectionStateConn) {
atomic.AddUint64(&h.cnt, 1)
h.runHook(csc, http.StateNew)
defer func() {
csc.Close()
atomic.AddUint64(&h.cnt, ^uint64(0))
h.runHook(csc, http.StateClosed)
}()
h.h2server.ServeConn(csc, &http2.ServeConnOpts{
Context: h.server.ConnContext(context.Background(), csc),
BaseConfig: h.server,
Handler: h.server.Handler,
})
type http2StateConn struct {
*http2Conn
}

const shutdownPollIntervalMax = 500 * time.Millisecond
func (conn http2StateConn) ConnectionState() tls.ConnectionState {
return conn.Conn.(connectionStater).ConnectionState()
}

func (h *http2Listener) Shutdown(ctx context.Context) error {
pollIntervalBase := time.Millisecond
nextPollInterval := func() time.Duration {
// Add 10% jitter.
//nolint:gosec
interval := pollIntervalBase + time.Duration(weakrand.Intn(int(pollIntervalBase/10)))
// Double and clamp for next time.
pollIntervalBase *= 2
if pollIntervalBase > shutdownPollIntervalMax {
pollIntervalBase = shutdownPollIntervalMax
}
return interval
}
type http2Conn struct {
// current index where the preface should match,
// no matching is done if idx is >= len(http2.ClientPreface)
idx int
// whether the connection is expected to be h2/h2c
h2Expected bool
// log if one such connection is detected
logger *zap.Logger
net.Conn
}

timer := time.NewTimer(nextPollInterval())
defer timer.Stop()
for {
if atomic.LoadUint64(&h.cnt) == 0 {
return nil
func (c *http2Conn) Read(p []byte) (int, error) {
if c.idx >= len(http2.ClientPreface) {
return c.Conn.Read(p)
}
n, err := c.Conn.Read(p)
for i := range n {
// first mismatch
if p[i] != http2.ClientPreface[c.idx] {
// close the connection if h2 is expected
if c.h2Expected {
c.logger.Debug("h1 connection detected, but h1 is not enabled")
_ = c.Conn.Close()
return 0, io.EOF
}
// no need to continue matching anymore
c.idx = len(http2.ClientPreface)
return n, err
}
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
timer.Reset(nextPollInterval())
c.idx++
// matching complete
if c.idx == len(http2.ClientPreface) && !c.h2Expected {
c.logger.Debug("h2/h2c connection detected, but h2/h2c is not enabled")
_ = c.Conn.Close()
return 0, io.EOF
}
}
}

func (h *http2Listener) runHook(conn net.Conn, state http.ConnState) {
if h.server.ConnState != nil {
h.server.ConnState(conn, state)
}
return n, err
}
Loading
Loading