From ef69863f46225d01208893cea643c93ad7a3c916 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 15 Apr 2021 14:13:55 +0700 Subject: [PATCH] x/bank/types: fix AddressFromBalancesStore address length overflow (#9112) addrLen is encoded in a byte, so it's an uint8. The code in AddressFromBalancesStore cast it to int for bound checking, but wrongly uses "addrLen+1", which can be overflow. To fix this, just cast addrLen once and use it in all places. Found by fuzzing added in #9060. Fixes #9111 --- x/bank/types/key.go | 7 +++---- x/bank/types/key_test.go | 31 ++++++++++++++++++------------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/x/bank/types/key.go b/x/bank/types/key.go index 0d8ec96a0de3..a50d04a0a78f 100644 --- a/x/bank/types/key.go +++ b/x/bank/types/key.go @@ -44,12 +44,11 @@ func AddressFromBalancesStore(key []byte) (sdk.AccAddress, error) { return nil, ErrInvalidKey } addrLen := key[0] - if len(key[1:]) < int(addrLen) { + bound := int(addrLen) + if len(key)-1 < bound { return nil, ErrInvalidKey } - addr := key[1 : addrLen+1] - - return sdk.AccAddress(addr), nil + return key[1 : bound+1], nil } // CreateAccountBalancesPrefix creates the prefix for an account's balances. diff --git a/x/bank/types/key_test.go b/x/bank/types/key_test.go index c54037a2279e..9a7f457e45bd 100644 --- a/x/bank/types/key_test.go +++ b/x/bank/types/key_test.go @@ -24,29 +24,34 @@ func TestAddressFromBalancesStore(t *testing.T) { require.NoError(t, err) addrLen := len(addr) require.Equal(t, 20, addrLen) - key := cloneAppend(address.MustLengthPrefix(addr), []byte("stake")) - res, err := types.AddressFromBalancesStore(key) - require.NoError(t, err) - require.Equal(t, res, addr) -} -func TestInvalidAddressFromBalancesStore(t *testing.T) { tests := []struct { - name string - key []byte + name string + key []byte + wantErr bool + expectedKey sdk.AccAddress }{ - {"empty", []byte("")}, - {"invalid", []byte("3AA")}, + {"valid", key, false, addr}, + {"#9111", []byte("\xff000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), false, nil}, + {"empty", []byte(""), true, nil}, + {"invalid", []byte("3AA"), true, nil}, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - _, err := types.AddressFromBalancesStore(tc.key) - assert.Error(t, err) - assert.True(t, errors.Is(types.ErrInvalidKey, err)) + addr, err := types.AddressFromBalancesStore(tc.key) + if tc.wantErr { + assert.Error(t, err) + assert.True(t, errors.Is(types.ErrInvalidKey, err)) + } else { + assert.NoError(t, err) + } + if len(tc.expectedKey) > 0 { + assert.Equal(t, tc.expectedKey, addr) + } }) } }