diff --git a/internal/quic/config.go b/internal/quic/config.go new file mode 100644 index 000000000..7d1b7433a --- /dev/null +++ b/internal/quic/config.go @@ -0,0 +1,20 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "crypto/tls" +) + +// A Config structure configures a QUIC endpoint. +// A Config must not be modified after it has been passed to a QUIC function. +// A Config may be reused; the quic package will also not modify it. +type Config struct { + // TLSConfig is the endpoint's TLS configuration. + // It must be non-nil and include at least one certificate or else set GetCertificate. + TLSConfig *tls.Config +} diff --git a/internal/quic/conn.go b/internal/quic/conn.go index e6375e86a..8130c549b 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -7,6 +7,7 @@ package quic import ( + "crypto/tls" "errors" "fmt" "net/netip" @@ -19,6 +20,7 @@ import ( type Conn struct { side connSide listener connListener + config *Config testHooks connTestHooks peerAddr netip.AddrPort @@ -29,14 +31,27 @@ type Conn struct { w packetWriter acks [numberSpaceCount]ackState // indexed by number space connIDState connIDState - tlsState tlsState loss lossState + // errForPeer is set when the connection is being closed. + errForPeer error + connCloseSent [numberSpaceCount]bool + // idleTimeout is the time at which the connection will be closed due to inactivity. // https://www.rfc-editor.org/rfc/rfc9000#section-10.1 maxIdleTimeout time.Duration idleTimeout time.Time + // Packet protection keys, CRYPTO streams, and TLS state. + rkeys [numberSpaceCount]keys + wkeys [numberSpaceCount]keys + crypto [numberSpaceCount]cryptoStream + tls *tls.QUICConn + + // handshakeConfirmed is set when the handshake is confirmed. + // For server connections, it tracks sending HANDSHAKE_DONE. + handshakeConfirmed sentVal + peerAckDelayExponent int8 // -1 when unknown // Tests only: Send a PING in a specific number space. @@ -53,12 +68,14 @@ type connListener interface { // connTestHooks override conn behavior in tests. type connTestHooks interface { nextMessage(msgc chan any, nextTimeout time.Time) (now time.Time, message any) + handleTLSEvent(tls.QUICEvent) } -func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort, l connListener, hooks connTestHooks) (*Conn, error) { +func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort, config *Config, l connListener, hooks connTestHooks) (*Conn, error) { c := &Conn{ side: side, listener: l, + config: config, peerAddr: peerAddr, msgc: make(chan any, 1), donec: make(chan struct{}), @@ -88,12 +105,58 @@ func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip. const maxDatagramSize = 1200 c.loss.init(c.side, maxDatagramSize, now) - c.tlsState.init(c.side, initialConnID) + c.startTLS(now, initialConnID, transportParameters{ + initialSrcConnID: c.connIDState.srcConnID(), + ackDelayExponent: ackDelayExponent, + maxUDPPayloadSize: maxUDPPayloadSize, + maxAckDelay: maxAckDelay, + }) go c.loop(now) return c, nil } +// confirmHandshake is called when the handshake is confirmed. +// https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2 +func (c *Conn) confirmHandshake(now time.Time) { + // If handshakeConfirmed is unset, the handshake is not confirmed. + // If it is unsent, the handshake is confirmed and we need to send a HANDSHAKE_DONE. + // If it is sent, we have sent a HANDSHAKE_DONE. + // If it is received, the handshake is confirmed and we do not need to send anything. + if c.handshakeConfirmed.isSet() { + return // already confirmed + } + if c.side == serverSide { + // When the server confirms the handshake, it sends a HANDSHAKE_DONE. + c.handshakeConfirmed.setUnsent() + } else { + // The client never sends a HANDSHAKE_DONE, so we set handshakeConfirmed + // to the received state, indicating that the handshake is confirmed and we + // don't need to send anything. + c.handshakeConfirmed.setReceived() + } + c.loss.confirmHandshake() + // "An endpoint MUST discard its Handshake keys when the TLS handshake is confirmed" + // https://www.rfc-editor.org/rfc/rfc9001#section-4.9.2-1 + c.discardKeys(now, handshakeSpace) +} + +// discardKeys discards unused packet protection keys. +// https://www.rfc-editor.org/rfc/rfc9001#section-4.9 +func (c *Conn) discardKeys(now time.Time, space numberSpace) { + c.rkeys[space].discard() + c.wkeys[space].discard() + c.loss.discardKeys(now, space) +} + +// receiveTransportParameters applies transport parameters sent by the peer. +func (c *Conn) receiveTransportParameters(p transportParameters) { + c.peerAckDelayExponent = p.ackDelayExponent + c.loss.setMaxAckDelay(p.maxAckDelay) + + // TODO: Many more transport parameters to come. +} + type timerEvent struct{} // loop is the connection main loop. @@ -104,6 +167,7 @@ type timerEvent struct{} // Other goroutines may examine or modify conn state by sending the loop funcs to execute. func (c *Conn) loop(now time.Time) { defer close(c.donec) + defer c.tls.Close() // The connection timer sends a message to the connection loop on expiry. // We need to give it an expiry when creating it, so set the initial timeout to @@ -201,8 +265,9 @@ func (c *Conn) runOnLoop(f func(now time.Time, c *Conn)) error { // abort terminates a connection with an error. func (c *Conn) abort(now time.Time, err error) { - // TODO: Send CONNECTION_CLOSE frames. - c.exit() + if c.errForPeer == nil { + c.errForPeer = err + } } // exit fully terminates a connection immediately. diff --git a/internal/quic/conn_loss.go b/internal/quic/conn_loss.go index 11ed42dbb..6cb459c33 100644 --- a/internal/quic/conn_loss.go +++ b/internal/quic/conn_loss.go @@ -29,7 +29,7 @@ func (c *Conn) handleAckOrLoss(space numberSpace, sent *sentPacket, fate packetF for !sent.done() { switch f := sent.next(); f { default: - panic(fmt.Sprintf("BUG: unhandled lost frame type %x", f)) + panic(fmt.Sprintf("BUG: unhandled acked/lost frame type %x", f)) case frameTypeAck: // Unlike most information, loss of an ACK frame does not trigger // retransmission. ACKs are sent in response to ack-eliciting packets, @@ -41,6 +41,11 @@ func (c *Conn) handleAckOrLoss(space numberSpace, sent *sentPacket, fate packetF if fate == packetAcked { c.acks[space].handleAck(largest) } + case frameTypeCrypto: + start, end := sent.nextRange() + c.crypto[space].ackOrLoss(start, end, fate) + case frameTypeHandshakeDone: + c.handshakeConfirmed.ackOrLoss(sent.num, fate) } } } diff --git a/internal/quic/conn_loss_test.go b/internal/quic/conn_loss_test.go new file mode 100644 index 000000000..be4f5fb2c --- /dev/null +++ b/internal/quic/conn_loss_test.go @@ -0,0 +1,143 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "crypto/tls" + "testing" +) + +// Frames may be retransmitted either when the packet containing the frame is lost, or on PTO. +// lostFrameTest runs a test in both configurations. +func lostFrameTest(t *testing.T, f func(t *testing.T, pto bool)) { + t.Run("lost", func(t *testing.T) { + f(t, false) + }) + t.Run("pto", func(t *testing.T) { + f(t, true) + }) +} + +// triggerLossOrPTO causes the conn to declare the last sent packet lost, +// or advances to the PTO timer. +func (tc *testConn) triggerLossOrPTO(ptype packetType, pto bool) { + tc.t.Helper() + if pto { + if !tc.conn.loss.ptoTimerArmed { + tc.t.Fatalf("PTO timer not armed, expected it to be") + } + tc.advanceTo(tc.conn.loss.timer) + return + } + defer func(ignoreFrames map[byte]bool) { + tc.ignoreFrames = ignoreFrames + }(tc.ignoreFrames) + tc.ignoreFrames = map[byte]bool{ + frameTypeAck: true, + frameTypePadding: true, + } + // Send three packets containing PINGs, and then respond with an ACK for the + // last one. This puts the last packet before the PINGs outside the packet + // reordering threshold, and it will be declared lost. + const lossThreshold = 3 + var num packetNumber + for i := 0; i < lossThreshold; i++ { + tc.conn.ping(spaceForPacketType(ptype)) + d := tc.readDatagram() + if d == nil { + tc.t.Fatalf("conn is idle; want PING frame") + } + if d.packets[0].ptype != ptype { + tc.t.Fatalf("conn sent %v packet; want %v", d.packets[0].ptype, ptype) + } + num = d.packets[0].num + } + tc.writeFrames(ptype, debugFrameAck{ + ranges: []i64range[packetNumber]{ + {num, num + 1}, + }, + }) +} + +func TestLostCRYPTOFrame(t *testing.T) { + // "Data sent in CRYPTO frames is retransmitted [...] until all data has been acknowledged." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.3-3.1 + lostFrameTest(t, func(t *testing.T, pto bool) { + tc := newTestConn(t, clientSide) + tc.ignoreFrame(frameTypeAck) + + tc.wantFrame("client sends Initial CRYPTO frame", + packetTypeInitial, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial], + }) + tc.triggerLossOrPTO(packetTypeInitial, pto) + tc.wantFrame("client resends Initial CRYPTO frame", + packetTypeInitial, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial], + }) + + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], + }) + + tc.wantFrame("client sends Handshake CRYPTO frame", + packetTypeHandshake, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelHandshake], + }) + tc.triggerLossOrPTO(packetTypeHandshake, pto) + tc.wantFrame("client resends Handshake CRYPTO frame", + packetTypeHandshake, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelHandshake], + }) + }) +} + +func TestLostHandshakeDoneFrame(t *testing.T) { + // "The HANDSHAKE_DONE frame MUST be retransmitted until it is acknowledged." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.3-3.16 + lostFrameTest(t, func(t *testing.T, pto bool) { + tc := newTestConn(t, serverSide) + tc.ignoreFrame(frameTypeAck) + + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.wantFrame("server sends Initial CRYPTO frame", + packetTypeInitial, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial], + }) + tc.wantFrame("server sends Handshake CRYPTO frame", + packetTypeHandshake, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelHandshake], + }) + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], + }) + + tc.wantFrame("server sends HANDSHAKE_DONE after handshake completes", + packetType1RTT, debugFrameHandshakeDone{}) + tc.wantFrame("server sends session ticket in CRYPTO frame", + packetType1RTT, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelApplication], + }) + + tc.triggerLossOrPTO(packetType1RTT, pto) + tc.wantFrame("server resends HANDSHAKE_DONE", + packetType1RTT, debugFrameHandshakeDone{}) + tc.wantFrame("server resends session ticket", + packetType1RTT, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelApplication], + }) + }) +} diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go index d5a3b8cb0..7eb03e727 100644 --- a/internal/quic/conn_recv.go +++ b/internal/quic/conn_recv.go @@ -41,12 +41,12 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) { } func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpace, buf []byte) int { - if !c.tlsState.rkeys[space].isSet() { + if !c.rkeys[space].isSet() { return skipLongHeaderPacket(buf) } pnumMax := c.acks[space].largestSeen() - p, n := parseLongHeaderPacket(buf, c.tlsState.rkeys[space], pnumMax) + p, n := parseLongHeaderPacket(buf, c.rkeys[space], pnumMax) if n < 0 { return -1 } @@ -66,21 +66,23 @@ func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpa if p.ptype == packetTypeHandshake && c.side == serverSide { c.loss.validateClientAddress() - // TODO: Discard Initial keys. + // "[...] a server MUST discard Initial keys when it first successfully + // processes a Handshake packet [...]" // https://www.rfc-editor.org/rfc/rfc9001#section-4.9.1-2 + c.discardKeys(now, initialSpace) } return n } func (c *Conn) handle1RTT(now time.Time, buf []byte) int { - if !c.tlsState.rkeys[appDataSpace].isSet() { + if !c.rkeys[appDataSpace].isSet() { // 1-RTT packets extend to the end of the datagram, // so skip the remainder of the datagram if we can't parse this. return len(buf) } pnumMax := c.acks[appDataSpace].largestSeen() - p, n := parse1RTTPacket(buf, c.tlsState.rkeys[appDataSpace], connIDLen, pnumMax) + p, n := parse1RTTPacket(buf, c.rkeys[appDataSpace], connIDLen, pnumMax) if n < 0 { return -1 } @@ -163,7 +165,7 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, if !frameOK(c, ptype, IH_1) { return } - _, _, n = consumeCryptoFrame(payload) + n = c.handleCryptoFrame(now, space, payload) case frameTypeNewToken: if !frameOK(c, ptype, ___1) { return @@ -207,14 +209,18 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, case frameTypeConnectionCloseTransport: // CONNECTION_CLOSE is OK in all spaces. _, _, _, n = consumeConnectionCloseTransportFrame(payload) + // TODO: https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2 + c.abort(now, localTransportError(errNo)) case frameTypeConnectionCloseApplication: // CONNECTION_CLOSE is OK in all spaces. _, _, n = consumeConnectionCloseApplicationFrame(payload) + // TODO: https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2 + c.abort(now, localTransportError(errNo)) case frameTypeHandshakeDone: if !frameOK(c, ptype, ___1) { return } - n = 1 + n = c.handleHandshakeDoneFrame(now, space, payload) } if n < 0 { c.abort(now, localTransportError(errFrameEncoding)) @@ -262,3 +268,24 @@ func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte) c.loss.receiveAckEnd(now, space, delay, c.handleAckOrLoss) return n } + +func (c *Conn) handleCryptoFrame(now time.Time, space numberSpace, payload []byte) int { + off, data, n := consumeCryptoFrame(payload) + err := c.handleCrypto(now, space, off, data) + if err != nil { + c.abort(now, err) + return -1 + } + return n +} + +func (c *Conn) handleHandshakeDoneFrame(now time.Time, space numberSpace, payload []byte) int { + if c.side == serverSide { + // Clients should never send HANDSHAKE_DONE. + // https://www.rfc-editor.org/rfc/rfc9000#section-19.20-4 + c.abort(now, localTransportError(errProtocolViolation)) + return -1 + } + c.confirmHandshake(now) + return 1 +} diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go index 3a51ceb28..71d24e6f0 100644 --- a/internal/quic/conn_send.go +++ b/internal/quic/conn_send.go @@ -7,6 +7,8 @@ package quic import ( + "crypto/tls" + "errors" "time" ) @@ -45,7 +47,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { // Initial packet. pad := false var sentInitial *sentPacket - if k := c.tlsState.wkeys[initialSpace]; k.isSet() { + if k := c.wkeys[initialSpace]; k.isSet() { pnumMaxAcked := c.acks[initialSpace].largestSeen() pnum := c.loss.nextNumber(initialSpace) p := longPacket{ @@ -62,14 +64,14 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { // Client initial packets need to be sent in a datagram padded to // at least 1200 bytes. We can't add the padding yet, however, // since we may want to coalesce additional packets with this one. - if c.side == clientSide || sentInitial.ackEliciting { + if c.side == clientSide { pad = true } } } // Handshake packet. - if k := c.tlsState.wkeys[handshakeSpace]; k.isSet() { + if k := c.wkeys[handshakeSpace]; k.isSet() { pnumMaxAcked := c.acks[handshakeSpace].largestSeen() pnum := c.loss.nextNumber(handshakeSpace) p := longPacket{ @@ -84,14 +86,16 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, p); sent != nil { c.loss.packetSent(now, handshakeSpace, sent) if c.side == clientSide { - // TODO: Discard the Initial keys. - // https://www.rfc-editor.org/rfc/rfc9001.html#section-4.9.1 + // "[...] a client MUST discard Initial keys when it first + // sends a Handshake packet [...]" + // https://www.rfc-editor.org/rfc/rfc9001.html#section-4.9.1-2 + c.discardKeys(now, initialSpace) } } } // 1-RTT packet. - if k := c.tlsState.wkeys[appDataSpace]; k.isSet() { + if k := c.wkeys[appDataSpace]; k.isSet() { pnumMaxAcked := c.acks[appDataSpace].largestSeen() pnum := c.loss.nextNumber(appDataSpace) dstConnID := c.connIDState.dstConnID() @@ -133,7 +137,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { sentInitial.inFlight = true } } - if k := c.tlsState.wkeys[initialSpace]; k.isSet() { + if k := c.wkeys[initialSpace]; k.isSet() { c.loss.packetSent(now, initialSpace, sentInitial) } } @@ -143,6 +147,26 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { } func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, limit ccLimit) { + if c.errForPeer != nil { + // This is the bare minimum required to send a CONNECTION_CLOSE frame + // when closing a connection immediately, for example in response to a + // protocol error. + // + // This does not handle the closing and draining states + // (https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2), + // but it's enough to let us write tests that result in a CONNECTION_CLOSE, + // and have those tests still pass when we finish implementing + // connection shutdown. + // + // TODO: Finish implementing connection shutdown. + if !c.connCloseSent[space] { + c.exited = true + c.appendConnectionCloseFrame(c.errForPeer) + c.connCloseSent[space] = true + } + return + } + shouldSendAck := c.acks[space].shouldSendAck(now) if limit != ccOK { // ACKs are not limited by congestion control. @@ -185,6 +209,21 @@ func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, // TODO: Add all the other frames we can send. + // HANDSHAKE_DONE + if c.handshakeConfirmed.shouldSendPTO(pto) { + if !c.w.appendHandshakeDoneFrame() { + return + } + c.handshakeConfirmed.setSent(pnum) + } + + // CRYPTO + c.crypto[space].dataToSend(pto, func(off, size int64) int64 { + b, _ := c.w.appendCryptoFrame(off, int(size)) + c.crypto[space].sendData(off, b) + return int64(len(b)) + }) + // Test-only PING frames. if space == c.testSendPingSpace && c.testSendPing.shouldSendPTO(pto) { if !c.w.appendPingFrame() { @@ -253,3 +292,22 @@ func (c *Conn) appendAckFrame(now time.Time, space numberSpace) bool { d := unscaledAckDelayFromDuration(delay, ackDelayExponent) return c.w.appendAckFrame(seen, d) } + +func (c *Conn) appendConnectionCloseFrame(err error) { + // TODO: Send application errors. + switch e := err.(type) { + case localTransportError: + c.w.appendConnectionCloseTransportFrame(transportError(e), 0, "") + default: + // TLS alerts are sent using error codes [0x0100,0x01ff). + // https://www.rfc-editor.org/rfc/rfc9000#section-20.1-2.36.1 + var alert tls.AlertError + if errors.As(err, &alert) { + // tls.AlertError is a uint8, so this can't exceed 0x01ff. + code := errTLSBase + transportError(alert) + c.w.appendConnectionCloseTransportFrame(code, 0, "") + return + } + c.w.appendConnectionCloseTransportFrame(errInternal, 0, "") + } +} diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index fda1d4b86..511fb97a0 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -7,6 +7,9 @@ package quic import ( + "bytes" + "context" + "crypto/tls" "errors" "fmt" "math" @@ -111,8 +114,22 @@ type testConn struct { // we use Handshake keys to encrypt the packet. // The client only acquires those keys when it processes // the Initial packet. - rkeys [numberSpaceCount]keys // for packets sent to the conn - wkeys [numberSpaceCount]keys // for packets sent by the conn + rkeys [numberSpaceCount]keyData // for packets sent to the conn + wkeys [numberSpaceCount]keyData // for packets sent by the conn + + // testConn uses a test hook to snoop on the conn's TLS events. + // CRYPTO data produced by the conn's QUICConn is placed in + // cryptoDataOut. + // + // The peerTLSConn is is a QUICConn representing the peer. + // CRYPTO data produced by the conn is written to peerTLSConn, + // and data produced by peerTLSConn is placed in cryptoDataIn. + cryptoDataOut map[tls.QUICEncryptionLevel][]byte + cryptoDataIn map[tls.QUICEncryptionLevel][]byte + peerTLSConn *tls.QUICConn + + localConnID []byte + transientConnID []byte // Information about the conn's (fake) peer. peerConnID []byte // source conn id of peer's packets @@ -129,12 +146,18 @@ type testConn struct { ignoreFrames map[byte]bool } +type keyData struct { + suite uint16 + secret []byte + k keys +} + // newTestConn creates a Conn for testing. // // The Conn's event loop is controlled by the test, // allowing test code to access Conn state directly // by first ensuring the loop goroutine is idle. -func newTestConn(t *testing.T, side connSide) *testConn { +func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { t.Helper() tc := &testConn{ t: t, @@ -143,9 +166,24 @@ func newTestConn(t *testing.T, side connSide) *testConn { ignoreFrames: map[byte]bool{ frameTypePadding: true, // ignore PADDING by default }, + cryptoDataOut: make(map[tls.QUICEncryptionLevel][]byte), + cryptoDataIn: make(map[tls.QUICEncryptionLevel][]byte), } t.Cleanup(tc.cleanup) + config := &Config{ + TLSConfig: newTestTLSConfig(side), + } + peerProvidedParams := defaultTransportParameters() + for _, o := range opts { + switch o := o.(type) { + case func(*tls.Config): + o(config.TLSConfig) + default: + t.Fatalf("unknown newTestConn option %T", o) + } + } + var initialConnID []byte if side == serverSide { // The initial connection ID for the server is chosen by the client. @@ -157,11 +195,21 @@ func newTestConn(t *testing.T, side connSide) *testConn { } } + peerQUICConfig := &tls.QUICConfig{TLSConfig: newTestTLSConfig(side.peer())} + if side == clientSide { + tc.peerTLSConn = tls.QUICServer(peerQUICConfig) + } else { + tc.peerTLSConn = tls.QUICClient(peerQUICConfig) + } + tc.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams)) + tc.peerTLSConn.Start(context.Background()) + conn, err := newConn( tc.now, side, initialConnID, netip.MustParseAddrPort("127.0.0.1:443"), + config, (*testConnListener)(tc), (*testConnHooks)(tc)) if err != nil { @@ -169,8 +217,16 @@ func newTestConn(t *testing.T, side connSide) *testConn { } tc.conn = conn - tc.wkeys[initialSpace] = conn.tlsState.wkeys[initialSpace] - tc.rkeys[initialSpace] = conn.tlsState.rkeys[initialSpace] + if side == serverSide { + tc.transientConnID = tc.conn.connIDState.local[0].cid + tc.localConnID = tc.conn.connIDState.local[1].cid + } else if side == clientSide { + tc.transientConnID = tc.conn.connIDState.remote[0].cid + tc.localConnID = tc.conn.connIDState.local[0].cid + } + + tc.wkeys[initialSpace].k = conn.wkeys[initialSpace] + tc.rkeys[initialSpace].k = conn.rkeys[initialSpace] tc.wait() return tc @@ -385,7 +441,7 @@ func (tc *testConn) wantFrame(expectation string, wantType packetType, want debu tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want) } if gotType != wantType { - tc.t.Fatalf("%v:\ngot %v packet, want %v", expectation, wantType, want) + tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got) } if !reflect.DeepEqual(got, want) { tc.t.Fatalf("%v:\ngot frame: %v\nwant frame: %v", expectation, got, want) @@ -426,12 +482,12 @@ func (tc *testConn) encodeTestPacket(p *testPacket) []byte { f.write(&w) } space := spaceForPacketType(p.ptype) - if !tc.rkeys[space].isSet() { + if !tc.rkeys[space].k.isSet() { tc.t.Fatalf("sending packet with no %v keys available", space) return nil } if p.ptype != packetType1RTT { - w.finishProtectedLongHeaderPacket(pnumMaxAcked, tc.rkeys[space], longPacket{ + w.finishProtectedLongHeaderPacket(pnumMaxAcked, tc.rkeys[space].k, longPacket{ ptype: p.ptype, version: p.version, num: p.num, @@ -439,7 +495,7 @@ func (tc *testConn) encodeTestPacket(p *testPacket) []byte { srcConnID: p.srcConnID, }) } else { - w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, tc.rkeys[space]) + w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, tc.rkeys[space].k) } return w.datagram() } @@ -455,12 +511,12 @@ func (tc *testConn) parseTestDatagram(buf []byte) *testDatagram { } ptype := getPacketType(buf) space := spaceForPacketType(ptype) - if !tc.wkeys[space].isSet() { + if !tc.wkeys[space].k.isSet() { tc.t.Fatalf("no keys for space %v, packet type %v", space, ptype) } if isLongHeader(buf[0]) { var pnumMax packetNumber // TODO: Track packet numbers. - p, n := parseLongHeaderPacket(buf, tc.wkeys[space], pnumMax) + p, n := parseLongHeaderPacket(buf, tc.wkeys[space].k, pnumMax) if n < 0 { tc.t.Fatalf("packet parse error") } @@ -479,11 +535,10 @@ func (tc *testConn) parseTestDatagram(buf []byte) *testDatagram { buf = buf[n:] } else { var pnumMax packetNumber // TODO: Track packet numbers. - p, n := parse1RTTPacket(buf, tc.wkeys[space], len(tc.peerConnID), pnumMax) + p, n := parse1RTTPacket(buf, tc.wkeys[space].k, len(tc.peerConnID), pnumMax) if n < 0 { tc.t.Fatalf("packet parse error") } - dstConnID, _ := dstConnIDForDatagram(buf) frames, err := tc.parseTestFrames(p.payload) if err != nil { tc.t.Fatal(err) @@ -491,7 +546,7 @@ func (tc *testConn) parseTestDatagram(buf []byte) *testDatagram { d.packets = append(d.packets, &testPacket{ ptype: packetType1RTT, num: p.num, - dstConnID: dstConnID, + dstConnID: buf[1:][:len(tc.peerConnID)], frames: frames, }) buf = buf[n:] @@ -535,6 +590,73 @@ func spaceForPacketType(ptype packetType) numberSpace { // testConnHooks implements connTestHooks. type testConnHooks testConn +// handleTLSEvent processes TLS events generated by +// the connection under test's tls.QUICConn. +// +// We maintain a second tls.QUICConn representing the peer, +// and feed the TLS handshake data into it. +// +// We stash TLS handshake data from both sides in the testConn, +// where it can be used by tests. +// +// We snoop packet protection keys out of the tls.QUICConns, +// and verify that both sides of the connection are getting +// matching keys. +func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) { + setKey := func(keys *[numberSpaceCount]keyData, e tls.QUICEvent) { + k, err := newKeys(e.Suite, e.Data) + if err != nil { + tc.t.Errorf("newKeys: %v", err) + return + } + var space numberSpace + switch { + case e.Level == tls.QUICEncryptionLevelHandshake: + space = handshakeSpace + case e.Level == tls.QUICEncryptionLevelApplication: + space = appDataSpace + default: + tc.t.Errorf("unexpected encryption level %v", e.Level) + return + } + s := "read" + if keys == &tc.wkeys { + s = "write" + } + if keys[space].k.isSet() { + if keys[space].suite != e.Suite || !bytes.Equal(keys[space].secret, e.Data) { + tc.t.Errorf("%v key mismatch for level for level %v", s, e.Level) + } + return + } + keys[space].suite = e.Suite + keys[space].secret = append([]byte{}, e.Data...) + keys[space].k = k + } + switch e.Kind { + case tls.QUICSetReadSecret: + setKey(&tc.rkeys, e) + case tls.QUICSetWriteSecret: + setKey(&tc.wkeys, e) + case tls.QUICWriteData: + tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...) + tc.peerTLSConn.HandleData(e.Level, e.Data) + } + for { + e := tc.peerTLSConn.NextEvent() + switch e.Kind { + case tls.QUICNoEvent: + return + case tls.QUICSetReadSecret: + setKey(&tc.wkeys, e) + case tls.QUICSetWriteSecret: + setKey(&tc.rkeys, e) + case tls.QUICWriteData: + tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...) + } + } +} + // nextMessage is called by the Conn's event loop to request its next event. func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.Time, m any) { tc.timer = timer diff --git a/internal/quic/ping_test.go b/internal/quic/ping_test.go index 4a732ed54..c370aaf1d 100644 --- a/internal/quic/ping_test.go +++ b/internal/quic/ping_test.go @@ -10,26 +10,34 @@ import "testing" func TestPing(t *testing.T) { tc := newTestConn(t, clientSide) - tc.conn.ping(initialSpace) + tc.handshake() + + tc.conn.ping(appDataSpace) tc.wantFrame("connection should send a PING frame", - packetTypeInitial, debugFramePing{}) + packetType1RTT, debugFramePing{}) tc.advanceToTimer() tc.wantFrame("on PTO, connection should send another PING frame", - packetTypeInitial, debugFramePing{}) + packetType1RTT, debugFramePing{}) tc.wantIdle("after sending PTO probe, no additional frames to send") } func TestAck(t *testing.T) { tc := newTestConn(t, serverSide) - tc.writeFrames(packetTypeInitial, + tc.handshake() + + // Send two packets, to trigger an immediate ACK. + tc.writeFrames(packetType1RTT, + debugFramePing{}, + ) + tc.writeFrames(packetType1RTT, debugFramePing{}, ) tc.wantFrame("connection should respond to ack-eliciting packet with an ACK frame", - packetTypeInitial, + packetType1RTT, debugFrameAck{ - ranges: []i64range[packetNumber]{{0, 1}}, + ranges: []i64range[packetNumber]{{0, 3}}, }, ) } diff --git a/internal/quic/quic.go b/internal/quic/quic.go index 9df7f7e2b..a61c91f16 100644 --- a/internal/quic/quic.go +++ b/internal/quic/quic.go @@ -64,6 +64,14 @@ func (s connSide) String() string { } } +func (s connSide) peer() connSide { + if s == clientSide { + return serverSide + } else { + return clientSide + } +} + // A numberSpace is the context in which a packet number applies. // https://www.rfc-editor.org/rfc/rfc9000.html#section-12.3-7 type numberSpace byte diff --git a/internal/quic/tls.go b/internal/quic/tls.go index 1cdb727e2..4306a3e46 100644 --- a/internal/quic/tls.go +++ b/internal/quic/tls.go @@ -6,18 +6,132 @@ package quic -// tlsState encapsulates interactions with TLS. -type tlsState struct { - // Encryption keys indexed by number space. - rkeys [numberSpaceCount]keys - wkeys [numberSpaceCount]keys -} +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "time" +) -func (s *tlsState) init(side connSide, initialConnID []byte) { +// startTLS starts the TLS handshake. +func (c *Conn) startTLS(now time.Time, initialConnID []byte, params transportParameters) error { clientKeys, serverKeys := initialKeys(initialConnID) - if side == clientSide { - s.wkeys[initialSpace], s.rkeys[initialSpace] = clientKeys, serverKeys + if c.side == clientSide { + c.wkeys[initialSpace], c.rkeys[initialSpace] = clientKeys, serverKeys } else { - s.wkeys[initialSpace], s.rkeys[initialSpace] = serverKeys, clientKeys + c.wkeys[initialSpace], c.rkeys[initialSpace] = serverKeys, clientKeys + } + + qconfig := &tls.QUICConfig{TLSConfig: c.config.TLSConfig} + if c.side == clientSide { + c.tls = tls.QUICClient(qconfig) + } else { + c.tls = tls.QUICServer(qconfig) + } + c.tls.SetTransportParameters(marshalTransportParameters(params)) + // TODO: We don't need or want a context for cancelation here, + // but users can use a context to plumb values through to hooks defined + // in the tls.Config. Pass through a context. + if err := c.tls.Start(context.TODO()); err != nil { + return err + } + return c.handleTLSEvents(now) +} + +func (c *Conn) handleTLSEvents(now time.Time) error { + for { + e := c.tls.NextEvent() + if c.testHooks != nil { + c.testHooks.handleTLSEvent(e) + } + switch e.Kind { + case tls.QUICNoEvent: + return nil + case tls.QUICSetReadSecret: + space, k, err := tlsKey(e) + if err != nil { + return err + } + c.rkeys[space] = k + case tls.QUICSetWriteSecret: + space, k, err := tlsKey(e) + if err != nil { + return err + } + c.wkeys[space] = k + case tls.QUICWriteData: + space, err := spaceForLevel(e.Level) + if err != nil { + return err + } + c.crypto[space].write(e.Data) + case tls.QUICHandshakeDone: + if c.side == serverSide { + // "[...] the TLS handshake is considered confirmed + // at the server when the handshake completes." + // https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2-1 + c.confirmHandshake(now) + if !c.config.TLSConfig.SessionTicketsDisabled { + if err := c.tls.SendSessionTicket(false); err != nil { + return err + } + } + } + case tls.QUICTransportParameters: + params, err := unmarshalTransportParams(e.Data) + if err != nil { + return err + } + c.receiveTransportParameters(params) + } + } +} + +// tlsKey returns the keys in a QUICSetReadSecret or QUICSetWriteSecret event. +func tlsKey(e tls.QUICEvent) (numberSpace, keys, error) { + space, err := spaceForLevel(e.Level) + if err != nil { + return 0, keys{}, err + } + k, err := newKeys(e.Suite, e.Data) + if err != nil { + return 0, keys{}, err + } + return space, k, nil +} + +func spaceForLevel(level tls.QUICEncryptionLevel) (numberSpace, error) { + switch level { + case tls.QUICEncryptionLevelInitial: + return initialSpace, nil + case tls.QUICEncryptionLevelHandshake: + return handshakeSpace, nil + case tls.QUICEncryptionLevelApplication: + return appDataSpace, nil + default: + return 0, fmt.Errorf("quic: internal error: write handshake data at level %v", level) + } +} + +// handleCrypto processes data received in a CRYPTO frame. +func (c *Conn) handleCrypto(now time.Time, space numberSpace, off int64, data []byte) error { + var level tls.QUICEncryptionLevel + switch space { + case initialSpace: + level = tls.QUICEncryptionLevelInitial + case handshakeSpace: + level = tls.QUICEncryptionLevelHandshake + case appDataSpace: + level = tls.QUICEncryptionLevelApplication + default: + return errors.New("quic: internal error: received CRYPTO frame in unexpected number space") + } + err := c.crypto[space].handleCrypto(off, data, func(b []byte) error { + return c.tls.HandleData(level, b) + }) + if err != nil { + return err } + return c.handleTLSEvents(now) } diff --git a/internal/quic/tls_test.go b/internal/quic/tls_test.go new file mode 100644 index 000000000..df0782008 --- /dev/null +++ b/internal/quic/tls_test.go @@ -0,0 +1,421 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "reflect" + "testing" + "time" +) + +// handshake executes the handshake. +func (tc *testConn) handshake() { + tc.t.Helper() + defer func(saved map[byte]bool) { + tc.ignoreFrames = saved + }(tc.ignoreFrames) + tc.ignoreFrames = nil + t := tc.t + dgrams := handshakeDatagrams(tc) + i := 0 + for { + if i == len(dgrams)-1 { + if tc.conn.side == clientSide { + want := tc.now.Add(maxAckDelay - timerGranularity) + if !tc.timer.Equal(want) { + t.Fatalf("want timer = %v (max_ack_delay), got %v", want, tc.timer) + } + if got := tc.readDatagram(); got != nil { + t.Fatalf("client unexpectedly sent: %v", got) + } + } + tc.advance(maxAckDelay) + } + + // Check that we're sending exactly the data we expect. + // Any variation from the norm here should be intentional. + got := tc.readDatagram() + var want *testDatagram + if !(tc.conn.side == serverSide && i == 0) && i < len(dgrams) { + want = dgrams[i] + fillCryptoFrames(want, tc.cryptoDataOut) + i++ + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("dgram %v:\ngot %v\n\nwant %v", i, got, want) + } + if i >= len(dgrams) { + break + } + + fillCryptoFrames(dgrams[i], tc.cryptoDataIn) + tc.write(dgrams[i]) + i++ + } +} + +func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { + var ( + clientConnID []byte + serverConnID []byte + ) + if tc.conn.side == clientSide { + clientConnID = tc.localConnID + serverConnID = tc.peerConnID + } else { + clientConnID = tc.peerConnID + serverConnID = tc.localConnID + } + return []*testDatagram{{ + // Client Initial + packets: []*testPacket{{ + ptype: packetTypeInitial, + num: 0, + version: 1, + srcConnID: clientConnID, + dstConnID: tc.transientConnID, + frames: []debugFrame{ + debugFrameCrypto{}, + }, + }}, + paddedSize: 1200, + }, { + // Server Initial + Handshake + packets: []*testPacket{{ + ptype: packetTypeInitial, + num: 0, + version: 1, + srcConnID: serverConnID, + dstConnID: clientConnID, + frames: []debugFrame{ + debugFrameAck{ + ranges: []i64range[packetNumber]{{0, 1}}, + }, + debugFrameCrypto{}, + }, + }, { + ptype: packetTypeHandshake, + num: 0, + version: 1, + srcConnID: serverConnID, + dstConnID: clientConnID, + frames: []debugFrame{ + debugFrameCrypto{}, + }, + }}, + }, { + // Client Handshake + packets: []*testPacket{{ + ptype: packetTypeInitial, + num: 1, + version: 1, + srcConnID: clientConnID, + dstConnID: serverConnID, + frames: []debugFrame{ + debugFrameAck{ + ranges: []i64range[packetNumber]{{0, 1}}, + }, + }, + }, { + ptype: packetTypeHandshake, + num: 0, + version: 1, + srcConnID: clientConnID, + dstConnID: serverConnID, + frames: []debugFrame{ + debugFrameAck{ + ranges: []i64range[packetNumber]{{0, 1}}, + }, + debugFrameCrypto{}, + }, + }}, + paddedSize: 1200, + }, { + // Server HANDSHAKE_DONE and session ticket + packets: []*testPacket{{ + ptype: packetType1RTT, + num: 0, + dstConnID: clientConnID, + frames: []debugFrame{ + debugFrameHandshakeDone{}, + debugFrameCrypto{}, + }, + }}, + }, { + // Client ack (after max_ack_delay) + packets: []*testPacket{{ + ptype: packetType1RTT, + num: 0, + dstConnID: serverConnID, + frames: []debugFrame{ + debugFrameAck{ + ackDelay: unscaledAckDelayFromDuration( + maxAckDelay, ackDelayExponent), + ranges: []i64range[packetNumber]{{0, 1}}, + }, + }, + }}, + }} +} + +func fillCryptoFrames(d *testDatagram, data map[tls.QUICEncryptionLevel][]byte) { + for _, p := range d.packets { + var level tls.QUICEncryptionLevel + switch p.ptype { + case packetTypeInitial: + level = tls.QUICEncryptionLevelInitial + case packetTypeHandshake: + level = tls.QUICEncryptionLevelHandshake + case packetType1RTT: + level = tls.QUICEncryptionLevelApplication + default: + continue + } + for i := range p.frames { + c, ok := p.frames[i].(debugFrameCrypto) + if !ok { + continue + } + c.data = data[level] + data[level] = nil + p.frames[i] = c + } + } +} + +func TestConnClientHandshake(t *testing.T) { + tc := newTestConn(t, clientSide) + tc.handshake() + tc.advance(1 * time.Second) + tc.wantIdle("no packets should be sent by an idle conn after the handshake") +} + +func TestConnServerHandshake(t *testing.T) { + tc := newTestConn(t, serverSide) + tc.handshake() + tc.advance(1 * time.Second) + tc.wantIdle("no packets should be sent by an idle conn after the handshake") +} + +func TestConnKeysDiscardedClient(t *testing.T) { + tc := newTestConn(t, clientSide) + tc.ignoreFrame(frameTypeAck) + + tc.wantFrame("client sends Initial CRYPTO frame", + packetTypeInitial, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], + }) + tc.wantFrame("client sends Handshake CRYPTO frame", + packetTypeHandshake, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelHandshake], + }) + + // The client discards Initial keys after sending a Handshake packet. + tc.writeFrames(packetTypeInitial, + debugFrameConnectionCloseTransport{code: errInternal}) + tc.wantIdle("client has discarded Initial keys, cannot read CONNECTION_CLOSE") + + // The client discards Handshake keys after receiving a HANDSHAKE_DONE frame. + tc.writeFrames(packetType1RTT, + debugFrameHandshakeDone{}) + tc.writeFrames(packetTypeHandshake, + debugFrameConnectionCloseTransport{code: errInternal}) + tc.wantIdle("client has discarded Handshake keys, cannot read CONNECTION_CLOSE") + + tc.writeFrames(packetType1RTT, + debugFrameConnectionCloseTransport{code: errInternal}) + tc.wantFrame("client closes connection after 1-RTT CONNECTION_CLOSE", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errNo, + }) +} + +func TestConnKeysDiscardedServer(t *testing.T) { + tc := newTestConn(t, serverSide, func(c *tls.Config) { + c.SessionTicketsDisabled = true + }) + tc.ignoreFrame(frameTypeAck) + + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.wantFrame("server sends Initial CRYPTO frame", + packetTypeInitial, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial], + }) + tc.wantFrame("server sends Handshake CRYPTO frame", + packetTypeHandshake, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelHandshake], + }) + + // The server discards Initial keys after receiving a Handshake packet. + // The Handshake packet contains only the start of the client's CRYPTO flight here, + // to avoids completing the handshake yet. + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake][:1], + }) + tc.writeFrames(packetTypeInitial, + debugFrameConnectionCloseTransport{code: errInternal}) + tc.wantIdle("server has discarded Initial keys, cannot read CONNECTION_CLOSE") + + // The server discards Handshake keys after sending a HANDSHAKE_DONE frame. + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + off: 1, + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake][1:], + }) + tc.wantFrame("server sends HANDSHAKE_DONE after handshake completes", + packetType1RTT, debugFrameHandshakeDone{}) + tc.writeFrames(packetTypeHandshake, + debugFrameConnectionCloseTransport{code: errInternal}) + tc.wantIdle("server has discarded Handshake keys, cannot read CONNECTION_CLOSE") + + tc.writeFrames(packetType1RTT, + debugFrameConnectionCloseTransport{code: errInternal}) + tc.wantFrame("server closes connection after 1-RTT CONNECTION_CLOSE", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errNo, + }) +} + +func TestConnInvalidCryptoData(t *testing.T) { + tc := newTestConn(t, clientSide) + tc.ignoreFrame(frameTypeAck) + + tc.wantFrame("client sends Initial CRYPTO frame", + packetTypeInitial, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + + // Render the server's response invalid. + // + // The client closes the connection with CRYPTO_ERROR. + // + // Changing the first byte will change the TLS message type, + // so we can reasonably assume that this is an unexpected_message alert (10). + tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake][0] ^= 0x1 + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], + }) + tc.wantFrame("client closes connection due to TLS handshake error", + packetTypeInitial, debugFrameConnectionCloseTransport{ + code: errTLSBase + 10, + }) +} + +func TestConnInvalidPeerCertificate(t *testing.T) { + tc := newTestConn(t, clientSide, func(c *tls.Config) { + c.VerifyPeerCertificate = func([][]byte, [][]*x509.Certificate) error { + return errors.New("I will not buy this certificate. It is scratched.") + } + }) + tc.ignoreFrame(frameTypeAck) + + tc.wantFrame("client sends Initial CRYPTO frame", + packetTypeInitial, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], + }) + tc.wantFrame("client closes connection due to rejecting server certificate", + packetTypeInitial, debugFrameConnectionCloseTransport{ + code: errTLSBase + 42, // 42: bad_certificate + }) +} + +func TestConnHandshakeDoneSentToServer(t *testing.T) { + tc := newTestConn(t, serverSide) + tc.handshake() + + tc.writeFrames(packetType1RTT, + debugFrameHandshakeDone{}) + tc.wantFrame("server closes connection when client sends a HANDSHAKE_DONE frame", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errProtocolViolation, + }) +} + +func TestConnCryptoDataOutOfOrder(t *testing.T) { + tc := newTestConn(t, clientSide) + tc.ignoreFrame(frameTypeAck) + + tc.wantFrame("client sends Initial CRYPTO frame", + packetTypeInitial, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.wantIdle("client is idle, server Handshake flight has not arrived") + + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + off: 15, + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake][15:], + }) + tc.wantIdle("client is idle, server Handshake flight is not complete") + + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + off: 1, + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake][1:20], + }) + tc.wantIdle("client is idle, server Handshake flight is still not complete") + + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake][0:1], + }) + tc.wantFrame("client sends Handshake CRYPTO frame", + packetTypeHandshake, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelHandshake], + }) +} + +func TestConnCryptoBufferSizeExceeded(t *testing.T) { + tc := newTestConn(t, clientSide) + tc.ignoreFrame(frameTypeAck) + + tc.wantFrame("client sends Initial CRYPTO frame", + packetTypeInitial, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + off: cryptoBufferSize, + data: []byte{0}, + }) + tc.wantFrame("client closes connection after server exceeds CRYPTO buffer", + packetTypeInitial, debugFrameConnectionCloseTransport{ + code: errCryptoBufferExceeded, + }) +} diff --git a/internal/quic/tlsconfig_test.go b/internal/quic/tlsconfig_test.go new file mode 100644 index 000000000..47bfb0598 --- /dev/null +++ b/internal/quic/tlsconfig_test.go @@ -0,0 +1,62 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "crypto/tls" + "strings" +) + +func newTestTLSConfig(side connSide) *tls.Config { + config := &tls.Config{ + InsecureSkipVerify: true, + CipherSuites: []uint16{ + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + tls.TLS_CHACHA20_POLY1305_SHA256, + }, + MinVersion: tls.VersionTLS13, + } + if side == serverSide { + config.Certificates = []tls.Certificate{testCert} + } + return config +} + +var testCert = func() tls.Certificate { + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + panic(err) + } + return cert +}() + +// localhostCert is a PEM-encoded TLS cert with SAN IPs +// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT. +// generated from src/crypto/tls: +// go run generate_cert.go --ecdsa-curve P256 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h +var localhostCert = []byte(`-----BEGIN CERTIFICATE----- +MIIBrDCCAVKgAwIBAgIPCvPhO+Hfv+NW76kWxULUMAoGCCqGSM49BAMCMBIxEDAO +BgNVBAoTB0FjbWUgQ28wIBcNNzAwMTAxMDAwMDAwWhgPMjA4NDAxMjkxNjAwMDBa +MBIxEDAOBgNVBAoTB0FjbWUgQ28wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARh +WRF8p8X9scgW7JjqAwI9nYV8jtkdhqAXG9gyEgnaFNN5Ze9l3Tp1R9yCDBMNsGms +PyfMPe5Jrha/LmjgR1G9o4GIMIGFMA4GA1UdDwEB/wQEAwIChDATBgNVHSUEDDAK +BggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBSOJri/wLQxq6oC +Y6ZImms/STbTljAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAA +AAAAAAAAAAAAATAKBggqhkjOPQQDAgNIADBFAiBUguxsW6TGhixBAdORmVNnkx40 +HjkKwncMSDbUaeL9jQIhAJwQ8zV9JpQvYpsiDuMmqCuW35XXil3cQ6Drz82c+fvE +-----END CERTIFICATE-----`) + +// localhostKey is the private key for localhostCert. +var localhostKey = []byte(testingKey(`-----BEGIN TESTING KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgY1B1eL/Bbwf/MDcs +rnvvWhFNr1aGmJJR59PdCN9lVVqhRANCAARhWRF8p8X9scgW7JjqAwI9nYV8jtkd +hqAXG9gyEgnaFNN5Ze9l3Tp1R9yCDBMNsGmsPyfMPe5Jrha/LmjgR1G9 +-----END TESTING KEY-----`)) + +// testingKey helps keep security scanners from getting excited about a private key in this file. +func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } diff --git a/internal/quic/transport_params.go b/internal/quic/transport_params.go index 416bfb867..89ea69fb9 100644 --- a/internal/quic/transport_params.go +++ b/internal/quic/transport_params.go @@ -25,7 +25,7 @@ type transportParameters struct { initialMaxStreamDataUni int64 initialMaxStreamsBidi int64 initialMaxStreamsUni int64 - ackDelayExponent uint8 + ackDelayExponent int8 maxAckDelay time.Duration disableActiveMigration bool preferredAddrV4 netip.AddrPort @@ -220,7 +220,7 @@ func unmarshalTransportParams(params []byte) (transportParameters, error) { if v > 20 { return p, localTransportError(errTransportParameter) } - p.ackDelayExponent = uint8(v) + p.ackDelayExponent = int8(v) case paramMaxAckDelay: var v uint64 v, n = consumeVarint(val)