Skip to content

Commit

Permalink
Merge pull request #1932 from felixhandte/diagnose-corruption-dicts
Browse files Browse the repository at this point in the history
Add Support for Dictionaries in Corruption Diagnosis Tool
  • Loading branch information
felixhandte authored Dec 19, 2019
2 parents 4ae0f68 + fe454c0 commit 07ad866
Showing 1 changed file with 186 additions and 6 deletions.
192 changes: 186 additions & 6 deletions contrib/diagnose_corruption/check_flipped_bits.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
* You may select, at your option, one of the above-listed licenses.
*/

#define ZSTD_STATIC_LINKING_ONLY
#include "zstd.h"
#include "zstd_errors.h"

#include <stdio.h>
#include <stdlib.h>
Expand All @@ -26,20 +28,50 @@ typedef struct {
char *output;
size_t output_size;

const char *dict_file_name;
const char *dict_file_dir_name;
int32_t dict_id;
char *dict;
size_t dict_size;
ZSTD_DDict* ddict;

ZSTD_DCtx* dctx;

int success_count;
int error_counts[ZSTD_error_maxCode];
} stuff_t;

static void free_stuff(stuff_t* stuff) {
free(stuff->input);
free(stuff->output);
ZSTD_freeDDict(stuff->ddict);
free(stuff->dict);
ZSTD_freeDCtx(stuff->dctx);
}

static void usage(void) {
fprintf(stderr, "check_flipped_bits input_filename");
fprintf(stderr, "check_flipped_bits input_filename [-d dict] [-D dict_dir]\n");
fprintf(stderr, "\n");
fprintf(stderr, "Arguments:\n");
fprintf(stderr, " -d file: path to a dictionary file to use.\n");
fprintf(stderr, " -D dir : path to a directory, with files containing dictionaries, of the\n"
" form DICTID.zstd-dict, e.g., 12345.zstd-dict.\n");
exit(1);
}

static void print_summary(stuff_t* stuff) {
int error_code;
fprintf(stderr, "%9d successful decompressions\n", stuff->success_count);
for (error_code = 0; error_code < ZSTD_error_maxCode; error_code++) {
int count = stuff->error_counts[error_code];
if (count) {
fprintf(
stderr, "%9d failed decompressions with message: %s\n",
count, ZSTD_getErrorString(error_code));
}
}
}

