Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UTXO DoS Vulnerability Fix #191

Merged
merged 3 commits into from
May 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 3 additions & 16 deletions vms/avm/prefixed_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,7 @@ func (s *prefixedState) SetDBInitialized(status choices.Status) error {
// Funds returns the mapping from the 32 byte representation of an address to a
// list of utxo IDs that reference the address.
func (s *prefixedState) Funds(id ids.ID) ([]ids.ID, error) {
return s.state.IDs(uniqueID(id, fundsID, s.funds))
}

// SetFunds saves the mapping from address to utxo IDs to storage.
func (s *prefixedState) SetFunds(id ids.ID, idSlice []ids.ID) error {
return s.state.SetIDs(uniqueID(id, fundsID, s.funds), idSlice)
return s.state.IDs(id)
}

// SpendUTXO consumes the provided utxo.
Expand All @@ -106,11 +101,7 @@ func (s *prefixedState) SpendUTXO(utxoID ids.ID) error {
func (s *prefixedState) removeUTXO(addrs [][]byte, utxoID ids.ID) error {
for _, addr := range addrs {
addrID := ids.NewID(hashing.ComputeHash256Array(addr))
utxos := ids.Set{}
funds, _ := s.Funds(addrID)
utxos.Add(funds...)
utxos.Remove(utxoID)
if err := s.SetFunds(addrID, utxos.List()); err != nil {
if err := s.state.RemoveID(addrID, utxoID); err != nil {
return err
}
}
Expand All @@ -135,11 +126,7 @@ func (s *prefixedState) FundUTXO(utxo *ava.UTXO) error {
func (s *prefixedState) addUTXO(addrs [][]byte, utxoID ids.ID) error {
for _, addr := range addrs {
addrID := ids.NewID(hashing.ComputeHash256Array(addr))
utxos := ids.Set{}
funds, _ := s.Funds(addrID)
utxos.Add(funds...)
utxos.Add(utxoID)
if err := s.SetFunds(addrID, utxos.List()); err != nil {
if err := s.state.AddID(addrID, utxoID); err != nil {
return err
}
}
Expand Down
9 changes: 6 additions & 3 deletions vms/avm/prefixed_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,11 @@ func TestPrefixedFundingAddresses(t *testing.T) {
if err := state.SpendUTXO(utxo.InputID()); err != nil {
t.Fatal(err)
}
_, err = state.Funds(ids.NewID(hashing.ComputeHash256Array([]byte{0})))
if err == nil {
t.Fatalf("Should have returned no utxoIDs")
funds, err = state.Funds(ids.NewID(hashing.ComputeHash256Array([]byte{0})))
if err != nil {
t.Fatal(err)
}
if len(funds) != 0 {
t.Fatalf("Should have returned 0 utxoIDs")
}
}
82 changes: 48 additions & 34 deletions vms/avm/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,19 @@ func TestStateIDs(t *testing.T) {

state := vm.state.state

id0 := ids.NewID([32]byte{0xff, 0})
id1 := ids.NewID([32]byte{0xff, 0})
id2 := ids.NewID([32]byte{0xff, 0})
id0 := ids.NewID([32]byte{0x00, 0})
id1 := ids.NewID([32]byte{0x01, 0})
id2 := ids.NewID([32]byte{0x02, 0})

if _, err := state.IDs(ids.Empty); err == nil {
t.Fatalf("Should have errored when reading ids")
if _, err := state.IDs(ids.Empty); err != nil {
t.Fatal(err)
}

expected := []ids.ID{id0, id1}
if err := state.SetIDs(ids.Empty, expected); err != nil {
t.Fatal(err)
for _, id := range expected {
if err := state.AddID(ids.Empty, id); err != nil {
t.Fatal(err)
}
}

result, err := state.IDs(ids.Empty)
Expand All @@ -45,18 +47,36 @@ func TestStateIDs(t *testing.T) {
t.Fatalf("Returned the wrong number of ids")
}

ids.SortIDs(result)
for i, resultID := range result {
expectedID := expected[i]
if !expectedID.Equals(resultID) {
t.Fatalf("Wrong ID returned")
}
}

expected = []ids.ID{id1, id2}
if err := state.SetIDs(ids.Empty, expected); err != nil {
for _, id := range expected {
if err := state.RemoveID(ids.Empty, id); err != nil {
t.Fatal(err)
}
}

result, err = state.IDs(ids.Empty)
if err != nil {
t.Fatal(err)
}

if len(result) != 0 {
t.Fatalf("Should have returned 0 IDs")
}

expected = []ids.ID{id1, id2}
for _, id := range expected {
if err := state.AddID(ids.Empty, id); err != nil {
t.Fatal(err)
}
}

result, err = state.IDs(ids.Empty)
if err != nil {
t.Fatal(err)
Expand All @@ -66,6 +86,7 @@ func TestStateIDs(t *testing.T) {
t.Fatalf("Returned the wrong number of ids")
}

ids.SortIDs(result)
for i, resultID := range result {
expectedID := expected[i]
if !expectedID.Equals(resultID) {
Expand All @@ -84,6 +105,7 @@ func TestStateIDs(t *testing.T) {
t.Fatalf("Returned the wrong number of ids")
}

ids.SortIDs(result)
for i, resultID := range result {
expectedID := expected[i]
if !expectedID.Equals(resultID) {
Expand All @@ -95,18 +117,6 @@ func TestStateIDs(t *testing.T) {
t.Fatal(err)
}

result, err = state.IDs(ids.Empty)
if err == nil {
t.Fatalf("Should have errored during cache lookup")
}

state.Cache.Flush()

result, err = state.IDs(ids.Empty)
if err == nil {
t.Fatalf("Should have errored during parsing")
}

statusResult, err := state.Status(ids.Empty)
if err != nil {
t.Fatal(err)
Expand All @@ -115,16 +125,27 @@ func TestStateIDs(t *testing.T) {
t.Fatalf("Should have returned the %s status", choices.Accepted)
}

if err := state.SetIDs(ids.Empty, []ids.ID{ids.ID{}}); err == nil {
t.Fatalf("Should have errored during serialization")
for _, id := range expected {
if err := state.RemoveID(ids.Empty, id); err != nil {
t.Fatal(err)
}
}

if err := state.SetIDs(ids.Empty, []ids.ID{}); err != nil {
result, err = state.IDs(ids.Empty)
if err != nil {
t.Fatal(err)
}

if _, err := state.IDs(ids.Empty); err == nil {
t.Fatalf("Should have errored when reading ids")
if len(result) != 0 {
t.Fatalf("Should have returned 0 IDs")
}

if err := state.AddID(ids.Empty, ids.ID{}); err == nil {
t.Fatalf("Should have errored during serialization")
}

if err := state.RemoveID(ids.Empty, ids.ID{}); err == nil {
t.Fatalf("Should have errored during serialization")
}
}

Expand Down Expand Up @@ -153,14 +174,7 @@ func TestStateStatuses(t *testing.T) {
t.Fatalf("Should have returned the %s status", choices.Accepted)
}

if err := state.SetIDs(ids.Empty, []ids.ID{ids.Empty}); err != nil {
t.Fatal(err)
}
if _, err := state.Status(ids.Empty); err == nil {
t.Fatalf("Should have errored when reading ids")
}

if err := state.SetStatus(ids.Empty, choices.Accepted); err != nil {
if err := state.AddID(ids.Empty, ids.Empty); err != nil {
t.Fatal(err)
}

Expand Down
18 changes: 3 additions & 15 deletions vms/components/ava/prefixed_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (s *chainState) UTXO(id ids.ID) (*UTXO, error) {
// Funds returns the mapping from the 32 byte representation of an
// address to a list of utxo IDs that reference the address.
func (s *chainState) Funds(id ids.ID) ([]ids.ID, error) {
return s.IDs(UniqueID(id, s.fundsIDPrefix, s.fundsID))
return s.IDs(id)
}

// SpendUTXO consumes the provided platform utxo.
Expand Down Expand Up @@ -97,11 +97,7 @@ func (s *chainState) setStatus(id ids.ID, status choices.Status) error {
func (s *chainState) removeUTXO(addrs [][]byte, utxoID ids.ID) error {
for _, addr := range addrs {
addrID := ids.NewID(hashing.ComputeHash256Array(addr))
utxos := ids.Set{}
funds, _ := s.Funds(addrID)
utxos.Add(funds...)
utxos.Remove(utxoID)
if err := s.setFunds(addrID, utxos.List()); err != nil {
if err := s.RemoveID(addrID, utxoID); err != nil {
return err
}
}
Expand All @@ -111,21 +107,13 @@ func (s *chainState) removeUTXO(addrs [][]byte, utxoID ids.ID) error {
func (s *chainState) addUTXO(addrs [][]byte, utxoID ids.ID) error {
for _, addr := range addrs {
addrID := ids.NewID(hashing.ComputeHash256Array(addr))
utxos := ids.Set{}
funds, _ := s.Funds(addrID)
utxos.Add(funds...)
utxos.Add(utxoID)
if err := s.setFunds(addrID, utxos.List()); err != nil {
if err := s.AddID(addrID, utxoID); err != nil {
return err
}
}
return nil
}

func (s *chainState) setFunds(id ids.ID, idSlice []ids.ID) error {
return s.SetIDs(UniqueID(id, s.fundsIDPrefix, s.fundsID), idSlice)
}

// PrefixedState wraps a state object. By prefixing the state, there will
// be no collisions between different types of objects that have the same hash.
type PrefixedState struct {
Expand Down
50 changes: 24 additions & 26 deletions vms/components/ava/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ import (

"github.com/ava-labs/gecko/cache"
"github.com/ava-labs/gecko/database"
"github.com/ava-labs/gecko/database/prefixdb"
"github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/snow/choices"
"github.com/ava-labs/gecko/vms/components/codec"
)

var (
errCacheTypeMismatch = errors.New("type returned from cache doesn't match the expected type")
errZeroID = errors.New("database key ID value not initialized")
)

// UniqueID returns a unique identifier
Expand Down Expand Up @@ -116,39 +118,35 @@ func (s *State) SetStatus(id ids.ID, status choices.Status) error {

// IDs returns a slice of IDs from storage
func (s *State) IDs(id ids.ID) ([]ids.ID, error) {
if idsIntf, found := s.Cache.Get(id); found {
if idSlice, ok := idsIntf.([]ids.ID); ok {
return idSlice, nil
idSlice := []ids.ID(nil)
iter := prefixdb.New(id.Bytes(), s.DB).NewIterator()
defer iter.Release()

for iter.Next() {
keyID, err := ids.ToID(iter.Key())
if err != nil {
return nil, err
}
return nil, errCacheTypeMismatch
}

bytes, err := s.DB.Get(id.Bytes())
if err != nil {
return nil, err
idSlice = append(idSlice, keyID)
}

idSlice := []ids.ID(nil)
if err := s.Codec.Unmarshal(bytes, &idSlice); err != nil {
return nil, err
}

s.Cache.Put(id, idSlice)
return idSlice, nil
}

// SetIDs saves a slice of IDs to the database.
func (s *State) SetIDs(id ids.ID, idSlice []ids.ID) error {
if len(idSlice) == 0 {
s.Cache.Evict(id)
return s.DB.Delete(id.Bytes())
// AddID saves an ID to the prefixed database
func (s *State) AddID(id ids.ID, key ids.ID) error {
if key.IsZero() {
return errZeroID
}
db := prefixdb.New(id.Bytes(), s.DB)
return db.Put(key.Bytes(), nil)
}

bytes, err := s.Codec.Marshal(idSlice)
if err != nil {
return err
// RemoveID removes an ID from the prefixed database
func (s *State) RemoveID(id ids.ID, key ids.ID) error {
if key.IsZero() {
return errZeroID
}

s.Cache.Put(id, idSlice)
return s.DB.Put(id.Bytes(), bytes)
db := prefixdb.New(id.Bytes(), s.DB)
return db.Delete(key.Bytes())
}