diff --git a/types/address.go b/types/address.go index f2d33019..ea02a72d 100644 --- a/types/address.go +++ b/types/address.go @@ -5,6 +5,7 @@ import ( "crypto/sha512" "encoding/base32" "encoding/base64" + "fmt" ) const ( @@ -64,7 +65,8 @@ func (a *Address) UnmarshalText(text []byte) error { } // DecodeAddress turns a checksum address string into an Address object. It -// checks that the checksum is correct, and returns an error if it's not. +// checks that the checksum is correct and whether the address is canonical, +// and returns an error if it's not. func DecodeAddress(addr string) (a Address, err error) { // Interpret the address as base32 decoded, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(addr) @@ -94,6 +96,13 @@ func DecodeAddress(addr string) (a Address, err error) { // Checksum is good, copy address bytes into output copy(a[:], addressBytes) + + // Check if address is canonical + if a.String() != addr { + err = fmt.Errorf("address %s is non-canonical", addr) + return + } + return a, nil } diff --git a/types/address_test.go b/types/address_test.go index dee6c278..11205a0e 100644 --- a/types/address_test.go +++ b/types/address_test.go @@ -92,3 +92,18 @@ func TestUnmarshalAddress(t *testing.T) { }) } } + +func TestDecodeNonCanonicalAddress(t *testing.T) { + // Canonical addresses must end with one of the following: "AEIMQUY4", + // e.g. "7HJBGRIWI7GDL42SOJNIAZ7LJ7EBEGKGE5S52QZXAWDXOHDKMDFR6AUXDE" + addrs := []string{ + "7HJBGRIWI7GDL42SOJNIAZ7LJ7EBEGKGE5S52QZXAWDXOHDKMDFR6AUXDF", + "7HJBGRIWI7GDL42SOJNIAZ7LJ7EBEGKGE5S52QZXAWDXOHDKMDFR6AUXDG", + "7HJBGRIWI7GDL42SOJNIAZ7LJ7EBEGKGE5S52QZXAWDXOHDKMDFR6AUXDH", + } + for _, addr := range addrs { + _, err := DecodeAddress(addr) + require.Error(t, err) + require.ErrorContains(t, err, fmt.Sprintf("address %s is non-canonical", addr)) + } +}