Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Dictionaries in Corruption Diagnosis Tool #1932

Merged
merged 3 commits into from
Dec 19, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 186 additions & 6 deletions contrib/diagnose_corruption/check_flipped_bits.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
* You may select, at your option, one of the above-listed licenses.
*/

#define ZSTD_STATIC_LINKING_ONLY
#include "zstd.h"
#include "zstd_errors.h"

#include <stdio.h>
#include <stdlib.h>
Expand All @@ -26,20 +28,50 @@ typedef struct {
char *output;
size_t output_size;

const char *dict_file_name;
const char *dict_file_dir_name;
int32_t dict_id;
char *dict;
size_t dict_size;
ZSTD_DDict* ddict;

ZSTD_DCtx* dctx;

int success_count;
int error_counts[ZSTD_error_maxCode];
} stuff_t;

static void free_stuff(stuff_t* stuff) {
free(stuff->input);
free(stuff->output);
ZSTD_freeDDict(stuff->ddict);
free(stuff->dict);
ZSTD_freeDCtx(stuff->dctx);
}

static void usage(void) {
fprintf(stderr, "check_flipped_bits input_filename");
fprintf(stderr, "check_flipped_bits input_filename [-d dict] [-D dict_dir]\n");
fprintf(stderr, "\n");
fprintf(stderr, "Arguments:\n");
fprintf(stderr, " -d file: path to a dictionary file to use.\n");
fprintf(stderr, " -D dir : path to a directory, with files containing dictionaries, of the\n"
" form DICTID.zstd-dict, e.g., 12345.zstd-dict.\n");
exit(1);
}

static void print_summary(stuff_t* stuff) {
int error_code;
fprintf(stderr, "%9d successful decompressions\n", stuff->success_count);
for (error_code = 0; error_code < ZSTD_error_maxCode; error_code++) {
int count = stuff->error_counts[error_code];
if (count) {
fprintf(
stderr, "%9d failed decompressions with message: %s\n",
count, ZSTD_getErrorString(error_code));
}
}
}

