Skip to content

Commit

Permalink
Restore original internal interface
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna committed Oct 1, 2020
1 parent c071069 commit 15e3c6b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
12 changes: 11 additions & 1 deletion internal/saltfilter.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ var initSaltfilterOnce sync.Once

// GetSaltFilterSingleton returns the BloomRing singleton,
// initializing it on first call.
func GetSaltFilterSingleton() *BloomRing {
func getSaltFilterSingleton() *BloomRing {
initSaltfilterOnce.Do(func() {
var (
finalCapacity = DefaultSFCapacity
Expand Down Expand Up @@ -69,3 +69,13 @@ func GetSaltFilterSingleton() *BloomRing {
})
return saltfilter
}

// TestSalt returns true if salt is repeated
func TestSalt(b []byte) bool {
return getSaltFilterSingleton().Test(b)
}

// AddSalt salt to filter
func AddSalt(b []byte) {
getSaltFilterSingleton().Add(b)
}
7 changes: 3 additions & 4 deletions shadowaead/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func Pack(dst, plaintext []byte, ciph Cipher) ([]byte, error) {
if err != nil {
return nil, err
}
internal.GetSaltFilterSingleton().Add(salt)
internal.AddSalt(salt)

if len(dst) < saltSize+len(plaintext)+aead.Overhead() {
return nil, io.ErrShortBuffer
Expand All @@ -45,16 +45,15 @@ func Unpack(dst, pkt []byte, ciph Cipher) ([]byte, error) {
if len(pkt) < saltSize {
return nil, ErrShortPacket
}
saltfilter := internal.GetSaltFilterSingleton()
salt := pkt[:saltSize]
if saltfilter.Test(salt) {
if internal.TestSalt(salt) {
return nil, ErrRepeatedSalt
}
aead, err := ciph.Decrypter(salt)
if err != nil {
return nil, err
}
saltfilter.Add(salt)
internal.AddSalt(salt)
if len(pkt) < saltSize+aead.Overhead() {
return nil, ErrShortPacket
}
Expand Down
7 changes: 3 additions & 4 deletions shadowaead/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,14 @@ func (c *streamConn) initReader() error {
if _, err := io.ReadFull(c.Conn, salt); err != nil {
return err
}
saltfilter := internal.GetSaltFilterSingleton()
if saltfilter.Test(salt) {
if internal.TestSalt(salt) {
return ErrRepeatedSalt
}
aead, err := c.Decrypter(salt)
if err != nil {
return err
}
saltfilter.Add(salt)
internal.AddSalt(salt)

c.r = newReader(c.Conn, aead)
return nil
Expand Down Expand Up @@ -250,7 +249,7 @@ func (c *streamConn) initWriter() error {
if err != nil {
return err
}
internal.GetSaltFilterSingleton().Add(salt)
internal.AddSalt(salt)
c.w = newWriter(c.Conn, aead)
return nil
}
Expand Down

0 comments on commit 15e3c6b

Please sign in to comment.