diff --git a/credentials/alts/internal/conn/aes128gcm.go b/credentials/alts/internal/conn/aes128gcm.go index 04e0adb6c908..efd2a9f1feff 100644 --- a/credentials/alts/internal/conn/aes128gcm.go +++ b/credentials/alts/internal/conn/aes128gcm.go @@ -28,7 +28,7 @@ import ( const ( // Overflow length n in bytes, never encrypt more than 2^(n*8) frames (in // each direction). - overflowLenAES128GCM = 5 + overflowLenAES128GCM = 4 ) // aes128gcm is the struct that holds necessary information for ALTS record. diff --git a/credentials/alts/internal/conn/aes128gcmrekey.go b/credentials/alts/internal/conn/aes128gcmrekey.go index 6a9035ea254f..84ec374ebb7e 100644 --- a/credentials/alts/internal/conn/aes128gcmrekey.go +++ b/credentials/alts/internal/conn/aes128gcmrekey.go @@ -27,7 +27,7 @@ import ( const ( // Overflow length n in bytes, never encrypt more than 2^(n*8) frames (in // each direction). - overflowLenAES128GCMRekey = 8 + overflowLenAES128GCMRekey = 4 nonceLen = 12 aeadKeyLen = 16 kdfKeyLen = 32 diff --git a/credentials/alts/internal/conn/counter_test.go b/credentials/alts/internal/conn/counter_test.go index 0e752c3bfdde..2c9fc0f4d813 100644 --- a/credentials/alts/internal/conn/counter_test.go +++ b/credentials/alts/internal/conn/counter_test.go @@ -52,8 +52,9 @@ func (s) TestCounterSides(t *testing.T) { func (s) TestCounterInc(t *testing.T) { for _, test := range []struct { - counter []byte - want []byte + counter []byte + want []byte + expectInvalid bool }{ { counter: []byte{0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, @@ -72,19 +73,32 @@ func (s) TestCounterInc(t *testing.T) { want: []byte{0x43, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, }, { - counter: []byte{0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - want: []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + counter: []byte{0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + want: []byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, }, { - counter: []byte{0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80}, - want: []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80}, + counter: []byte{0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80}, + want: []byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80}, + }, + { + counter: []byte{0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + want: []byte{}, + expectInvalid: true, + }, + { + counter: []byte{0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80}, + want: []byte{}, + expectInvalid: true, }, } { c := CounterFromValue(test.counter, overflowLenAES128GCM) c.Inc() value, _ := c.Value() - if g, w := value, test.want; !bytes.Equal(g, w) || c.invalid { - t.Errorf("counter(%v).Inc() =\n%v, want\n%v", test.counter, g, w) + if got, want := c.invalid, test.expectInvalid; got != want { + t.Errorf("counter.invalid=%t, want=%t", got, want) + } + if got, want := value, test.want; !bytes.Equal(got, want) { + t.Errorf("counter(%v).Inc() =\n%v, want\n%v", test.counter, got, want) } } }