diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index 1c71a23aa8a..1d17f9df5cd 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "fmt" + "io/ioutil" mrand "math/rand" "net" "sync/atomic" @@ -31,6 +32,7 @@ var _ = Describe("Handshake drop tests", func() { ln quic.Listener ) + data := GeneratePRData(5000) const timeout = 2 * time.Minute startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool, version protocol.VersionNumber) { @@ -77,10 +79,9 @@ var _ = Describe("Handshake drop tests", func() { defer sess.CloseWithError(0, "") str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 6) - _, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b) + b, err := ioutil.ReadAll(gbytes.TimeoutReader(str, timeout)) Expect(err).ToNot(HaveOccurred()) - Expect(string(b)).To(Equal("foobar")) + Expect(b).To(Equal(data)) serverSessionChan <- sess }() sess, err := quic.DialAddr( @@ -95,8 +96,9 @@ var _ = Describe("Handshake drop tests", func() { Expect(err).ToNot(HaveOccurred()) str, err := sess.OpenStream() Expect(err).ToNot(HaveOccurred()) - _, err = str.Write([]byte("foobar")) + _, err = str.Write(data) Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) var serverSession quic.Session Eventually(serverSessionChan, timeout).Should(Receive(&serverSession)) @@ -115,8 +117,9 @@ var _ = Describe("Handshake drop tests", func() { Expect(err).ToNot(HaveOccurred()) str, err := sess.OpenStream() Expect(err).ToNot(HaveOccurred()) - _, err = str.Write([]byte("foobar")) + _, err = str.Write(data) Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) serverSessionChan <- sess }() sess, err := quic.DialAddr( @@ -131,10 +134,9 @@ var _ = Describe("Handshake drop tests", func() { Expect(err).ToNot(HaveOccurred()) str, err := sess.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 6) - _, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b) + b, err := ioutil.ReadAll(gbytes.TimeoutReader(str, timeout)) Expect(err).ToNot(HaveOccurred()) - Expect(string(b)).To(Equal("foobar")) + Expect(b).To(Equal(data)) var serverSession quic.Session Eventually(serverSessionChan, timeout).Should(Receive(&serverSession)) diff --git a/packet_packer.go b/packet_packer.go index 84b5cb5c116..858fcd71498 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -467,7 +467,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { }, nil } -func (p *packetPacker) maybeGetCryptoPacket(maxSize, currentSize protocol.ByteCount, encLevel protocol.EncryptionLevel) (*wire.ExtendedHeader, *payload) { +func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize, currentSize protocol.ByteCount, encLevel protocol.EncryptionLevel) (*wire.ExtendedHeader, *payload) { var s cryptoStream var hasRetransmission bool //nolint:exhaustive // Initial and Handshake are the only two encryption levels here. @@ -494,19 +494,19 @@ func (p *packetPacker) maybeGetCryptoPacket(maxSize, currentSize protocol.ByteCo if ack != nil { payload.ack = ack payload.length = ack.Length(p.version) - maxSize -= payload.length + maxPacketSize -= payload.length } hdr := p.getLongHeader(encLevel) - maxSize -= hdr.GetLength(p.version) + maxPacketSize -= hdr.GetLength(p.version) if hasRetransmission { for { var f wire.Frame //nolint:exhaustive // 0-RTT packets can't contain any retransmission.s switch encLevel { case protocol.EncryptionInitial: - f = p.retransmissionQueue.GetInitialFrame(maxSize) + f = p.retransmissionQueue.GetInitialFrame(maxPacketSize) case protocol.EncryptionHandshake: - f = p.retransmissionQueue.GetHandshakeFrame(maxSize) + f = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize) } if f == nil { break @@ -514,10 +514,10 @@ func (p *packetPacker) maybeGetCryptoPacket(maxSize, currentSize protocol.ByteCo payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) frameLen := f.Length(p.version) payload.length += frameLen - maxSize -= frameLen + maxPacketSize -= frameLen } } else if s.HasData() { - cf := s.PopCryptoFrame(maxSize) + cf := s.PopCryptoFrame(maxPacketSize) payload.frames = []ackhandler.Frame{{Frame: cf}} payload.length += cf.Length(p.version) } @@ -547,18 +547,19 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPacketSize, currentSize protocol } maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead()) - payload := p.maybeGetAppDataPacketWithEncLevel(maxPayloadSize, currentSize, encLevel) + payload := p.maybeGetAppDataPacketWithEncLevel(maxPayloadSize, encLevel == protocol.Encryption1RTT && currentSize == 0) return sealer, hdr, payload } -func (p *packetPacker) maybeGetAppDataPacketWithEncLevel(maxPayloadSize, currentSize protocol.ByteCount, encLevel protocol.EncryptionLevel) *payload { - payload := p.composeNextPacket(maxPayloadSize, encLevel == protocol.Encryption1RTT && currentSize == 0) +func (p *packetPacker) maybeGetAppDataPacketWithEncLevel(maxPayloadSize protocol.ByteCount, ackAllowed bool) *payload { + payload := p.composeNextPacket(maxPayloadSize, ackAllowed) // check if we have anything to send - if len(payload.frames) == 0 && payload.ack == nil { - return nil - } - if len(payload.frames) == 0 { // the packet only contains an ACK + if len(payload.frames) == 0 { + if payload.ack == nil { + return nil + } + // the packet only contains an ACK if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks { ping := &wire.PingFrame{} payload.frames = append(payload.frames, ackhandler.Frame{Frame: ping}) @@ -642,14 +643,12 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( return nil, err } sealer = oneRTTSealer - payload = p.maybeGetAppDataPacketWithEncLevel(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.Encryption1RTT) - if payload != nil { - hdr = p.getShortHeader(oneRTTSealer.KeyPhase()) - } + hdr = p.getShortHeader(oneRTTSealer.KeyPhase()) + payload = p.maybeGetAppDataPacketWithEncLevel(p.maxPacketSize-protocol.ByteCount(sealer.Overhead())-hdr.GetLength(p.version), true) default: panic("unknown encryption level") } - if hdr == nil { + if payload == nil { return nil, nil } size := p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) diff --git a/packet_packer_test.go b/packet_packer_test.go index cf1f385008f..49c9f23b941 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -1322,6 +1322,25 @@ var _ = Describe("Packet packer", func() { parsePacket(packet.buffer.Data) }) + It("packs a full size Handshake probe packet", func() { + f := &wire.CryptoFrame{Data: make([]byte, 2000)} + retransmissionQueue.AddHandshake(f) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) + handshakeStream.EXPECT().HasData() + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + + packet, err := packer.MaybePackProbePacket(protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(packet.frames).To(HaveLen(1)) + Expect(packet.frames[0].Frame).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) + Expect(packet.length).To(Equal(maxPacketSize)) + parsePacket(packet.buffer.Data) + }) + It("packs a 1-RTT probe packet", func() { f := &wire.StreamFrame{Data: []byte("1-RTT")} retransmissionQueue.AddInitial(f) @@ -1341,8 +1360,33 @@ var _ = Describe("Packet packer", func() { Expect(packet.frames[0].Frame).To(Equal(f)) }) + It("packs a full size 1-RTT probe packet", func() { + f := &wire.StreamFrame{Data: make([]byte, 2000)} + retransmissionQueue.AddInitial(f) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + framer.EXPECT().HasData().Return(true) + expectAppendControlFrames() + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.Frame, maxSize protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + sf, split := f.MaybeSplitOffFrame(maxSize, packer.version) + Expect(split).To(BeTrue()) + return append(fs, ackhandler.Frame{Frame: sf}), sf.Length(packer.version) + }) + + packet, err := packer.MaybePackProbePacket(protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + Expect(packet.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) + Expect(packet.frames).To(HaveLen(1)) + Expect(packet.frames[0].Frame).To(BeAssignableToTypeOf(&wire.StreamFrame{})) + Expect(packet.length).To(Equal(maxPacketSize)) + }) + It("returns nil if there's no probe data to send", func() { sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) framer.EXPECT().HasData()