static char* readFile(const char* filename, size_t* size) {
struct stat statbuf;
int ret;
Expand Down Expand Up @@ -87,10 +119,76 @@ static char* readFile(const char* filename, size_t* size) {
return buf;
}

static ZSTD_DDict* readDict(const char* filename, char **buf, size_t* size, int32_t* dict_id) {
ZSTD_DDict* ddict;
*buf = readFile(filename, size);
if (*buf == NULL) {
fprintf(stderr, "Opening dictionary file '%s' failed\n", filename);
return NULL;
}

ddict = ZSTD_createDDict_advanced(*buf, *size, ZSTD_dlm_byRef, ZSTD_dct_auto, ZSTD_defaultCMem);
if (ddict == NULL) {
fprintf(stderr, "Failed to create ddict.\n");
return NULL;
}
if (dict_id != NULL) {
*dict_id = ZSTD_getDictID_fromDDict(ddict);
}
return ddict;
}

static ZSTD_DDict* readDictByID(stuff_t *stuff, int32_t dict_id, char **buf, size_t* size) {
if (stuff->dict_file_dir_name == NULL) {
return NULL;
} else {
size_t dir_name_len = strlen(stuff->dict_file_dir_name);
int dir_needs_separator = 0;
size_t dict_file_name_alloc_size = dir_name_len + 1 /* '/' */ + 10 /* max int32_t len */ + strlen(".zstd-dict") + 1 /* '\0' */;
char *dict_file_name = malloc(dict_file_name_alloc_size);
ZSTD_DDict* ddict;
int32_t read_dict_id;
if (dict_file_name == NULL) {
fprintf(stderr, "malloc failed.\n");
return 0;
}

if (dir_name_len > 0 && stuff->dict_file_dir_name[dir_name_len - 1] != '/') {
dir_needs_separator = 1;
}

snprintf(
dict_file_name,
dict_file_name_alloc_size,
"%s%s%u.zstd-dict",
stuff->dict_file_dir_name,
dir_needs_separator ? "/" : "",
dict_id);

/* fprintf(stderr, "Loading dict %u from '%s'.\n", dict_id, dict_file_name); */

ddict = readDict(dict_file_name, buf, size, &read_dict_id);
if (ddict == NULL) {
fprintf(stderr, "Failed to create ddict from '%s'.\n", dict_file_name);
free(dict_file_name);
return 0;
}
if (read_dict_id != dict_id) {
fprintf(stderr, "Read dictID (%u) does not match expected (%u).\n", read_dict_id, dict_id);
free(dict_file_name);
ZSTD_freeDDict(ddict);
return 0;
}

free(dict_file_name);
return ddict;
}
}

static int init_stuff(stuff_t* stuff, int argc, char *argv[]) {
const char* input_filename;

if (argc != 2) {
if (argc < 2) {
usage();
}

Expand All @@ -116,11 +214,59 @@ static int init_stuff(stuff_t* stuff, int argc, char *argv[]) {
return 0;
}

stuff->dict_file_name = NULL;
stuff->dict_file_dir_name = NULL;
stuff->dict_id = 0;
stuff->dict = NULL;
stuff->dict_size = 0;
stuff->ddict = NULL;

if (argc > 2) {
if (!strcmp(argv[2], "-d")) {
if (argc > 3) {
stuff->dict_file_name = argv[3];
} else {
usage();
}
} else
if (!strcmp(argv[2], "-D")) {
if (argc > 3) {
stuff->dict_file_dir_name = argv[3];
} else {
usage();
}
} else {
usage();
}
}

if (stuff->dict_file_dir_name) {
int32_t dict_id = ZSTD_getDictID_fromFrame(stuff->input, stuff->input_size);
if (dict_id != 0) {
stuff->ddict = readDictByID(stuff, dict_id, &stuff->dict, &stuff->dict_size);
if (stuff->ddict == NULL) {
fprintf(stderr, "Failed to create cached ddict.\n");
return 0;
}
stuff->dict_id = dict_id;
}
} else
if (stuff->dict_file_name) {
stuff->ddict = readDict(stuff->dict_file_name, &stuff->dict, &stuff->dict_size, &stuff->dict_id);
if (stuff->ddict == NULL) {
fprintf(stderr, "Failed to create ddict from '%s'.\n", stuff->dict_file_name);
return 0;
}
}

stuff->dctx = ZSTD_createDCtx();
if (stuff->dctx == NULL) {
return 0;
}

stuff->success_count = 0;
memset(stuff->error_counts, 0, sizeof(stuff->error_counts));

return 1;
}

Expand All @@ -129,23 +275,53 @@ static int test_decompress(stuff_t* stuff) {
ZSTD_inBuffer in = {stuff->perturbed, stuff->input_size, 0};
ZSTD_outBuffer out = {stuff->output, stuff->output_size, 0};
ZSTD_DCtx* dctx = stuff->dctx;
int32_t custom_dict_id = ZSTD_getDictID_fromFrame(in.src, in.size);
char *custom_dict = NULL;
size_t custom_dict_size = 0;
ZSTD_DDict* custom_ddict = NULL;

if (custom_dict_id != 0 && custom_dict_id != stuff->dict_id) {
/* fprintf(stderr, "Instead of dict %u, this perturbed blob wants dict %u.\n", stuff->dict_id, custom_dict_id); */
custom_ddict = readDictByID(stuff, custom_dict_id, &custom_dict, &custom_dict_size);
}

ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only);
ZSTD_DCtx_refDDict(dctx, NULL);

if (custom_ddict != NULL) {
ZSTD_DCtx_refDDict(dctx, custom_ddict);
} else {
ZSTD_DCtx_refDDict(dctx, stuff->ddict);
}

while (in.pos != in.size) {
out.pos = 0;
ret = ZSTD_decompressStream(dctx, &out, &in);

if (ZSTD_isError(ret)) {
unsigned int code = ZSTD_getErrorCode(ret);
if (code >= ZSTD_error_maxCode) {
fprintf(stderr, "Received unexpected error code!\n");
exit(1);
}
stuff->error_counts[code]++;
/*
fprintf(
stderr, "Decompression failed: %s\n", ZSTD_getErrorName(ret));
*/
if (custom_ddict != NULL) {
ZSTD_freeDDict(custom_ddict);
free(custom_dict);
}
return 0;
}
}

stuff->success_count++;

if (custom_ddict != NULL) {
ZSTD_freeDDict(custom_ddict);
free(custom_dict);
}
return 1;
}

Expand All @@ -155,7 +331,7 @@ static int perturb_bits(stuff_t* stuff) {
for (pos = 0; pos < stuff->input_size; pos++) {
unsigned char old_val = stuff->input[pos];
if (pos % 1000 == 0) {
fprintf(stderr, "Perturbing byte %zu\n", pos);
fprintf(stderr, "Perturbing byte %zu / %zu\n", pos, stuff->input_size);
}
for (bit = 0; bit < 8; bit++) {
unsigned char new_val = old_val ^ (1 << bit);
Expand All @@ -179,7 +355,7 @@ static int perturb_bytes(stuff_t* stuff) {
for (pos = 0; pos < stuff->input_size; pos++) {
unsigned char old_val = stuff->input[pos];
if (pos % 1000 == 0) {
fprintf(stderr, "Perturbing byte %zu\n", pos);
fprintf(stderr, "Perturbing byte %zu / %zu\n", pos, stuff->input_size);
}
for (new_val = 0; new_val < 256; new_val++) {
stuff->perturbed[pos] = new_val;
Expand Down Expand Up @@ -213,5 +389,9 @@ int main(int argc, char* argv[]) {

perturb_bytes(&stuff);

print_summary(&stuff);

free_stuff(&stuff);
}

return 0;
}

0 comments on commit 07ad866

Please sign in to comment.