Skip to content

Commit

Permalink
feat: wrap parsing errors into ErrInvalidCid
Browse files Browse the repository at this point in the history
  • Loading branch information
hacdias committed Mar 20, 2023
1 parent 85c4236 commit b98e249
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 39 deletions.
94 changes: 58 additions & 36 deletions cid.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,32 @@ import (
// UnsupportedVersionString just holds an error message
const UnsupportedVersionString = "<unsupported cid version>"

// ErrInvalidCid is an error that indicates that a CID is invalid.
type ErrInvalidCid struct {
Err error
}

func (e ErrInvalidCid) Error() string {
return fmt.Sprintf("invalid cid: %s", e.Err)
}

func (e ErrInvalidCid) Unwrap() error {
return e.Err
}

func (e ErrInvalidCid) Is(err error) bool {
switch err.(type) {
case ErrInvalidCid, *ErrInvalidCid:
return true
default:
return false
}
}

var (
// ErrCidTooShort means that the cid passed to decode was not long
// enough to be a valid Cid
ErrCidTooShort = errors.New("cid too short")
ErrCidTooShort = ErrInvalidCid{errors.New("cid too short")}

// ErrInvalidEncoding means that selected encoding is not supported
// by this Cid version
Expand Down Expand Up @@ -90,10 +112,10 @@ func tryNewCidV0(mhash mh.Multihash) (Cid, error) {
// incorrectly detect it as CidV1 in the Version() method
dec, err := mh.Decode(mhash)
if err != nil {
return Undef, err
return Undef, ErrInvalidCid{err}
}
if dec.Code != mh.SHA2_256 || dec.Length != 32 {
return Undef, fmt.Errorf("invalid hash for cidv0 %d-%d", dec.Code, dec.Length)
return Undef, ErrInvalidCid{fmt.Errorf("invalid hash for cidv0 %d-%d", dec.Code, dec.Length)}
}
return Cid{string(mhash)}, nil
}
Expand Down Expand Up @@ -177,7 +199,7 @@ func Parse(v interface{}) (Cid, error) {
case Cid:
return v2, nil
default:
return Undef, fmt.Errorf("can't parse %+v as Cid", v2)
return Undef, ErrInvalidCid{fmt.Errorf("can't parse %+v as Cid", v2)}
}
}

Expand Down Expand Up @@ -210,15 +232,15 @@ func Decode(v string) (Cid, error) {
if len(v) == 46 && v[:2] == "Qm" {
hash, err := mh.FromB58String(v)
if err != nil {
return Undef, err
return Undef, ErrInvalidCid{err}
}

return tryNewCidV0(hash)
}

_, data, err := mbase.Decode(v)
if err != nil {
return Undef, err
return Undef, ErrInvalidCid{err}
}

return Cast(data)
Expand All @@ -240,7 +262,7 @@ func ExtractEncoding(v string) (mbase.Encoding, error) {
// check encoding is valid
_, err := mbase.NewEncoder(encoding)
if err != nil {
return -1, err
return -1, ErrInvalidCid{err}
}

return encoding, nil
Expand All @@ -260,11 +282,11 @@ func ExtractEncoding(v string) (mbase.Encoding, error) {
func Cast(data []byte) (Cid, error) {
nr, c, err := CidFromBytes(data)
if err != nil {
return Undef, err
return Undef, ErrInvalidCid{err}
}

if nr != len(data) {
return Undef, fmt.Errorf("trailing bytes in data buffer passed to cid Cast")
return Undef, ErrInvalidCid{fmt.Errorf("trailing bytes in data buffer passed to cid Cast")}
}

return c, nil
Expand Down Expand Up @@ -434,28 +456,28 @@ func (c Cid) Equals(o Cid) bool {
// UnmarshalJSON parses the JSON representation of a Cid.
func (c *Cid) UnmarshalJSON(b []byte) error {
if len(b) < 2 {
return fmt.Errorf("invalid cid json blob")
return ErrInvalidCid{fmt.Errorf("invalid cid json blob")}
}
obj := struct {
CidTarget string `json:"/"`
}{}
objptr := &obj
err := json.Unmarshal(b, &objptr)
if err != nil {
return err
return ErrInvalidCid{err}
}
if objptr == nil {
*c = Cid{}
return nil
}

if obj.CidTarget == "" {
return fmt.Errorf("cid was incorrectly formatted")
return ErrInvalidCid{fmt.Errorf("cid was incorrectly formatted")}
}

out, err := Decode(obj.CidTarget)
if err != nil {
return err
return ErrInvalidCid{err}
}

*c = out
Expand Down Expand Up @@ -542,12 +564,12 @@ func (p Prefix) Sum(data []byte) (Cid, error) {
if p.Version == 0 && (p.MhType != mh.SHA2_256 ||
(p.MhLength != 32 && p.MhLength != -1)) {

return Undef, fmt.Errorf("invalid v0 prefix")
return Undef, ErrInvalidCid{fmt.Errorf("invalid v0 prefix")}
}

hash, err := mh.Sum(data, p.MhType, length)
if err != nil {
return Undef, err
return Undef, ErrInvalidCid{err}
}

switch p.Version {
Expand All @@ -556,7 +578,7 @@ func (p Prefix) Sum(data []byte) (Cid, error) {
case 1:
return NewCidV1(p.Codec, hash), nil
default:
return Undef, fmt.Errorf("invalid cid version")
return Undef, ErrInvalidCid{fmt.Errorf("invalid cid version")}
}
}

Expand Down Expand Up @@ -586,22 +608,22 @@ func PrefixFromBytes(buf []byte) (Prefix, error) {
r := bytes.NewReader(buf)
vers, err := varint.ReadUvarint(r)
if err != nil {
return Prefix{}, err
return Prefix{}, ErrInvalidCid{err}
}

codec, err := varint.ReadUvarint(r)
if err != nil {
return Prefix{}, err
return Prefix{}, ErrInvalidCid{err}
}

mhtype, err := varint.ReadUvarint(r)
if err != nil {
return Prefix{}, err
return Prefix{}, ErrInvalidCid{err}
}

mhlen, err := varint.ReadUvarint(r)
if err != nil {
return Prefix{}, err
return Prefix{}, ErrInvalidCid{err}
}

return Prefix{
Expand All @@ -615,34 +637,34 @@ func PrefixFromBytes(buf []byte) (Prefix, error) {
func CidFromBytes(data []byte) (int, Cid, error) {
if len(data) > 2 && data[0] == mh.SHA2_256 && data[1] == 32 {
if len(data) < 34 {
return 0, Undef, fmt.Errorf("not enough bytes for cid v0")
return 0, Undef, ErrInvalidCid{fmt.Errorf("not enough bytes for cid v0")}
}

h, err := mh.Cast(data[:34])
if err != nil {
return 0, Undef, err
return 0, Undef, ErrInvalidCid{err}
}

return 34, Cid{string(h)}, nil
}

vers, n, err := varint.FromUvarint(data)
if err != nil {
return 0, Undef, err
return 0, Undef, ErrInvalidCid{err}
}

if vers != 1 {
return 0, Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers)
return 0, Undef, ErrInvalidCid{fmt.Errorf("expected 1 as the cid version number, got: %d", vers)}
}

_, cn, err := varint.FromUvarint(data[n:])
if err != nil {
return 0, Undef, err
return 0, Undef, ErrInvalidCid{err}
}

mhnr, _, err := mh.MHFromBytes(data[n+cn:])
if err != nil {
return 0, Undef, err
return 0, Undef, ErrInvalidCid{err}
}

l := n + cn + mhnr
Expand Down Expand Up @@ -705,32 +727,32 @@ func CidFromReader(r io.Reader) (int, Cid, error) {
// The varint package wants a io.ByteReader, so we must wrap our io.Reader.
vers, err := varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
return len(br.dst), Undef, ErrInvalidCid{err}
}

// If we have a CIDv0, read the rest of the bytes and cast the buffer.
if vers == mh.SHA2_256 {
if n, err := io.ReadFull(r, br.dst[1:34]); err != nil {
return len(br.dst) + n, Undef, err
return len(br.dst) + n, Undef, ErrInvalidCid{err}
}

br.dst = br.dst[:34]
h, err := mh.Cast(br.dst)
if err != nil {
return len(br.dst), Undef, err
return len(br.dst), Undef, ErrInvalidCid{err}
}

return len(br.dst), Cid{string(h)}, nil
}

if vers != 1 {
return len(br.dst), Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers)
return len(br.dst), Undef, ErrInvalidCid{fmt.Errorf("expected 1 as the cid version number, got: %d", vers)}
}

// CID block encoding multicodec.
_, err = varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
return len(br.dst), Undef, ErrInvalidCid{err}
}

// We could replace most of the code below with go-multihash's ReadMultihash.
Expand All @@ -741,19 +763,19 @@ func CidFromReader(r io.Reader) (int, Cid, error) {
// Multihash hash function code.
_, err = varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
return len(br.dst), Undef, ErrInvalidCid{err}
}

// Multihash digest length.
mhl, err := varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
return len(br.dst), Undef, ErrInvalidCid{err}
}

// Refuse to make large allocations to prevent OOMs due to bugs.
const maxDigestAlloc = 32 << 20 // 32MiB
if mhl > maxDigestAlloc {
return len(br.dst), Undef, fmt.Errorf("refusing to allocate %d bytes for a digest", mhl)
return len(br.dst), Undef, ErrInvalidCid{fmt.Errorf("refusing to allocate %d bytes for a digest", mhl)}
}

// Fine to convert mhl to int, given maxDigestAlloc.
Expand All @@ -772,15 +794,15 @@ func CidFromReader(r io.Reader) (int, Cid, error) {
if n, err := io.ReadFull(r, br.dst[prefixLength:cidLength]); err != nil {
// We can't use len(br.dst) here,
// as we've only read n bytes past prefixLength.
return prefixLength + n, Undef, err
return prefixLength + n, Undef, ErrInvalidCid{err}
}

// This simply ensures the multihash is valid.
// TODO: consider removing this bit, as it's probably redundant;
// for now, it helps ensure consistency with CidFromBytes.
_, _, err = mh.MHFromBytes(br.dst[mhStart:])
if err != nil {
return len(br.dst), Undef, err
return len(br.dst), Undef, ErrInvalidCid{err}
}

return len(br.dst), Cid{string(br.dst)}, nil
Expand Down
Loading

0 comments on commit b98e249

Please sign in to comment.