Skip to content

Commit 5b26619

Browse files
committed
Add support for in-place decompression
* Add a function and macro ZSTD_decompressionMargin() that computes the decompression margin for in-place decompression. The function computes a tight margin that works in all cases, and the macro computes an upper bound that will only work if flush isn't used. * When doing in-place decompression, make sure that our output buffer doesn't overlap with the input buffer. This ensures that we don't decide to use the portion of the output buffer that overlaps the input buffer for temporary memory, like for literals. * Add a simple unit test. * Add in-place decompression to the simple_round_trip and stream_round_trip fuzzers. This should help verify that our margin stays correct.
1 parent 423500d commit 5b26619

File tree

7 files changed

+227
-10
lines changed

7 files changed

+227
-10
lines changed

lib/common/zstd_internal.h

+1
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ MEM_STATIC ZSTD_sequenceLength ZSTD_getSequenceLength(seqStore_t const* seqStore
341341
* `decompressedBound != ZSTD_CONTENTSIZE_ERROR`
342342
*/
343343
typedef struct {
344+
size_t nbBlocks;
344345
size_t compressedSize;
345346
unsigned long long decompressedBound;
346347
} ZSTD_frameSizeInfo; /* decompress & legacy */

lib/decompress/zstd_decompress.c

+65-3
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,7 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize
782782
ip += 4;
783783
}
784784

785+
frameSizeInfo.nbBlocks = nbBlocks;
785786
frameSizeInfo.compressedSize = (size_t)(ip - ipstart);
786787
frameSizeInfo.decompressedBound = (zfh.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN)
787788
? zfh.frameContentSize
@@ -825,6 +826,48 @@ unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize)
825826
return bound;
826827
}
827828

829+
size_t ZSTD_decompressionMargin(void const* src, size_t srcSize)
830+
{
831+
size_t margin = 0;
832+
unsigned maxBlockSize = 0;
833+
834+
/* Iterate over each frame */
835+
while (srcSize > 0) {
836+
ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize);
837+
size_t const compressedSize = frameSizeInfo.compressedSize;
838+
unsigned long long const decompressedBound = frameSizeInfo.decompressedBound;
839+
ZSTD_frameHeader zfh;
840+
841+
FORWARD_IF_ERROR(ZSTD_getFrameHeader(&zfh, src, srcSize), "");
842+
if (ZSTD_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR)
843+
return ERROR(corruption_detected);
844+
845+
if (zfh.frameType == ZSTD_frame) {
846+
/* Add the frame header to our margin */
847+
margin += zfh.headerSize;
848+
/* Add the checksum to our margin */
849+
margin += zfh.checksumFlag ? 4 : 0;
850+
/* Add 3 bytes per block */
851+
margin += 3 * frameSizeInfo.nbBlocks;
852+
853+
/* Compute the max block size */
854+
maxBlockSize = MAX(maxBlockSize, zfh.blockSizeMax);
855+
} else {
856+
assert(zfh.frameType == ZSTD_skippableFrame);
857+
/* Add the entire skippable frame size to our margin. */
858+
margin += compressedSize;
859+
}
860+
861+
assert(srcSize >= compressedSize);
862+
src = (const BYTE*)src + compressedSize;
863+
srcSize -= compressedSize;
864+
}
865+
866+
/* Add the max block size back to the margin. */
867+
margin += maxBlockSize;
868+
869+
return margin;
870+
}
828871