static char* readFile(const char* filename, size_t* size) {
struct stat statbuf;
int ret;
Expand Down Expand Up @@ -87,10 +119,76 @@ static char* readFile(const char* filename, size_t* size) {
return buf;
}

static ZSTD_DDict* readDict(const char* filename, char **buf, size_t* size, int32_t* dict_id) {
ZSTD_DDict* ddict;
*buf = readFile(filename, size);
if (*buf == NULL) {
fprintf(stderr, "Opening dictionary file '%s' failed\n", filename);
return NULL;
}

ddict = ZSTD_createDDict_advanced(*buf, *size, ZSTD_dlm_byRef, ZSTD_dct_auto, ZSTD_defaultCMem);
if (ddict == NULL) {
fprintf(stderr, "Failed to create ddict.\n");
return NULL;
}
if (dict_id != NULL) {
*dict_id = ZSTD_getDictID_fromDDict(ddict);
}
return ddict;
}

static ZSTD_DDict* readDictByID(stuff_t *stuff, int32_t dict_id, char **buf, size_t* size) {
if (stuff->dict_file_dir_name == NULL) {
return NULL;
} else {
size_t dir_name_len = strlen(stuff->dict_file_dir_name);
int dir_needs_separator = 0;
size_t dict_file_name_alloc_size = dir_name_len + 1 /* '/' */ + 10 /* max int32_t len */ + strlen(".zstd-dict") + 1 /* '\0' */;
char *dict_file_name = malloc(dict_file_name_alloc_size);
ZSTD_DDict* ddict;
int32_t read_dict_id;
if (dict_file_name == NULL) {
fprintf(stderr, "malloc failed.\n");
return 0;
}

if (dir_name_len > 0 && stuff->dict_file_dir_name[dir_name_len - 1] != '/') {
dir_needs_separator = 1;
}

snprintf(
dict_file_name,
dict_file_name_alloc_size,
"%s%s%u.zstd-dict",
stuff->dict_file_dir_name,
dir_needs_separator ? "/" : "",
dict_id);

/* fprintf(stderr, "Loading dict %u from '%s'.\n", dict_id, dict_file_name); */

ddict = readDict(dict_file_name, buf, size, &read_dict_id);
if (ddict == NULL) {
fprintf(stderr, "Failed to create ddict from '%s'.\n", dict_file_name);
free(dict_file_name);
return 0;
}
if (read_dict_id != dict_id) {
fprintf(stderr, "Read dictID (%u) does not match expected (%u).\n", read_dict_id, dict_id);
free(dict_file_name);
ZSTD_freeDDict(ddict);
return 0;
}

free(dict_file_name);
return ddict;
}
}

static int init_stuff(stuff_t* stuff, int argc, char *argv[]) {
const char* input_filename;

if (argc != 2) {
if (argc < 2) {
usage();
}

Expand All @@ -116,11 +214,59 @@ static int init_stuff(stuff_t* stuff, int argc, char *argv[]) {
return 0;
}

stuff->dict_file_name = NULL;
stuff->dict_file_dir_name = NULL;
stuff->dict_id = 0;
stuff->dict = NULL;
stuff->dict_size = 0;
stuff->ddict = NULL;

if (argc > 2) {
if (!strcmp(argv[2], "-d")) {
if (argc > 3) {
stuff->dict_file_name = argv[3];
} else {
usage();
}
} else
if (!strcmp(argv[2], "-D")) {
if (argc > 3) {
stuff->dict_file_dir_name = argv[3];
} else {
usage();
}
} else {
usage();
}
}

if (stuff->dict_file_dir_name) {
int32_t dict_id = ZSTD_getDictID_fromFrame(stuff->input, stuff->input_size);
if (dict_id != 0) {
stuff->ddict = readDictByID(stuff, dict_id, &stuff->dict, &stuff->dict_size);
if (stuff->ddict == NULL) {
fprintf(stderr, "Failed to create cached ddict.\n");
return 0;
}
stuff->dict_id = dict_id;
}
} else
if (stuff->dict_file_name) {
stuff->ddict = readDict(stuff->dict_file_name, &stuff->dict, &stuff->dict_size, &stuff->dict_id);
if (stuff->ddict == NULL) {
fprintf(stderr, "Failed to create ddict from '%s'.\n", stuff->dict_file_name);
return 0;
}
}

stuff->dctx = ZSTD_createDCtx();
if (stuff->dctx == NULL) {
return 0;
}

stuff->success_count = 0;
memset(stuff->error_counts, 0, sizeof(stuff->error_counts));

return 1;
}

Expand All @@ -129,23 +275,53 @@ static int test_decompress(stuff_t* stuff) {
ZSTD_inBuffer in = {stuff->perturbed, stuff->input_size, 0};
ZSTD_outBuffer out = {stuff->output, stuff->output_size, 0};
ZSTD_DCtx* dctx = stuff->dctx;
int32_t custom_dict_id = ZSTD_getDictID_fromFrame(in.src, in.size);
char *custom_dict = NULL;
size_t custom_dict_size = 0;
ZSTD_DDict* custom_ddict = NULL;

if (custom_dict_id != 0 && custom_dict_id != stuff->dict_id) {
/* fprintf(stderr, "Instead of dict %u, this perturbed blob wants dict %u.\n", stuff->dict_id, custom_dict_id); */
custom_ddict = readDictByID(stuff, custom_dict_id, &custom_dict, &custom_dict_size);
}

ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only);
ZSTD_DCtx_refDDict(dctx, NULL);

if (custom_ddict != NULL) {
ZSTD_DCtx_refDDict(dctx, custom_ddict);
} else {
ZSTD_DCtx_refDDict(dctx, stuff->ddict);
}

while (in.pos != in.size) {
out.pos = 0;
ret = ZSTD_decompressStream(dctx, &out, &in);

if (ZSTD_isError(ret)) {
unsigned int code = ZSTD_getErrorCode(ret);
if (code >= ZSTD_error_maxCode) {
fprintf(stderr, "Received unexpected error code!\n");
exit(1);
}
stuff->error_counts[code]++;
/*
fprintf(
stderr, "Decompression failed: %s\n", ZSTD_getErrorName(ret));
*/
if (custom_ddict != NULL) {
ZSTD_freeDDict(custom_ddict);
free(custom_dict);
}
return 0;
}
}

stuff->success_count++;

if (custom_ddict != NULL) {
ZSTD_freeDDict(custom_ddict);
free(custom_dict);
}
return 1;
}

Expand All @@ -155,7 +331,7 @@ static int perturb_bits(stuff_t* stuff) {
for (pos = 0; pos < stuff->input_size; pos++) {
unsigned char old_val = stuff->input[pos];
if (pos % 1000 == 0) {
fprintf(stderr, "Perturbing byte %zu\n", pos);
fprintf(stderr, "Perturbing byte %zu / %zu\n", pos, stuff->input_size);
}
for (bit = 0; bit < 8; bit++) {
unsigned char new_val = old_val ^ (1 << bit);
Expand All @@ -179,7 +355,7 @@ static int perturb_bytes(stuff_t* stuff) {
for (pos = 0; pos < stuff->input_size; pos++) {
unsigned char old_val = stuff->input[pos];
if (pos % 1000 == 0) {
fprintf(stderr, "Perturbing byte %zu\n", pos);
fprintf(stderr, "Perturbing byte %zu / %zu\n", pos, stuff->input_size);
}
for (new_val = 0; new_val < 256; new_val++) {
stuff->perturbed[pos] = new_val;
Expand Down Expand Up @@ -213,5 +389,9 @@ int main(int argc, char* argv[]) {

perturb_bytes(&stuff);

print_summary(&stuff);

free_stuff(&stuff);
}

return 0;
}