Skip to content

Commit 6897f36

Browse files
committed
make decoder more reliable
The original Java code didn't perform any bounds checks. Thus, the original Go translation didn't either. This patch updates the API to return an error from Decompress(), and adds bound checks.
1 parent 1716209 commit 6897f36

File tree

3 files changed

+83
-26
lines changed

3 files changed

+83
-26
lines changed

fuzz.go

+2-4
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
package quicklz
44

5-
import "encoding/binary"
6-
75
func Fuzz(data []byte) int {
86

97
if len(data) < 5 {
@@ -16,12 +14,12 @@ func Fuzz(data []byte) int {
1614

1715
}
1816

19-
ln := binary.LittleEndian.Uint32(data[1:])
17+
ln, _ := sizeDecompressed(data)
2018
if ln > (1 << 21) {
2119
return 0
2220
}
2321

24-
if b := Decompress(data); b == nil {
22+
if _, err := Decompress(data); err != nil {
2523
return 0
2624
}
2725

quicklz.go

+76-21
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ Licensed under the GPL, like the original.
66
*/
77
package quicklz
88

9+
import "errors"
10+
911
const (
1012
// Streaming mode not supported
1113
QLZ_STREAMING_BUFFER = 0
@@ -36,27 +38,30 @@ func headerLen(source []byte) int {
3638
return 3
3739
}
3840

39-
func sizeDecompressed(source []byte) int {
41+
func sizeDecompressed(source []byte) (int, error) {
4042
if headerLen(source) == 9 {
4143
return fastRead(source, 5, 4)
4244
}
4345
return fastRead(source, 2, 1)
4446

4547
}
4648

47-
func sizeCompressed(source []byte) int {
49+
func sizeCompressed(source []byte) (int, error) {
4850
if headerLen(source) == 9 {
4951
return fastRead(source, 1, 4)
5052
}
5153
return fastRead(source, 1, 1)
5254
}
5355

54-
func fastRead(a []byte, i, numbytes int) int {
56+
func fastRead(a []byte, i, numbytes int) (int, error) {
5557
l := 0
58+
if len(a) < i+numbytes {
59+
return 0, ErrCorrupt
60+
}
5661
for j := 0; j < numbytes; j++ {
5762
l |= int(a[i+j]) << (uint(j) * 8)
5863
}
59-
return l
64+
return l, nil
6065
}
6166

6267
func fastWrite(a []byte, i, value, numbytes int) {
@@ -112,7 +117,7 @@ func Compress(source []byte, level int) []byte {
112117
}
113118

114119
if src <= lastMatchStart {
115-
fetch = fastRead(source, src, 3)
120+
fetch, _ = fastRead(source, src, 3)
116121
}
117122

118123
for src <= lastMatchStart {
@@ -180,7 +185,7 @@ func Compress(source []byte, level int) []byte {
180185
}
181186
}
182187
lits = 0
183-
fetch = fastRead(source, src, 3)
188+
fetch, _ = fastRead(source, src, 3)
184189
} else {
185190
lits++
186191
hashCounter[hash] = 1
@@ -191,7 +196,7 @@ func Compress(source []byte, level int) []byte {
191196
fetch = (fetch>>8)&0xffff | int(source[src+2])<<16
192197
}
193198
} else {
194-
fetch = fastRead(source, src, 3)
199+
fetch, _ = fastRead(source, src, 3)
195200

196201
var o, offset2 int
197202
var matchlen, k, m int
@@ -230,7 +235,7 @@ func Compress(source []byte, level int) []byte {
230235
if matchlen >= 3 && src-o < 131071 {
231236
offset := src - o
232237
for u := 1; u < matchlen; u++ {
233-
fetch = fastRead(source, src+u, 3)
238+
fetch, _ = fastRead(source, src+u, 3)
234239
hash = ((fetch >> 12) ^ fetch) & (HASH_VALUES - 1)
235240
c = hashCounter[hash]
236241
hashCounter[hash]++
@@ -289,8 +294,16 @@ func Compress(source []byte, level int) []byte {
289294
return d2
290295
}
291296

292-
func Decompress(source []byte) []byte {
293-
size := sizeDecompressed(source)
297+
var (
298+
ErrCorrupt = errors.New("quicklz: corrupt document")
299+
ErrInvalidVersion = errors.New("quicklz: unsupported compression version")
300+
)
301+
302+
func Decompress(source []byte) ([]byte, error) {
303+
size, err := sizeDecompressed(source)
304+
if err != nil || size < 0 {
305+
return nil, ErrCorrupt
306+
}
294307
src := headerLen(source)
295308
var dst int
296309
var cwordVal = 1
@@ -305,24 +318,35 @@ func Decompress(source []byte) []byte {
305318
level := (source[0] >> 2) & 0x3
306319

307320
if level != 1 && level != 3 {
308-
panic("Go version only supports level 1 and 3")
321+
return nil, ErrInvalidVersion
309322
}
310323

311324
if (source[0] & 1) != 1 {
312325
d2 := make([]byte, size)
313-
copy(d2, source[headerLen(source):])
314-
return d2
326+
l := headerLen(source)
327+
if len(source) < l {
328+
return nil, ErrCorrupt
329+
}
330+
copy(d2, source[l:])
331+
return d2, nil
315332
}
316333

317334
for {
318335
if cwordVal == 1 {
319-
cwordVal = fastRead(source, src, 4)
336+
var err error
337+
cwordVal, err = fastRead(source, src, 4)
338+
if err != nil {
339+
return nil, ErrCorrupt
340+
}
320341
src += 4
321342
if dst <= lastMatchStart {
322343
if level == 1 {
323-
fetch = fastRead(source, src, 3)
344+
fetch, err = fastRead(source, src, 3)
324345
} else {
325-
fetch = fastRead(source, src, 4)
346+
fetch, err = fastRead(source, src, 4)
347+
}
348+
if err != nil {
349+
return nil, ErrCorrupt
326350
}
327351
}
328352
}
@@ -341,6 +365,9 @@ func Decompress(source []byte) []byte {
341365
matchlen = (fetch & 0xf) + 2
342366
src += 2
343367
} else {
368+
if len(source) <= src+2 {
369+
return nil, ErrCorrupt
370+
}
344371
matchlen = int(source[src+2]) & 0xff
345372
src += 3
346373
}
@@ -371,6 +398,10 @@ func Decompress(source []byte) []byte {
371398
offset2 = int(dst - offset)
372399
}
373400

401+
if matchlen < 0 || offset2 < 0 || len(destination) <= dst+2 || len(destination) <= offset2+matchlen || len(destination) <= dst+matchlen {
402+
return nil, ErrCorrupt
403+
}
404+
374405
destination[dst+0] = destination[offset2+0]
375406
destination[dst+1] = destination[offset2+1]
376407
destination[dst+2] = destination[offset2+2]
@@ -381,17 +412,29 @@ func Decompress(source []byte) []byte {
381412
dst += matchlen
382413

383414
if level == 1 {
384-
fetch = fastRead(destination, lastHashed+1, 3) // destination[lastHashed + 1] | (destination[lastHashed + 2] << 8) | (destination[lastHashed + 3] << 16);
415+
fetch, err = fastRead(destination, lastHashed+1, 3) // destination[lastHashed + 1] | (destination[lastHashed + 2] << 8) | (destination[lastHashed + 3] << 16);
416+
if err != nil {
417+
return nil, ErrCorrupt
418+
}
385419
for lastHashed < dst-matchlen {
386420
lastHashed++
387421
hash = ((fetch >> 12) ^ fetch) & (HASH_VALUES - 1)
388422
hashtable[hash] = lastHashed
389423
hashCounter[hash] = 1
424+
if len(destination) <= lastHashed+3 {
425+
return nil, ErrCorrupt
426+
}
390427
fetch = (fetch >> 8 & 0xffff) | (int(destination[lastHashed+3]) << 16)
391428
}
392-
fetch = fastRead(source, src, 3)
429+
fetch, err = fastRead(source, src, 3)
430+
if err != nil {
431+
return nil, ErrCorrupt
432+
}
393433
} else {
394-
fetch = fastRead(source, src, 4)
434+
fetch, err = fastRead(source, src, 4)
435+
if err != nil {
436+
return nil, ErrCorrupt
437+
}
395438
}
396439
lastHashed = dst - 1
397440
} else {
@@ -404,13 +447,22 @@ func Decompress(source []byte) []byte {
404447
if level == 1 {
405448
for lastHashed < dst-3 {
406449
lastHashed++
407-
fetch2 := fastRead(destination, lastHashed, 3)
450+
fetch2, err := fastRead(destination, lastHashed, 3)
451+
if err != nil {
452+
return nil, ErrCorrupt
453+
}
408454
hash = ((fetch2 >> 12) ^ fetch2) & (HASH_VALUES - 1)
409455
hashtable[hash] = lastHashed
410456
hashCounter[hash] = 1
411457
}
458+
if len(source) <= src+2 {
459+
return nil, ErrCorrupt
460+
}
412461
fetch = fetch>>8&0xffff | int(source[src+2])<<16
413462
} else {
463+
if len(source) <= src+3 {
464+
return nil, ErrCorrupt
465+
}
414466
fetch = fetch>>8&0xffff | int(source[src+2])<<16 | int(source[src+3])<<24
415467
}
416468
} else {
@@ -420,12 +472,15 @@ func Decompress(source []byte) []byte {
420472
cwordVal = 0x80000000
421473
}
422474

475+
if len(destination) <= dst || len(source) <= src {
476+
return nil, ErrCorrupt
477+
}
423478
destination[dst] = source[src]
424479
dst++
425480
src++
426481
cwordVal = cwordVal >> 1
427482
}
428-
return destination
483+
return destination, nil
429484
}
430485
}
431486
}

quicklz_test.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@ func TestCompress(t *testing.T) {
1515

1616
qz := Compress(in[:i], 1)
1717

18-
out := Decompress(qz)
18+
out, err := Decompress(qz)
19+
if err != nil {
20+
t.Errorf("roundtrip error length %d: %v", i, err)
21+
}
22+
1923
if !bytes.Equal(in[:i], out) {
2024
offs := dump(t, "o", out, "i", in[:i])
2125
t.Log("\n" + hex.Dump(qz))

0 commit comments

Comments
 (0)