829872
/*-*************************************************************
830873
* Frame decoding
@@ -850,7 +893,7 @@ static size_t ZSTD_copyRawBlock(void* dst, size_t dstCapacity,
850893
if (srcSize == 0) return 0;
851894
RETURN_ERROR(dstBuffer_null, "");
852895
}
853-
ZSTD_memcpy(dst, src, srcSize);
896+
ZSTD_memmove(dst, src, srcSize);
854897
return srcSize;
855898
}
856899

@@ -928,6 +971,7 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx,
928971

929972
/* Loop on each block */
930973
while (1) {
974+
BYTE* oBlockEnd = oend;
931975
size_t decodedSize;
932976
blockProperties_t blockProperties;
933977
size_t const cBlockSize = ZSTD_getcBlockSize(ip, remainingSrcSize, &blockProperties);
@@ -937,16 +981,34 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx,
937981
remainingSrcSize -= ZSTD_blockHeaderSize;
938982
RETURN_ERROR_IF(cBlockSize > remainingSrcSize, srcSize_wrong, "");
939983

984+
if (ip >= op && ip < oBlockEnd) {
985+
/* We are decompressing in-place. Limit the output pointer so that we
986+
* don't overwrite the block that we are currently reading. This will
987+
* fail decompression if the input & output pointers aren't spaced
988+
* far enough apart.
989+
*
990+
* This is important to set, even when the pointers are far enough
991+
* apart, because ZSTD_decompressBlock_internal() can decide to store
992+
* literals in the output buffer, after the block it is decompressing.
993+
* Since we don't want anything to overwrite our input, we have to tell
994+
* ZSTD_decompressBlock_internal to never write past ip.
995+
*
996+
* See ZSTD_allocateLiteralsBuffer() for reference.
997+
*/
998+
oBlockEnd = op + (ip - op);
999+
}
1000+
9401001
switch(blockProperties.blockType)
9411002
{
9421003
case bt_compressed:
943-
decodedSize = ZSTD_decompressBlock_internal(dctx, op, (size_t)(oend-op), ip, cBlockSize, /* frame */ 1, not_streaming);
1004+
decodedSize = ZSTD_decompressBlock_internal(dctx, op, (size_t)(oBlockEnd-op), ip, cBlockSize, /* frame */ 1, not_streaming);
9441005
break;
9451006
case bt_raw :
1007+
/* Use oend instead of oBlockEnd because this function is safe to overlap. It uses memmove. */
9461008
decodedSize = ZSTD_copyRawBlock(op, (size_t)(oend-op), ip, cBlockSize);
9471009
break;
9481010
case bt_rle :
949-
decodedSize = ZSTD_setRleBlock(op, (size_t)(oend-op), *ip, blockProperties.origSize);
1011+
decodedSize = ZSTD_setRleBlock(op, (size_t)(oBlockEnd-op), *ip, blockProperties.origSize);
9501012
break;
9511013
case bt_reserved :
9521014
default:

lib/legacy/zstd_legacy.h

+7
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,13 @@ MEM_STATIC ZSTD_frameSizeInfo ZSTD_findFrameSizeInfoLegacy(const void *src, size
242242
frameSizeInfo.compressedSize = ERROR(srcSize_wrong);
243243
frameSizeInfo.decompressedBound = ZSTD_CONTENTSIZE_ERROR;
244244
}
245+
/* In all cases, decompressedBound == nbBlocks * ZSTD_BLOCKSIZE_MAX.
246+
* So we can compute nbBlocks without having to change every function.
247+
*/
248+
if (frameSizeInfo.decompressedBound != ZSTD_CONTENTSIZE_ERROR) {
249+
assert((frameSizeInfo.decompressedBound & (ZSTD_BLOCKSIZE_MAX - 1)) == 0);
250+
frameSizeInfo.nbBlocks = (size_t)(frameSizeInfo.decompressedBound / ZSTD_BLOCKSIZE_MAX);
251+
}
245252
return frameSizeInfo;
246253
}
247254

lib/zstd.h

+45
Original file line numberDiff line numberDiff line change
@@ -1427,6 +1427,51 @@ ZSTDLIB_STATIC_API unsigned long long ZSTD_decompressBound(const void* src, size
14271427
* or an error code (if srcSize is too small) */
14281428
ZSTDLIB_STATIC_API size_t ZSTD_frameHeaderSize(const void* src, size_t srcSize);
14291429

1430+
/*! ZSTD_decompressionMargin() :
1431+
* Zstd supports in-place decompression, where the input and output buffers overlap.
1432+
* In this case, the output buffer must be at least (Margin + Output_Size) bytes large,
1433+
* and the input buffer must be at the end of the output buffer.
1434+
*
1435+
* _______________________ Output Buffer ________________________
1436+
* | |
1437+
* | ____ Input Buffer ____|
1438+
* | | |
1439+
* v v v
1440+
* |---------------------------------------|-----------|----------|
1441+
* ^ ^ ^
1442+
* |___________________ Output_Size ___________________|_ Margin _|
1443+
*
1444+
* NOTE: See also ZSTD_DECOMPRESSION_MARGIN().
1445+
* NOTE: This applies only to single-pass decompression through ZSTD_decompress() or
1446+
* ZSTD_decompressDCtx().
1447+
* NOTE: This function supports multi-frame input.
1448+
*
1449+
* @param src The compressed frame(s)
1450+
* @param srcSize The size of the compressed frame(s)
1451+
* @returns The decompression margin or an error that can be checked with ZSTD_isError().
1452+
*/
1453+
ZSTDLIB_STATIC_API size_t ZSTD_decompressionMargin(const void* src, size_t srcSize);
1454+
1455+
/*! ZSTD_DECOMPRESS_MARGIN() :
1456+
* Similar to ZSTD_decompressionMargin(), but instead of computing the margin from
1457+
* the compressed frame, compute it from the original size and the blockSizeLog.
1458+
* See ZSTD_decompressionMargin() for details.
1459+
*
1460+
* WARNING: This macro does not support multi-frame input, the input must be a single
1461+
* zstd frame. If you need that support use the function, or implement it yourself.
1462+
*
1463+
* @param originalSize The original uncompressed size of the data.
1464+
* @param blockSize The block size == MIN(windowSize, ZSTD_BLOCKSIZE_MAX).
1465+
* Unless you explicitly set the windowLog smaller than
1466+
* ZSTD_BLOCKSIZELOG_MAX you can just use ZSTD_BLOCKSIZE_MAX.
1467+
*/
1468+
#define ZSTD_DECOMPRESSION_MARGIN(originalSize, blockSize) ((size_t)( \
1469+
ZSTD_FRAMEHEADERSIZE_MAX /* Frame header */ + \
1470+
4 /* checksum */ + \
1471+
((originalSize) == 0 ? 0 : 3 * (((originalSize) + (blockSize) - 1) / blockSize)) /* 3 bytes per block */ + \
1472+
(blockSize) /* One block of margin */ \
1473+
))
1474+
14301475
typedef enum {
14311476
ZSTD_sf_noBlockDelimiters = 0, /* Representation of ZSTD_Sequence has no block delimiters, sequences only */
14321477
ZSTD_sf_explicitBlockDelimiters = 1 /* Representation of ZSTD_Sequence contains explicit block delimiters */

tests/fuzz/simple_round_trip.c

+37-7
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,23 @@
2626
static ZSTD_CCtx *cctx = NULL;
2727
static ZSTD_DCtx *dctx = NULL;
2828

29+
static size_t getDecompressionMargin(void const* compressed, size_t cSize, size_t srcSize, int hasSmallBlocks)
30+
{
31+
size_t margin = ZSTD_decompressionMargin(compressed, cSize);
32+
if (!hasSmallBlocks) {
33+
/* The macro should be correct in this case, but it may be smaller
34+
* because of e.g. block splitting, so take the smaller of the two.
35+
*/
36+
ZSTD_frameHeader zfh;
37+
size_t marginM;
38+
FUZZ_ZASSERT(ZSTD_getFrameHeader(&zfh, compressed, cSize));
39+
marginM = ZSTD_DECOMPRESSION_MARGIN(srcSize, zfh.blockSizeMax);
40+
if (marginM < margin)
41+
margin = marginM;
42+
}
43+
return margin;
44+
}
45+
2946
static size_t roundTripTest(void *result, size_t resultCapacity,
3047
void *compressed, size_t compressedCapacity,
3148
const void *src, size_t srcSize,
@@ -67,6 +84,25 @@ static size_t roundTripTest(void *result, size_t resultCapacity,
6784
}
6885
dSize = ZSTD_decompressDCtx(dctx, result, resultCapacity, compressed, cSize);
6986
FUZZ_ZASSERT(dSize);
87+
FUZZ_ASSERT_MSG(dSize == srcSize, "Incorrect regenerated size");
88+
FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, result, dSize), "Corruption!");
89+
90+
{
91+
size_t margin = getDecompressionMargin(compressed, cSize, srcSize, targetCBlockSize);
92+
size_t const outputSize = srcSize + margin;
93+
char* const output = (char*)FUZZ_malloc(outputSize);
94+
char* const input = output + outputSize - cSize;
95+
FUZZ_ASSERT(outputSize >= cSize);
96+
memcpy(input, compressed, cSize);
97+
98+
dSize = ZSTD_decompressDCtx(dctx, output, outputSize, input, cSize);
99+
FUZZ_ZASSERT(dSize);
100+
FUZZ_ASSERT_MSG(dSize == srcSize, "Incorrect regenerated size");
101+
FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, output, srcSize), "Corruption!");
102+
103+
free(output);
104+
}
105+
70106
/* When superblock is enabled make sure we don't expand the block more than expected.
71107
* NOTE: This test is currently disabled because superblock mode can arbitrarily
72108
* expand the block in the worst case. Once superblock mode has been improved we can
@@ -120,13 +156,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *src, size_t size)
120156
FUZZ_ASSERT(dctx);
121157
}
122158

123-
{
124-
size_t const result =
125-
roundTripTest(rBuf, rBufSize, cBuf, cBufSize, src, size, producer);
126-
FUZZ_ZASSERT(result);
127-
FUZZ_ASSERT_MSG(result == size, "Incorrect regenerated size");
128-
FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, rBuf, size), "Corruption!");
129-
}
159+
roundTripTest(rBuf, rBufSize, cBuf, cBufSize, src, size, producer);
130160
free(rBuf);
131161
free(cBuf);
132162
FUZZ_dataProducer_free(producer);

tests/fuzz/stream_round_trip.c

+18
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,24 @@ int LLVMFuzzerTestOneInput(const uint8_t *src, size_t size)
166166
FUZZ_ZASSERT(rSize);
167167
FUZZ_ASSERT_MSG(rSize == size, "Incorrect regenerated size");
168168
FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, rBuf, size), "Corruption!");
169+
170+
/* Test in-place decompression (note the macro doesn't work in this case) */
171+
{
172+
size_t const margin = ZSTD_decompressionMargin(cBuf, cSize);
173+
size_t const outputSize = size + margin;
174+
char* const output = (char*)FUZZ_malloc(outputSize);
175+
char* const input = output + outputSize - cSize;
176+
size_t dSize;
177+
FUZZ_ASSERT(outputSize >= cSize);
178+
memcpy(input, cBuf, cSize);
179+
180+
dSize = ZSTD_decompressDCtx(dctx, output, outputSize, input, cSize);
181+
FUZZ_ZASSERT(dSize);
182+
FUZZ_ASSERT_MSG(dSize == size, "Incorrect regenerated size");
183+
FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, output, size), "Corruption!");
184+
185+
free(output);
186+
}
169187
}
170188

