Skip to content

Commit

Permalink
Add recursion limit for dynamic code (#358)
Browse files Browse the repository at this point in the history
Prevent stack exhaustion on:

Decoder:

* CopyNext
* Skip
* ReadIntf
* ReadMapStrIntf
* WriteToJSON

Standalone:

* Skip
* ReadMapStrIntfBytes
* ReadIntfBytes
* CopyToJSON
* UnmarshalAsJSON

Limit is set to 100K recursive map/slice operations.
  • Loading branch information
klauspost authored Sep 6, 2024
1 parent bdea0d5 commit b78c5cd
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 30 deletions.
4 changes: 4 additions & 0 deletions msgp/defs.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ const (
last5 = 0x1f
first3 = 0xe0
last7 = 0x7f

// recursionLimit is the limit of recursive calls.
// This limits the call depth of dynamic code, like Skip and interface conversions.
recursionLimit = 100000
)

func isfixint(b byte) bool {
Expand Down
9 changes: 9 additions & 0 deletions msgp/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ var (
// contain the contents of the message
ErrShortBytes error = errShort{}

// ErrRecursion is returned when the maximum recursion limit is reached for an operation.
// This should only realistically be seen on adversarial data trying to exhaust the stack.
ErrRecursion error = errRecursion{}

// this error is only returned
// if we reach code that should
// be unreachable
Expand Down Expand Up @@ -134,6 +138,11 @@ func (f errFatal) Resumable() bool { return false }

func (f errFatal) withContext(ctx string) error { f.ctx = addCtx(f.ctx, ctx); return f }

type errRecursion struct{}

func (e errRecursion) Error() string { return "msgp: recursion limit reached" }
func (e errRecursion) Resumable() bool { return false }

// ArrayError is an error returned
// when decoding a fix-sized array
// of the wrong size
Expand Down
14 changes: 14 additions & 0 deletions msgp/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ func rwMap(dst jsWriter, src *Reader) (n int, err error) {
return dst.WriteString("{}")
}

// This is potentially a recursive call.
if done, err := src.recursiveCall(); err != nil {
return 0, err
} else {
defer done()
}

err = dst.WriteByte('{')
if err != nil {
return
Expand Down Expand Up @@ -162,6 +169,13 @@ func rwArray(dst jsWriter, src *Reader) (n int, err error) {
if err != nil {
return
}
// This is potentially a recursive call.
if done, err := src.recursiveCall(); err != nil {
return 0, err
} else {
defer done()
}

var sz uint32
var nn int
sz, err = src.ReadArrayHeader()
Expand Down
52 changes: 29 additions & 23 deletions msgp/json_bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ import (
"time"
)

var unfuns [_maxtype]func(jsWriter, []byte, []byte) ([]byte, []byte, error)
var unfuns [_maxtype]func(jsWriter, []byte, []byte, int) ([]byte, []byte, error)

func init() {
// NOTE(pmh): this is best expressed as a jump table,
// but gc doesn't do that yet. revisit post-go1.5.
unfuns = [_maxtype]func(jsWriter, []byte, []byte) ([]byte, []byte, error){
unfuns = [_maxtype]func(jsWriter, []byte, []byte, int) ([]byte, []byte, error){
StrType: rwStringBytes,
BinType: rwBytesBytes,
MapType: rwMapBytes,
Expand Down Expand Up @@ -51,15 +51,15 @@ func UnmarshalAsJSON(w io.Writer, msg []byte) ([]byte, error) {
dst = bufio.NewWriterSize(w, 512)
}
for len(msg) > 0 && err == nil {
msg, scratch, err = writeNext(dst, msg, scratch)
msg, scratch, err = writeNext(dst, msg, scratch, 0)
}
if !cast && err == nil {
err = dst.(*bufio.Writer).Flush()
}
return msg, err
}

func writeNext(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func writeNext(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
if len(msg) < 1 {
return msg, scratch, ErrShortBytes
}
Expand All @@ -76,10 +76,13 @@ func writeNext(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
t = TimeType
}
}
return unfuns[t](w, msg, scratch)
return unfuns[t](w, msg, scratch, depth)
}

func rwArrayBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwArrayBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
if depth >= recursionLimit {
return msg, scratch, ErrRecursion
}
sz, msg, err := ReadArrayHeaderBytes(msg)
if err != nil {
return msg, scratch, err
Expand All @@ -95,7 +98,7 @@ func rwArrayBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error
return msg, scratch, err
}
}
msg, scratch, err = writeNext(w, msg, scratch)
msg, scratch, err = writeNext(w, msg, scratch, depth+1)
if err != nil {
return msg, scratch, err
}
Expand All @@ -104,7 +107,10 @@ func rwArrayBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error
return msg, scratch, err
}

func rwMapBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwMapBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
if depth >= recursionLimit {
return msg, scratch, ErrRecursion
}
sz, msg, err := ReadMapHeaderBytes(msg)
if err != nil {
return msg, scratch, err
Expand All @@ -120,15 +126,15 @@ func rwMapBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
return msg, scratch, err
}
}
msg, scratch, err = rwMapKeyBytes(w, msg, scratch)
msg, scratch, err = rwMapKeyBytes(w, msg, scratch, depth)
if err != nil {
return msg, scratch, err
}
err = w.WriteByte(':')
if err != nil {
return msg, scratch, err
}
msg, scratch, err = writeNext(w, msg, scratch)
msg, scratch, err = writeNext(w, msg, scratch, depth+1)
if err != nil {
return msg, scratch, err
}
Expand All @@ -137,17 +143,17 @@ func rwMapBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
return msg, scratch, err
}

func rwMapKeyBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
msg, scratch, err := rwStringBytes(w, msg, scratch)
func rwMapKeyBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
msg, scratch, err := rwStringBytes(w, msg, scratch, depth)
if err != nil {
if tperr, ok := err.(TypeError); ok && tperr.Encoded == BinType {
return rwBytesBytes(w, msg, scratch)
return rwBytesBytes(w, msg, scratch, depth)
}
}
return msg, scratch, err
}

func rwStringBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwStringBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
str, msg, err := ReadStringZC(msg)
if err != nil {
return msg, scratch, err
Expand All @@ -156,7 +162,7 @@ func rwStringBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, erro
return msg, scratch, err
}

func rwBytesBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwBytesBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
bts, msg, err := ReadBytesZC(msg)
if err != nil {
return msg, scratch, err
Expand All @@ -180,7 +186,7 @@ func rwBytesBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error
return msg, scratch, err
}

func rwNullBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwNullBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
msg, err := ReadNilBytes(msg)
if err != nil {
return msg, scratch, err
Expand All @@ -189,7 +195,7 @@ func rwNullBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
return msg, scratch, err
}

func rwBoolBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwBoolBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
b, msg, err := ReadBoolBytes(msg)
if err != nil {
return msg, scratch, err
Expand All @@ -202,7 +208,7 @@ func rwBoolBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
return msg, scratch, err
}

func rwIntBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwIntBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
i, msg, err := ReadInt64Bytes(msg)
if err != nil {
return msg, scratch, err
Expand All @@ -212,7 +218,7 @@ func rwIntBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
return msg, scratch, err
}

func rwUintBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwUintBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
u, msg, err := ReadUint64Bytes(msg)
if err != nil {
return msg, scratch, err
Expand All @@ -222,7 +228,7 @@ func rwUintBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
return msg, scratch, err
}

func rwFloat32Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwFloat32Bytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
var f float32
var err error
f, msg, err = ReadFloat32Bytes(msg)
Expand All @@ -234,7 +240,7 @@ func rwFloat32Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, err
return msg, scratch, err
}

func rwFloat64Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwFloat64Bytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
var f float64
var err error
f, msg, err = ReadFloat64Bytes(msg)
Expand All @@ -246,7 +252,7 @@ func rwFloat64Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, err
return msg, scratch, err
}

func rwTimeBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwTimeBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
var t time.Time
var err error
t, msg, err = ReadTimeBytes(msg)
Expand All @@ -261,7 +267,7 @@ func rwTimeBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
return msg, scratch, err
}

func rwExtensionBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwExtensionBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
var err error
var et int8
et, err = peekExtension(msg)
Expand Down
43 changes: 40 additions & 3 deletions msgp/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,9 @@ type Reader struct {
// is stateless; all the
// buffering is done
// within R.
R *fwd.Reader
scratch []byte
R *fwd.Reader
scratch []byte
recursionDepth int
}

// Read implements `io.Reader`
Expand Down Expand Up @@ -190,6 +191,11 @@ func (m *Reader) CopyNext(w io.Writer) (int64, error) {
return n, io.ErrShortWrite
}

if done, err := m.recursiveCall(); err != nil {
return n, err
} else {
defer done()
}
// for maps and slices, read elements
for x := uintptr(0); x < o; x++ {
var n2 int64
Expand All @@ -202,6 +208,18 @@ func (m *Reader) CopyNext(w io.Writer) (int64, error) {
return n, nil
}

// recursiveCall will increment the recursion depth and return an error if it is exceeded.
// If a nil error is returned, done must be called to decrement the counter.
func (m *Reader) recursiveCall() (done func(), err error) {
if m.recursionDepth >= recursionLimit {
return func() {}, ErrRecursion
}
m.recursionDepth++
return func() {
m.recursionDepth--
}, nil
}

// ReadFull implements `io.ReadFull`
func (m *Reader) ReadFull(p []byte) (int, error) {
return m.R.ReadFull(p)
Expand Down Expand Up @@ -332,7 +350,12 @@ func (m *Reader) Skip() error {
return err
}

// for maps and slices, skip elements
// for maps and slices, skip elements with recursive call
if done, err := m.recursiveCall(); err != nil {
return err
} else {
defer done()
}
for x := uintptr(0); x < o; x++ {
err = m.Skip()
if err != nil {
Expand Down Expand Up @@ -1333,6 +1356,13 @@ func (m *Reader) ReadIntf() (i interface{}, err error) {
return

case MapType:
// This can call back here, so treat as recursive call.
if done, err := m.recursiveCall(); err != nil {
return nil, err
} else {
defer done()
}

mp := make(map[string]interface{})
err = m.ReadMapStrIntf(mp)
i = mp
Expand All @@ -1358,6 +1388,13 @@ func (m *Reader) ReadIntf() (i interface{}, err error) {
if err != nil {
return
}

if done, err := m.recursiveCall(); err != nil {
return nil, err
} else {
defer done()
}

out := make([]interface{}, int(sz))
for j := range out {
out[j], err = m.ReadIntf()
Expand Down
Loading

0 comments on commit b78c5cd

Please sign in to comment.