@@ -782,6 +782,7 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize
782
782
ip += 4 ;
783
783
}
784
784
785
+ frameSizeInfo .nbBlocks = nbBlocks ;
785
786
frameSizeInfo .compressedSize = (size_t )(ip - ipstart );
786
787
frameSizeInfo .decompressedBound = (zfh .frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN )
787
788
? zfh .frameContentSize
@@ -825,6 +826,48 @@ unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize)
825
826
return bound ;
826
827
}
827
828
829
+ size_t ZSTD_decompressionMargin (void const * src , size_t srcSize )
830
+ {
831
+ size_t margin = 0 ;
832
+ unsigned maxBlockSize = 0 ;
833
+
834
+ /* Iterate over each frame */
835
+ while (srcSize > 0 ) {
836
+ ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo (src , srcSize );
837
+ size_t const compressedSize = frameSizeInfo .compressedSize ;
838
+ unsigned long long const decompressedBound = frameSizeInfo .decompressedBound ;
839
+ ZSTD_frameHeader zfh ;
840
+
841
+ FORWARD_IF_ERROR (ZSTD_getFrameHeader (& zfh , src , srcSize ), "" );
842
+ if (ZSTD_isError (compressedSize ) || decompressedBound == ZSTD_CONTENTSIZE_ERROR )
843
+ return ERROR (corruption_detected );
844
+
845
+ if (zfh .frameType == ZSTD_frame ) {
846
+ /* Add the frame header to our margin */
847
+ margin += zfh .headerSize ;
848
+ /* Add the checksum to our margin */
849
+ margin += zfh .checksumFlag ? 4 : 0 ;
850
+ /* Add 3 bytes per block */
851
+ margin += 3 * frameSizeInfo .nbBlocks ;
852
+
853
+ /* Compute the max block size */
854
+ maxBlockSize = MAX (maxBlockSize , zfh .blockSizeMax );
855
+ } else {
856
+ assert (zfh .frameType == ZSTD_skippableFrame );
857
+ /* Add the entire skippable frame size to our margin. */
858
+ margin += compressedSize ;
859
+ }
860
+
861
+ assert (srcSize >= compressedSize );
862
+ src = (const BYTE * )src + compressedSize ;
863
+ srcSize -= compressedSize ;
864
+ }
865
+
866
+ /* Add the max block size back to the margin. */
867
+ margin += maxBlockSize ;
868
+
869
+ return margin ;
870
+ }
828
871
829
872
/*-*************************************************************
830
873
* Frame decoding
@@ -850,7 +893,7 @@ static size_t ZSTD_copyRawBlock(void* dst, size_t dstCapacity,
850
893
if (srcSize == 0 ) return 0 ;
851
894
RETURN_ERROR (dstBuffer_null , "" );
852
895
}
853
- ZSTD_memcpy (dst , src , srcSize );
896
+ ZSTD_memmove (dst , src , srcSize );
854
897
return srcSize ;
855
898
}
856
899
@@ -928,6 +971,7 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx,
928
971
929
972
/* Loop on each block */
930
973
while (1 ) {
974
+ BYTE * oBlockEnd = oend ;
931
975
size_t decodedSize ;
932
976
blockProperties_t blockProperties ;
933
977
size_t const cBlockSize = ZSTD_getcBlockSize (ip , remainingSrcSize , & blockProperties );
@@ -937,16 +981,34 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx,
937
981
remainingSrcSize -= ZSTD_blockHeaderSize ;
938
982
RETURN_ERROR_IF (cBlockSize > remainingSrcSize , srcSize_wrong , "" );
939
983
984
+ if (ip >= op && ip < oBlockEnd ) {
985
+ /* We are decompressing in-place. Limit the output pointer so that we
986
+ * don't overwrite the block that we are currently reading. This will
987
+ * fail decompression if the input & output pointers aren't spaced
988
+ * far enough apart.
989
+ *
990
+ * This is important to set, even when the pointers are far enough
991
+ * apart, because ZSTD_decompressBlock_internal() can decide to store
992
+ * literals in the output buffer, after the block it is decompressing.
993
+ * Since we don't want anything to overwrite our input, we have to tell
994
+ * ZSTD_decompressBlock_internal to never write past ip.
995
+ *
996
+ * See ZSTD_allocateLiteralsBuffer() for reference.
997
+ */
998
+ oBlockEnd = op + (ip - op );
999
+ }
1000
+
940
1001
switch (blockProperties .blockType )
941
1002
{
942
1003
case bt_compressed :
943
- decodedSize = ZSTD_decompressBlock_internal (dctx , op , (size_t )(oend - op ), ip , cBlockSize , /* frame */ 1 , not_streaming );
1004
+ decodedSize = ZSTD_decompressBlock_internal (dctx , op , (size_t )(oBlockEnd - op ), ip , cBlockSize , /* frame */ 1 , not_streaming );
944
1005
break ;
945
1006
case bt_raw :
1007
+ /* Use oend instead of oBlockEnd because this function is safe to overlap. It uses memmove. */
946
1008
decodedSize = ZSTD_copyRawBlock (op , (size_t )(oend - op ), ip , cBlockSize );
947
1009
break ;
948
1010
case bt_rle :
949
- decodedSize = ZSTD_setRleBlock (op , (size_t )(oend - op ), * ip , blockProperties .origSize );
1011
+ decodedSize = ZSTD_setRleBlock (op , (size_t )(oBlockEnd - op ), * ip , blockProperties .origSize );
950
1012
break ;
951
1013
case bt_reserved :
952
1014
default :
0 commit comments