171189
FUZZ_dataProducer_free(producer);

tests/fuzzer.c

+54
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,60 @@ static int basicUnitTests(U32 const seed, double compressibility)
12201220
}
12211221
DISPLAYLEVEL(3, "OK \n");
12221222

1223+
DISPLAYLEVEL(3, "test%3i : in-place decompression : ", testNb++);
1224+
cSize = ZSTD_compress(compressedBuffer, compressedBufferSize, CNBuffer, CNBuffSize, -ZSTD_BLOCKSIZE_MAX);
1225+
CHECK_Z(cSize);
1226+
CHECK_LT(CNBuffSize, cSize);
1227+
{
1228+
size_t const margin = ZSTD_decompressionMargin(compressedBuffer, cSize);
1229+
size_t const outputSize = (CNBuffSize + margin);
1230+
char* output = malloc(outputSize);
1231+
char* input = output + outputSize - cSize;
1232+
CHECK_LT(cSize, CNBuffSize + margin);
1233+
CHECK(output != NULL);
1234+
CHECK_Z(margin);
1235+
CHECK(margin <= ZSTD_DECOMPRESSION_MARGIN(CNBuffSize, ZSTD_BLOCKSIZE_MAX));
1236+
memcpy(input, compressedBuffer, cSize);
1237+
1238+
{
1239+
size_t const dSize = ZSTD_decompress(output, outputSize, input, cSize);
1240+
CHECK_Z(dSize);
1241+
CHECK_EQ(dSize, CNBuffSize);
1242+
}
1243+
CHECK(!memcmp(output, CNBuffer, CNBuffSize));
1244+
free(output);
1245+
}
1246+
DISPLAYLEVEL(3, "OK \n");
1247+
1248+
DISPLAYLEVEL(3, "test%3i : in-place decompression with 2 frames : ", testNb++);
1249+
cSize = ZSTD_compress(compressedBuffer, compressedBufferSize, CNBuffer, CNBuffSize / 3, -ZSTD_BLOCKSIZE_MAX);
1250+
CHECK_Z(cSize);
1251+
{
1252+
size_t const cSize2 = ZSTD_compress((char*)compressedBuffer + cSize, compressedBufferSize - cSize, (char const*)CNBuffer + (CNBuffSize / 3), CNBuffSize / 3, -ZSTD_BLOCKSIZE_MAX);
1253+
CHECK_Z(cSize2);
1254+
cSize += cSize2;
1255+
}
1256+
{
1257+
size_t const srcSize = (CNBuffSize / 3) * 2;
1258+
size_t const margin = ZSTD_decompressionMargin(compressedBuffer, cSize);
1259+
size_t const outputSize = (CNBuffSize + margin);
1260+
char* output = malloc(outputSize);
1261+
char* input = output + outputSize - cSize;
1262+
CHECK_LT(cSize, CNBuffSize + margin);
1263+
CHECK(output != NULL);
1264+
CHECK_Z(margin);
1265+
memcpy(input, compressedBuffer, cSize);
1266+
1267+
{
1268+
size_t const dSize = ZSTD_decompress(output, outputSize, input, cSize);
1269+
CHECK_Z(dSize);
1270+
CHECK_EQ(dSize, srcSize);
1271+
}
1272+
CHECK(!memcmp(output, CNBuffer, srcSize));
1273+
free(output);
1274+
}
1275+
DISPLAYLEVEL(3, "OK \n");
1276+
12231277
DISPLAYLEVEL(3, "test%3d: superblock uncompressible data, too many nocompress superblocks : ", testNb++);
12241278
{
12251279
ZSTD_CCtx* const cctx = ZSTD_createCCtx();

0 commit comments

Comments
 (0)