From e07df3fa2f0333fed1759bc25be90c46da10d549 Mon Sep 17 00:00:00 2001 From: Shelikhoo Date: Fri, 12 Mar 2021 16:39:26 +0000 Subject: [PATCH] Fix a race condition that could allow same IV to be used more than once if timed correctly --- internal/bloomring.go | 19 +++++++++++++++++++ internal/saltfilter.go | 4 ++++ shadowaead/packet.go | 7 +++---- shadowaead/stream.go | 8 ++++---- 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/internal/bloomring.go b/internal/bloomring.go index c89a4e92..229a37b8 100644 --- a/internal/bloomring.go +++ b/internal/bloomring.go @@ -46,6 +46,10 @@ func (r *BloomRing) Add(b []byte) { } r.mutex.Lock() defer r.mutex.Unlock() + r.add(b) +} + +func (r *BloomRing) add(b []byte) { slot := r.slots[r.slotPosition] if r.entryCounter > r.slotCapacity { // Move to next slot and reset @@ -64,6 +68,11 @@ func (r *BloomRing) Test(b []byte) bool { } r.mutex.RLock() defer r.mutex.RUnlock() + test := r.test(b) + return test +} + +func (r *BloomRing) test(b []byte) bool { for _, s := range r.slots { if s.Test(b) { return true @@ -71,3 +80,13 @@ func (r *BloomRing) Test(b []byte) bool { } return false } + +func (r *BloomRing) Check(b []byte) bool { + r.mutex.Lock() + defer r.mutex.Unlock() + if r.Test(b) { + return true + } + r.Add(b) + return false +} diff --git a/internal/saltfilter.go b/internal/saltfilter.go index edb7e2d1..9ff61623 100644 --- a/internal/saltfilter.go +++ b/internal/saltfilter.go @@ -79,3 +79,7 @@ func TestSalt(b []byte) bool { func AddSalt(b []byte) { getSaltFilterSingleton().Add(b) } + +func CheckSalt(b []byte) bool { + return getSaltFilterSingleton().Test(b) +} diff --git a/shadowaead/packet.go b/shadowaead/packet.go index 6f48f14c..2ba403fb 100644 --- a/shadowaead/packet.go +++ b/shadowaead/packet.go @@ -46,14 +46,13 @@ func Unpack(dst, pkt []byte, ciph Cipher) ([]byte, error) { return nil, ErrShortPacket } salt := pkt[:saltSize] - if internal.TestSalt(salt) { - return nil, ErrRepeatedSalt - } aead, err := ciph.Decrypter(salt) if err != nil { return nil, err } - internal.AddSalt(salt) + if internal.CheckSalt(salt) { + return nil, ErrRepeatedSalt + } if len(pkt) < saltSize+aead.Overhead() { return nil, ErrShortPacket } diff --git a/shadowaead/stream.go b/shadowaead/stream.go index a41e14ea..251f5f56 100644 --- a/shadowaead/stream.go +++ b/shadowaead/stream.go @@ -205,14 +205,14 @@ func (c *streamConn) initReader() error { if _, err := io.ReadFull(c.Conn, salt); err != nil { return err } - if internal.TestSalt(salt) { - return ErrRepeatedSalt - } aead, err := c.Decrypter(salt) if err != nil { return err } - internal.AddSalt(salt) + + if internal.CheckSalt(salt) { + return ErrRepeatedSalt + } c.r = newReader(c.Conn, aead) return nil