diff --git a/replaydetector/replaydetector.go b/replaydetector/replaydetector.go index 4358d8f..8d4909e 100644 --- a/replaydetector/replaydetector.go +++ b/replaydetector/replaydetector.go @@ -8,7 +8,14 @@ package replaydetector type ReplayDetector interface { // Check returns true if given sequence number is not replayed. // Call accept() to mark the packet is received properly. - Check(seq uint64) (accept func(), ok bool) + // The return value of accept() indicates whether the accepted packet is + // has the latest observed sequence number. + Check(seq uint64) (accept func() bool, ok bool) +} + +// nop is a no-op func that is returned in the case that Check() fails. +func nop() bool { + return false } type slidingWindowDetector struct { @@ -30,30 +37,33 @@ func New(windowSize uint, maxSeq uint64) ReplayDetector { } } -func (d *slidingWindowDetector) Check(seq uint64) (accept func(), ok bool) { +func (d *slidingWindowDetector) Check(seq uint64) (accept func() bool, ok bool) { if seq > d.maxSeq { // Exceeded upper limit. - return func() {}, false + return nop, false } if seq <= d.latestSeq { if d.latestSeq >= uint64(d.windowSize)+seq { - return func() {}, false + return nop, false } if d.mask.Bit(uint(d.latestSeq-seq)) != 0 { // The sequence number is duplicated. - return func() {}, false + return nop, false } } - return func() { + return func() bool { + latest := seq == 0 if seq > d.latestSeq { // Update the head of the window. d.mask.Lsh(uint(seq - d.latestSeq)) d.latestSeq = seq + latest = true } diff := (d.latestSeq - seq) % d.maxSeq d.mask.SetBit(uint(diff)) + return latest }, true } @@ -75,10 +85,10 @@ type wrappedSlidingWindowDetector struct { init bool } -func (d *wrappedSlidingWindowDetector) Check(seq uint64) (accept func(), ok bool) { +func (d *wrappedSlidingWindowDetector) Check(seq uint64) (accept func() bool, ok bool) { if seq > d.maxSeq { // Exceeded upper limit. - return func() {}, false + return nop, false } if !d.init { if seq != 0 { @@ -99,21 +109,24 @@ func (d *wrappedSlidingWindowDetector) Check(seq uint64) (accept func(), ok bool if diff >= int64(d.windowSize) { // Too old. - return func() {}, false + return nop, false } if diff >= 0 { if d.mask.Bit(uint(diff)) != 0 { // The sequence number is duplicated. - return func() {}, false + return nop, false } } - return func() { + return func() bool { + latest := false if diff < 0 { // Update the head of the window. d.mask.Lsh(uint(-diff)) d.latestSeq = seq + latest = true } d.mask.SetBit(uint(d.latestSeq - seq)) + return latest }, true }