diff --git a/include/aws/s3/private/s3_checksum_context.h b/include/aws/s3/private/s3_checksum_context.h index ce86aa977..f58b66261 100644 --- a/include/aws/s3/private/s3_checksum_context.h +++ b/include/aws/s3/private/s3_checksum_context.h @@ -8,6 +8,7 @@ #include "aws/s3/s3_client.h" #include +#include #include struct aws_s3_meta_request_checksum_config_storage; @@ -28,9 +29,14 @@ struct aws_s3_upload_request_checksum_context { enum aws_s3_checksum_algorithm algorithm; enum aws_s3_checksum_location location; - struct aws_byte_buf base64_checksum; - /* The checksum already be calculated or not. */ - bool checksum_calculated; + struct { + /* Note: don't directly access the synced_data. */ + /* Lock to make sure the checksum context is safe to be access from different threads. */ + struct aws_mutex lock; + struct aws_byte_buf base64_checksum; + /* The checksum already be calculated or not. */ + bool checksum_calculated; + } synced_data; /* Validation */ size_t encoded_checksum_size; @@ -96,8 +102,7 @@ struct aws_s3_upload_request_checksum_context *aws_s3_upload_request_checksum_co * @return true if checksum calculation is needed, false otherwise */ AWS_S3_API -bool aws_s3_upload_request_checksum_context_should_calculate( - const struct aws_s3_upload_request_checksum_context *context); +bool aws_s3_upload_request_checksum_context_should_calculate(struct aws_s3_upload_request_checksum_context *context); /** * Check if checksum should be added to HTTP headers. @@ -122,15 +127,18 @@ bool aws_s3_upload_request_checksum_context_should_add_trailer( const struct aws_s3_upload_request_checksum_context *context); /** - * Get the checksum buffer to use for output. - * Returns the internal buffer for storing the calculated checksum. + * Encode the checksum to base64 and store it in the context. + * This function is thread-safe and can be called from multiple threads. + * Returns AWS_OP_SUCCESS on success, AWS_OP_ERR otherwise * * @param context The checksum context - * @return Pointer to the checksum buffer, or NULL if context is invalid + * @param raw_checksum_cursor the byte cursor to the raw checksum value. + * @return AWS_OP_SUCCESS on success, AWS_OP_ERR otherwise */ AWS_S3_API -struct aws_byte_buf *aws_s3_upload_request_checksum_context_get_output_buffer( - struct aws_s3_upload_request_checksum_context *context); +int aws_s3_upload_request_checksum_context_finalize_checksum( + struct aws_s3_upload_request_checksum_context *context, + struct aws_byte_cursor raw_checksum_cursor); /** * Get a cursor to the current base64 encoded checksum value (for use in headers/trailers). @@ -141,7 +149,7 @@ struct aws_byte_buf *aws_s3_upload_request_checksum_context_get_output_buffer( */ AWS_S3_API struct aws_byte_cursor aws_s3_upload_request_checksum_context_get_checksum_cursor( - const struct aws_s3_upload_request_checksum_context *context); + struct aws_s3_upload_request_checksum_context *context); AWS_EXTERN_C_END diff --git a/include/aws/s3/private/s3_checksums.h b/include/aws/s3/private/s3_checksums.h index 6d922c84e..d16babb98 100644 --- a/include/aws/s3/private/s3_checksums.h +++ b/include/aws/s3/private/s3_checksums.h @@ -62,8 +62,7 @@ struct aws_s3_meta_request_checksum_config_storage { }; /** - * a stream that takes in a stream, computes a running checksum as it is read, and outputs the checksum when the stream - * is destroyed. + * a stream that takes in a stream * Note: seek this stream will immediately fail, as it would prevent an accurate calculation of the * checksum. * @@ -72,15 +71,28 @@ struct aws_s3_meta_request_checksum_config_storage { * outputs the checksum of existing stream to checksum_output upon destruction. Will be kept * alive by the checksum stream * @param algorithm Checksum algorithm to use. - * @param checksum_output Checksum of the `existing_stream`, owned by caller, which will be calculated when this stream - * is destroyed. */ AWS_S3_API struct aws_input_stream *aws_checksum_stream_new( struct aws_allocator *allocator, struct aws_input_stream *existing_stream, - enum aws_s3_checksum_algorithm algorithm, - struct aws_byte_buf *checksum_output); + enum aws_s3_checksum_algorithm algorithm); + +/** + * Finalize the checksum has read so far to the output checksum buf with base64 encoding. + * Not thread safe. + */ +AWS_S3_API +int aws_checksum_stream_finalize_checksum(struct aws_input_stream *checksum_stream, struct aws_byte_buf *checksum_buf); + +/** + * Finalize the checksum has read so far to the checksum context. + * Not thread safe. + */ +AWS_S3_API +int aws_checksum_stream_finalize_checksum_context( + struct aws_input_stream *checksum_stream, + struct aws_s3_upload_request_checksum_context *checksum_context); /** * TODO: properly support chunked encoding. diff --git a/include/aws/s3/private/s3_client_impl.h b/include/aws/s3/private/s3_client_impl.h index c3520a7c8..974a19d01 100644 --- a/include/aws/s3/private/s3_client_impl.h +++ b/include/aws/s3/private/s3_client_impl.h @@ -175,7 +175,7 @@ struct aws_s3_client_vtable { struct aws_http_connection *client_connection, const struct aws_http_make_request_options *options); - void (*after_prepare_upload_part_finish)(struct aws_s3_request *request); + void (*after_prepare_upload_part_finish)(struct aws_s3_request *request, struct aws_http_message *message); }; struct aws_s3_upload_part_timeout_stats { diff --git a/include/aws/s3/s3_client.h b/include/aws/s3/s3_client.h index 9a8eb019e..d92df125f 100644 --- a/include/aws/s3/s3_client.h +++ b/include/aws/s3/s3_client.h @@ -884,7 +884,7 @@ struct aws_s3_meta_request_options { * Optional. * Callback for reviewing an upload before it completes. * WARNING: experimental/unstable - * See `aws_s3_upload_review_fn` + * See `aws_s3_meta_request_upload_review_fn` */ aws_s3_meta_request_upload_review_fn *upload_review_callback; diff --git a/source/s3_auto_ranged_put.c b/source/s3_auto_ranged_put.c index a6fd672cd..0c11315fe 100644 --- a/source/s3_auto_ranged_put.c +++ b/source/s3_auto_ranged_put.c @@ -1099,15 +1099,20 @@ static void s_s3_prepare_upload_part_on_read_done(void *user_data) { goto on_done; } struct aws_s3_upload_request_checksum_context *context = previously_uploaded_info->checksum_context; - /* if previously uploaded part had a checksum, compare it to what we just skipped */ - if (context != NULL && context->checksum_calculated == true && - s_verify_part_matches_checksum( - meta_request->allocator, - aws_byte_cursor_from_buf(&request->request_body), - meta_request->checksum_config.checksum_algorithm, - aws_byte_cursor_from_buf(&context->base64_checksum))) { - error_code = aws_last_error_or_unknown(); - goto on_done; + if (context) { + if (!aws_s3_upload_request_checksum_context_should_calculate(context)) { + struct aws_byte_cursor previous_calculated_checksum = + aws_s3_upload_request_checksum_context_get_checksum_cursor(context); + /* if previously uploaded part had a checksum, compare it to what we just skipped */ + if (s_verify_part_matches_checksum( + meta_request->allocator, + aws_byte_cursor_from_buf(&request->request_body), + meta_request->checksum_config.checksum_algorithm, + previous_calculated_checksum) != AWS_OP_SUCCESS) { + error_code = aws_last_error_or_unknown(); + goto on_done; + } + } } } @@ -1157,9 +1162,6 @@ static void s_s3_prepare_upload_part_finish(struct aws_s3_prepare_upload_part_jo checksum_context = part->checksum_context; /* If checksum already calculated, it means either the part being retried or the part resumed from list * parts. Keep reusing the old checksum in case of the request body in memory mangled */ - AWS_ASSERT( - !checksum_context->checksum_calculated || request->num_times_prepared > 0 || - auto_ranged_put->resume_token != NULL); aws_s3_meta_request_unlock_synced_data(meta_request); } /* END CRITICAL SECTION */ @@ -1184,15 +1186,15 @@ static void s_s3_prepare_upload_part_finish(struct aws_s3_prepare_upload_part_jo aws_future_http_message_set_error(part_prep->on_complete, aws_last_error()); goto on_done; } + if (client->vtable->after_prepare_upload_part_finish) { + /* TEST ONLY, allow test to stub here. */ + client->vtable->after_prepare_upload_part_finish(request, message); + } /* Success! */ aws_future_http_message_set_result_by_move(part_prep->on_complete, &message); on_done: - if (client->vtable->after_prepare_upload_part_finish) { - /* TEST ONLY, allow test to stub here. */ - client->vtable->after_prepare_upload_part_finish(request); - } AWS_FATAL_ASSERT(aws_future_http_message_is_done(part_prep->on_complete)); aws_future_bool_release(part_prep->asyncstep_read_part); aws_future_http_message_release(part_prep->on_complete); diff --git a/source/s3_checksum_context.c b/source/s3_checksum_context.c index 189a2b128..c2de802ae 100644 --- a/source/s3_checksum_context.c +++ b/source/s3_checksum_context.c @@ -8,9 +8,18 @@ #include #include +static void s_lock_synced_data(struct aws_s3_upload_request_checksum_context *context) { + aws_mutex_lock(&context->synced_data.lock); +} + +static void s_unlock_synced_data(struct aws_s3_upload_request_checksum_context *context) { + aws_mutex_unlock(&context->synced_data.lock); +} + static void s_aws_s3_upload_request_checksum_context_destroy(void *context) { struct aws_s3_upload_request_checksum_context *checksum_context = context; - aws_byte_buf_clean_up(&checksum_context->base64_checksum); + aws_byte_buf_clean_up(&checksum_context->synced_data.base64_checksum); + aws_mutex_clean_up(&checksum_context->synced_data.lock); aws_mem_release(checksum_context->allocator, checksum_context); } @@ -24,6 +33,10 @@ static struct aws_s3_upload_request_checksum_context *s_s3_upload_request_checks aws_ref_count_init(&context->ref_count, context, s_aws_s3_upload_request_checksum_context_destroy); context->allocator = allocator; + if (aws_mutex_init(&context->synced_data.lock)) { + aws_s3_upload_request_checksum_context_release(context); + return NULL; + } /* Handle case where no checksum config is provided */ if (!checksum_config || checksum_config->checksum_algorithm == AWS_SCA_NONE) { context->algorithm = AWS_SCA_NONE; @@ -54,7 +67,7 @@ struct aws_s3_upload_request_checksum_context *aws_s3_upload_request_checksum_co s_s3_upload_request_checksum_context_new_base(allocator, checksum_config); if (context && context->encoded_checksum_size > 0) { /* Initial the buffer for checksum */ - aws_byte_buf_init(&context->base64_checksum, allocator, context->encoded_checksum_size); + aws_byte_buf_init(&context->synced_data.base64_checksum, allocator, context->encoded_checksum_size); } return context; } @@ -79,8 +92,8 @@ struct aws_s3_upload_request_checksum_context *aws_s3_upload_request_checksum_co aws_s3_upload_request_checksum_context_release(context); return NULL; } - aws_byte_buf_init_copy_from_cursor(&context->base64_checksum, allocator, existing_base64_checksum); - context->checksum_calculated = true; + aws_byte_buf_init_copy_from_cursor(&context->synced_data.base64_checksum, allocator, existing_base64_checksum); + context->synced_data.checksum_calculated = true; } return context; } @@ -101,14 +114,18 @@ struct aws_s3_upload_request_checksum_context *aws_s3_upload_request_checksum_co return NULL; } -bool aws_s3_upload_request_checksum_context_should_calculate( - const struct aws_s3_upload_request_checksum_context *context) { +bool aws_s3_upload_request_checksum_context_should_calculate(struct aws_s3_upload_request_checksum_context *context) { if (!context || context->algorithm == AWS_SCA_NONE) { return false; } + bool should_calculate = false; + s_lock_synced_data(context); /* If not previous calculated */ - return !context->checksum_calculated; + should_calculate = !context->synced_data.checksum_calculated; + s_unlock_synced_data(context); + + return should_calculate; } bool aws_s3_upload_request_checksum_context_should_add_header( @@ -129,19 +146,46 @@ bool aws_s3_upload_request_checksum_context_should_add_trailer( return context->location == AWS_SCL_TRAILER && context->algorithm != AWS_SCA_NONE; } -struct aws_byte_buf *aws_s3_upload_request_checksum_context_get_output_buffer( - struct aws_s3_upload_request_checksum_context *context) { +int aws_s3_upload_request_checksum_context_finalize_checksum( + struct aws_s3_upload_request_checksum_context *context, + struct aws_byte_cursor raw_checksum_cursor) { if (!context) { - return NULL; + return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); } - return &context->base64_checksum; + s_lock_synced_data(context); + /* If not previous calculated */ + if (!context->synced_data.checksum_calculated) { + AWS_ASSERT(context->synced_data.base64_checksum.len == 0); + + if (aws_base64_encode(&raw_checksum_cursor, &context->synced_data.base64_checksum)) { + aws_byte_buf_reset(&context->synced_data.base64_checksum, false); + AWS_LOGF_ERROR( + AWS_LS_S3_GENERAL, + "Failed to base64 encode for the checksum. Raw checksum length: %zu. Output buffer capacity: %zu " + "length %zu", + raw_checksum_cursor.len, + context->synced_data.base64_checksum.capacity, + context->synced_data.base64_checksum.len); + s_unlock_synced_data(context); + return AWS_OP_ERR; + } + context->synced_data.checksum_calculated = true; + } + s_unlock_synced_data(context); + return AWS_OP_SUCCESS; } struct aws_byte_cursor aws_s3_upload_request_checksum_context_get_checksum_cursor( - const struct aws_s3_upload_request_checksum_context *context) { + struct aws_s3_upload_request_checksum_context *context) { struct aws_byte_cursor checksum_cursor = {0}; - if (!context || !context->checksum_calculated) { + if (!context) { return checksum_cursor; } - return aws_byte_cursor_from_buf(&context->base64_checksum); + s_lock_synced_data(context); + /* If not previous calculated */ + if (context->synced_data.checksum_calculated) { + checksum_cursor = aws_byte_cursor_from_buf(&context->synced_data.base64_checksum); + } + s_unlock_synced_data(context); + return checksum_cursor; } diff --git a/source/s3_checksum_stream.c b/source/s3_checksum_stream.c index 83b8dc859..00660284a 100644 --- a/source/s3_checksum_stream.c +++ b/source/s3_checksum_stream.c @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0. */ +#include "aws/s3/private/s3_checksum_context.h" #include "aws/s3/private/s3_checksums.h" #include #include @@ -14,44 +15,8 @@ struct aws_checksum_stream { struct aws_input_stream *old_stream; struct aws_s3_checksum *checksum; struct aws_byte_buf checksum_result; - /* base64 encoded checksum of the stream, updated at end of stream */ - struct aws_byte_buf *encoded_checksum_output; - bool checksum_finalized; }; -static int s_finalize_checksum(struct aws_checksum_stream *impl) { - if (impl->checksum_finalized) { - return AWS_OP_SUCCESS; - } - - if (aws_checksum_finalize(impl->checksum, &impl->checksum_result) != AWS_OP_SUCCESS) { - AWS_LOGF_ERROR( - AWS_LS_S3_CLIENT, - "Failed to calculate checksum with error code %d (%s).", - aws_last_error(), - aws_error_str(aws_last_error())); - aws_byte_buf_reset(&impl->checksum_result, true); - impl->checksum_finalized = true; - return aws_raise_error(AWS_ERROR_S3_CHECKSUM_CALCULATION_FAILED); - } - struct aws_byte_cursor checksum_result_cursor = aws_byte_cursor_from_buf(&impl->checksum_result); - if (aws_base64_encode(&checksum_result_cursor, impl->encoded_checksum_output) != AWS_OP_SUCCESS) { - AWS_LOGF_ERROR( - AWS_LS_S3_CLIENT, - "Failed to base64 encode checksum with error code %d (%s). Output capacity: %zu length %zu", - aws_last_error(), - aws_error_str(aws_last_error()), - impl->encoded_checksum_output->capacity, - impl->encoded_checksum_output->len); - aws_byte_buf_reset(&impl->checksum_result, true); - impl->checksum_finalized = true; - return aws_raise_error(AWS_ERROR_S3_CHECKSUM_CALCULATION_FAILED); - } - - impl->checksum_finalized = true; - return AWS_OP_SUCCESS; -} - static int s_aws_input_checksum_stream_seek( struct aws_input_stream *stream, int64_t offset, @@ -82,14 +47,6 @@ static int s_aws_input_checksum_stream_read(struct aws_input_stream *stream, str if (aws_checksum_update(impl->checksum, &to_sum)) { return AWS_OP_ERR; } - /* If we're at the end of the stream, compute and store the final checksum */ - struct aws_stream_status status; - if (aws_input_stream_get_status(impl->old_stream, &status)) { - return AWS_OP_ERR; - } - if (status.is_end_of_stream) { - return s_finalize_checksum(impl); - } return AWS_OP_SUCCESS; } @@ -111,10 +68,6 @@ static void s_aws_input_checksum_stream_destroy(struct aws_checksum_stream *impl if (!impl) { return; } - - /* Compute the checksum of whatever was read, if we didn't reach the end of the underlying stream. */ - s_finalize_checksum(impl); - aws_checksum_destroy(impl->checksum); aws_input_stream_release(impl->old_stream); aws_byte_buf_clean_up(&impl->checksum_result); @@ -131,11 +84,8 @@ static struct aws_input_stream_vtable s_aws_input_checksum_stream_vtable = { struct aws_input_stream *aws_checksum_stream_new( struct aws_allocator *allocator, struct aws_input_stream *existing_stream, - enum aws_s3_checksum_algorithm algorithm, - struct aws_byte_buf *checksum_output) { + enum aws_s3_checksum_algorithm algorithm) { AWS_PRECONDITION(existing_stream); - AWS_PRECONDITION(checksum_output); - AWS_PRECONDITION(checksum_output->len == 0 && "Checksum output buffer is not empty"); struct aws_checksum_stream *impl = aws_mem_calloc(allocator, 1, sizeof(struct aws_checksum_stream)); impl->allocator = allocator; @@ -147,7 +97,6 @@ struct aws_input_stream *aws_checksum_stream_new( } aws_byte_buf_init(&impl->checksum_result, allocator, impl->checksum->digest_size); impl->old_stream = aws_input_stream_acquire(existing_stream); - impl->encoded_checksum_output = checksum_output; aws_ref_count_init( &impl->base.ref_count, impl, (aws_simple_completion_callback *)s_aws_input_checksum_stream_destroy); @@ -156,3 +105,62 @@ struct aws_input_stream *aws_checksum_stream_new( aws_mem_release(impl->allocator, impl); return NULL; } + +int aws_checksum_stream_finalize_checksum(struct aws_input_stream *checksum_stream, struct aws_byte_buf *checksum_buf) { + AWS_PRECONDITION(checksum_buf); + AWS_PRECONDITION(checksum_buf->len == 0 && "Checksum output buffer is not empty"); + + struct aws_checksum_stream *impl = AWS_CONTAINER_OF(checksum_stream, struct aws_checksum_stream, base); + + if (aws_checksum_finalize(impl->checksum, &impl->checksum_result) != AWS_OP_SUCCESS) { + AWS_LOGF_ERROR( + AWS_LS_S3_CLIENT, + "Failed to calculate checksum with error code %d (%s).", + aws_last_error(), + aws_error_str(aws_last_error())); + aws_byte_buf_reset(&impl->checksum_result, true); + return aws_raise_error(AWS_ERROR_S3_CHECKSUM_CALCULATION_FAILED); + } + struct aws_byte_cursor checksum_result_cursor = aws_byte_cursor_from_buf(&impl->checksum_result); + if (aws_base64_encode(&checksum_result_cursor, checksum_buf) != AWS_OP_SUCCESS) { + AWS_LOGF_ERROR( + AWS_LS_S3_CLIENT, + "Failed to base64 encode checksum with error code %d (%s). Output capacity: %zu length %zu", + aws_last_error(), + aws_error_str(aws_last_error()), + checksum_buf->capacity, + checksum_buf->len); + aws_byte_buf_reset(&impl->checksum_result, true); + return aws_raise_error(AWS_ERROR_S3_CHECKSUM_CALCULATION_FAILED); + } + + return AWS_OP_SUCCESS; +} + +int aws_checksum_stream_finalize_checksum_context( + struct aws_input_stream *checksum_stream, + struct aws_s3_upload_request_checksum_context *checksum_context) { + struct aws_checksum_stream *impl = AWS_CONTAINER_OF(checksum_stream, struct aws_checksum_stream, base); + + if (aws_checksum_finalize(impl->checksum, &impl->checksum_result) != AWS_OP_SUCCESS) { + AWS_LOGF_ERROR( + AWS_LS_S3_CLIENT, + "Failed to calculate checksum with error code %d (%s).", + aws_last_error(), + aws_error_str(aws_last_error())); + aws_byte_buf_reset(&impl->checksum_result, true); + return aws_raise_error(AWS_ERROR_S3_CHECKSUM_CALCULATION_FAILED); + } + struct aws_byte_cursor checksum_result_cursor = aws_byte_cursor_from_buf(&impl->checksum_result); + if (aws_s3_upload_request_checksum_context_finalize_checksum(checksum_context, checksum_result_cursor) != + AWS_OP_SUCCESS) { + AWS_LOGF_ERROR( + AWS_LS_S3_CLIENT, + "Failed to finalize checksum context with error code %d (%s).", + aws_last_error(), + aws_error_str(aws_last_error())); + aws_byte_buf_reset(&impl->checksum_result, true); + return aws_raise_error(AWS_ERROR_S3_CHECKSUM_CALCULATION_FAILED); + } + return AWS_OP_SUCCESS; +} diff --git a/source/s3_chunk_stream.c b/source/s3_chunk_stream.c index 0b00d532b..5b146c784 100644 --- a/source/s3_chunk_stream.c +++ b/source/s3_chunk_stream.c @@ -27,6 +27,7 @@ struct aws_chunk_stream { /* Pointing to the stream we read from */ struct aws_input_stream *current_stream; struct aws_input_stream *chunk_body_stream; + struct aws_input_stream *checksum_stream; struct aws_s3_upload_request_checksum_context *checksum_context; struct aws_byte_buf pre_chunk_buffer; @@ -61,10 +62,15 @@ static int s_set_post_chunk_stream(struct aws_chunk_stream *parent_stream) { } struct aws_byte_cursor post_trailer_cursor = aws_byte_cursor_from_string(s_post_trailer); struct aws_byte_cursor colon_cursor = aws_byte_cursor_from_string(s_colon); - /* After the checksum stream released, the checksum will be calculated. */ - parent_stream->checksum_context->checksum_calculated = true; + if (parent_stream->checksum_stream) { + /* If we have the checksum stream, finalize the checksum now as we finished reading from it. */ + if (aws_checksum_stream_finalize_checksum_context( + parent_stream->checksum_stream, parent_stream->checksum_context)) { + return AWS_OP_ERR; + } + } struct aws_byte_cursor checksum_result_cursor = - aws_byte_cursor_from_buf(&parent_stream->checksum_context->base64_checksum); + aws_s3_upload_request_checksum_context_get_checksum_cursor(parent_stream->checksum_context); AWS_ASSERT(parent_stream->checksum_context->encoded_checksum_size == checksum_result_cursor.len); if (aws_byte_buf_init( @@ -159,14 +165,12 @@ static int s_aws_input_chunk_stream_get_length(struct aws_input_stream *stream, static void s_aws_input_chunk_stream_destroy(struct aws_chunk_stream *impl) { if (impl) { - if (impl->current_stream) { - aws_input_stream_release(impl->current_stream); - } - if (impl->chunk_body_stream) { - aws_input_stream_release(impl->chunk_body_stream); - } + aws_input_stream_release(impl->current_stream); + aws_input_stream_release(impl->chunk_body_stream); + aws_input_stream_release(impl->checksum_stream); aws_byte_buf_clean_up(&impl->pre_chunk_buffer); aws_byte_buf_clean_up(&impl->post_chunk_buffer); + /* Either we calculated the checksum, or we the checksum is empty. Otherwise, something was wrong. */ aws_s3_upload_request_checksum_context_release(impl->checksum_context); aws_mem_release(impl->allocator, impl); } @@ -194,13 +198,11 @@ struct aws_input_stream *aws_chunk_stream_new( /* Extract algorithm and buffer from context */ enum aws_s3_checksum_algorithm algorithm = AWS_SCA_NONE; - struct aws_byte_buf *checksum_buffer = NULL; impl->checksum_context = aws_s3_upload_request_checksum_context_acquire(checksum_context); algorithm = checksum_context->algorithm; - checksum_buffer = &checksum_context->base64_checksum; - bool checksum_calculated = checksum_context->checksum_calculated; + bool should_calculate_checksum = aws_s3_upload_request_checksum_context_should_calculate(impl->checksum_context); int64_t stream_length = 0; int64_t final_chunk_len = 0; @@ -223,12 +225,13 @@ struct aws_input_stream *aws_chunk_stream_new( if (aws_byte_buf_append(&impl->pre_chunk_buffer, &pre_chunk_cursor)) { goto error; } - if (!checksum_calculated) { + if (should_calculate_checksum) { /* Wrap the existing stream with checksum stream to calculate the checksum when reading from it. */ - impl->chunk_body_stream = aws_checksum_stream_new(allocator, existing_stream, algorithm, checksum_buffer); - if (impl->chunk_body_stream == NULL) { + impl->checksum_stream = aws_checksum_stream_new(allocator, existing_stream, algorithm); + if (impl->checksum_stream == NULL) { goto error; } + impl->chunk_body_stream = aws_input_stream_acquire(impl->checksum_stream); } else { /* No need to calculate the checksum during read, use the existing stream directly. */ impl->chunk_body_stream = aws_input_stream_acquire(existing_stream); diff --git a/source/s3_request_messages.c b/source/s3_request_messages.c index 3e5a35693..ac10323d1 100644 --- a/source/s3_request_messages.c +++ b/source/s3_request_messages.c @@ -838,13 +838,9 @@ static int s_calculate_in_memory_checksum_helper( struct aws_byte_cursor data, struct aws_s3_upload_request_checksum_context *checksum_context) { AWS_ASSERT(checksum_context); - if (checksum_context->checksum_calculated) { - return AWS_OP_SUCCESS; - } int ret_code = AWS_OP_ERR; /* Calculate checksum for output buffer only (no header/trailer) */ - struct aws_byte_buf *output_buffer = aws_s3_upload_request_checksum_context_get_output_buffer(checksum_context); struct aws_byte_buf raw_checksum; size_t digest_size = aws_get_digest_size_from_checksum_algorithm(checksum_context->algorithm); aws_byte_buf_init(&raw_checksum, allocator, digest_size); @@ -855,18 +851,14 @@ static int s_calculate_in_memory_checksum_helper( } struct aws_byte_cursor raw_checksum_cursor = aws_byte_cursor_from_buf(&raw_checksum); - if (aws_base64_encode(&raw_checksum_cursor, output_buffer)) { + if (aws_s3_upload_request_checksum_context_finalize_checksum(checksum_context, raw_checksum_cursor)) { aws_byte_buf_clean_up(&raw_checksum); goto done; } aws_byte_buf_clean_up(&raw_checksum); - checksum_context->checksum_calculated = true; ret_code = AWS_OP_SUCCESS; done: - if (ret_code) { - aws_byte_buf_clean_up(output_buffer); - } aws_byte_buf_clean_up(&raw_checksum); return ret_code; } @@ -890,7 +882,8 @@ static int s_calculate_and_add_checksum_to_header_helper( /* Add the encoded checksum to header. */ const struct aws_byte_cursor header_name = aws_get_http_header_name_from_checksum_algorithm(checksum_context->algorithm); - struct aws_byte_cursor encoded_checksum_val = aws_byte_cursor_from_buf(&checksum_context->base64_checksum); + struct aws_byte_cursor encoded_checksum_val = + aws_s3_upload_request_checksum_context_get_checksum_cursor(checksum_context); struct aws_http_headers *headers = aws_http_message_get_headers(out_message); if (aws_http_headers_set(headers, header_name, encoded_checksum_val)) { goto done; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index bc3165f15..e2b29fbfa 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -294,7 +294,6 @@ add_test_case(crc32c_test_invalid_buffer) add_test_case(crc32c_test_oneshot) add_test_case(crc32c_test_invalid_state) -add_test_case(test_upload_request_checksum_context_get_output_buffer) add_test_case(test_upload_request_checksum_context_get_checksum_cursor) add_test_case(test_upload_request_checksum_context_error_cases) add_test_case(test_upload_request_checksum_context_different_algorithms) @@ -331,6 +330,7 @@ if(ENABLE_MOCK_SERVER_TESTS) add_net_test_case(multipart_upload_unsigned_with_trailer_checksum_mock_server) add_net_test_case(single_upload_unsigned_with_trailer_checksum_mock_server) add_net_test_case(multipart_upload_with_network_interface_names_mock_server) + add_net_test_case(multipart_upload_checksum_with_retry_before_finish_mock_server) add_net_test_case(multipart_upload_checksum_with_retry_mock_server) add_net_test_case(multipart_download_checksum_with_retry_mock_server) add_net_test_case(async_internal_error_from_complete_multipart_mock_server) diff --git a/tests/mock_s3_server/UploadPart/throttle_before_finish.json b/tests/mock_s3_server/UploadPart/throttle_before_finish.json new file mode 100644 index 000000000..b68e9d54e --- /dev/null +++ b/tests/mock_s3_server/UploadPart/throttle_before_finish.json @@ -0,0 +1,13 @@ +{ + "status": 503, + "headers": {"ETag": "b54357faf0632cce46e942fa68356b38", "Connection": "keep-alive"}, + "body": [ + "", + "", + "", + "SlowDown", + "656c76696e6727732072657175657374", + "Uuag1LuByRx9e6j5Onimru9pO4ZVKnJ2Qz7/C1NPcfTWAtRPfTaOFg==", + "" + ] + } diff --git a/tests/mock_s3_server/mock_s3_server.py b/tests/mock_s3_server/mock_s3_server.py index c90b1f2c5..706a9b830 100644 --- a/tests/mock_s3_server/mock_s3_server.py +++ b/tests/mock_s3_server/mock_s3_server.py @@ -51,16 +51,33 @@ class Response: @dataclass class ResponseConfig: path: str - disconnect_after_headers = False + request: Optional[object] = None # Add request as a field + disconnect_after_headers: bool = False generate_body_size: Optional[int] = None - json_path: str = None - throttle: bool = False + json_path: Optional[str] = None + forced_throttle: bool = False force_retry: bool = False + should_skip_wait: bool = False request_headers: Optional[List[Tuple[bytes, bytes]]] = None + def __post_init__(self): + """Called automatically after the dataclass __init__""" + if self.request is not None: + self.request_headers = self.request.headers + if get_request_header_value(self.request, "before_finish") is not None: + self.should_skip_wait = True + if get_request_header_value(self.request, "force_throttle") is not None: + self.forced_throttle = True + def _resolve_file_path(self, wrapper, request_type): global SHOULD_THROTTLE if self.json_path is None: + if self.forced_throttle: + # force the throttle to happend, instead of just 50%. + response_file = os.path.join( + base_dir, request_type.name, f"throttle.json") + self.json_path = response_file + return response_file = os.path.join( base_dir, request_type.name, f"{self.path[1:]}.json") if os.path.exists(response_file) == False: @@ -481,19 +498,20 @@ async def handle_mock_s3_request(wrapper, request): wrapper.info("unsupported request:", request) request_type = S3Opts.CreateMultipartUpload - while True: - event = await wrapper.next_event() - if type(event) is h11.EndOfMessage: - break - assert type(event) is h11.Data - if response_config is None: - response_config = ResponseConfig(parsed_path.path) - response_config.request_headers = request.headers + response_config = ResponseConfig(parsed_path.path, request=request) + + if not response_config.should_skip_wait: + while True: + event = await wrapper.next_event() + if type(event) is h11.EndOfMessage: + break + assert type(event) is h11.Data + else: + print("Skipping waiting for request body") response = response_config.resolve_response( wrapper, request_type, head_request=method == "HEAD") - await send_response(wrapper, response) diff --git a/tests/s3_checksum_context_test.c b/tests/s3_checksum_context_test.c index 647293bd6..e1ab61ee9 100644 --- a/tests/s3_checksum_context_test.c +++ b/tests/s3_checksum_context_test.c @@ -8,38 +8,6 @@ #include #include -static int s_test_upload_request_checksum_context_get_output_buffer(struct aws_allocator *allocator, void *ctx) { - (void)ctx; - - struct aws_s3_meta_request_checksum_config_storage config = { - .allocator = allocator, - .checksum_algorithm = AWS_SCA_CRC32, - .location = AWS_SCL_HEADER, - .has_full_object_checksum = false, - }; - AWS_ZERO_STRUCT(config.full_object_checksum); - - /* Test get output buffer with valid context */ - struct aws_s3_upload_request_checksum_context *context = - aws_s3_upload_request_checksum_context_new(allocator, &config); - ASSERT_NOT_NULL(context); - - struct aws_byte_buf *output_buffer = aws_s3_upload_request_checksum_context_get_output_buffer(context); - ASSERT_NOT_NULL(output_buffer); - ASSERT_TRUE(output_buffer->capacity > 0); - - aws_s3_upload_request_checksum_context_release(context); - - /* Test get output buffer with NULL context */ - output_buffer = aws_s3_upload_request_checksum_context_get_output_buffer(NULL); - ASSERT_NULL(output_buffer); - - return AWS_OP_SUCCESS; -} -AWS_TEST_CASE( - test_upload_request_checksum_context_get_output_buffer, - s_test_upload_request_checksum_context_get_output_buffer) - static int s_test_upload_request_checksum_context_get_checksum_cursor(struct aws_allocator *allocator, void *ctx) { (void)ctx; diff --git a/tests/s3_checksum_stream_test.c b/tests/s3_checksum_stream_test.c index 76e63529f..e63b31fb6 100644 --- a/tests/s3_checksum_stream_test.c +++ b/tests/s3_checksum_stream_test.c @@ -28,8 +28,7 @@ static int compare_checksum_stream(struct aws_allocator *allocator, struct aws_b struct aws_byte_cursor checksum_result_cursor = aws_byte_cursor_from_buf(&compute_checksum_output); aws_base64_encode(&checksum_result_cursor, &compute_encoded_checksum_output); struct aws_input_stream *cursor_stream = aws_input_stream_new_from_cursor(allocator, input); - struct aws_input_stream *stream = - aws_checksum_stream_new(allocator, cursor_stream, algorithm, &stream_checksum_output); + struct aws_input_stream *stream = aws_checksum_stream_new(allocator, cursor_stream, algorithm); aws_input_stream_release(cursor_stream); struct aws_stream_status status; AWS_ZERO_STRUCT(status); @@ -38,11 +37,12 @@ static int compare_checksum_stream(struct aws_allocator *allocator, struct aws_b read_buf.len = 0; ASSERT_TRUE(aws_input_stream_get_status(stream, &status) == 0); } - aws_input_stream_release(stream); + aws_checksum_stream_finalize_checksum(stream, &stream_checksum_output); ASSERT_TRUE(aws_byte_buf_eq(&compute_encoded_checksum_output, &stream_checksum_output)); aws_byte_buf_clean_up(&compute_checksum_output); aws_byte_buf_clean_up(&stream_checksum_output); aws_byte_buf_clean_up(&compute_encoded_checksum_output); + aws_input_stream_release(stream); } aws_byte_buf_clean_up(&read_buf); return AWS_OP_SUCCESS; diff --git a/tests/s3_mock_server_tests.c b/tests/s3_mock_server_tests.c index b7b85f624..b6c83b248 100644 --- a/tests/s3_mock_server_tests.c +++ b/tests/s3_mock_server_tests.c @@ -435,13 +435,92 @@ TEST_CASE(multipart_upload_with_network_interface_names_mock_server) { } /* Total hack to flip the bytes. */ -static void s_after_prepare_upload_part_finish(struct aws_s3_request *request) { - if (request->num_times_prepared > 1) { +static void s_after_prepare_upload_part_finish(struct aws_s3_request *request, struct aws_http_message *message) { + (void)message; + if (request->num_times_prepared > 0) { /* mock that the body buffer was messed up in memory */ request->request_body.buffer[1]++; } } +static void s_after_prepare_upload_part_finish_retry_before_finish_sending( + struct aws_s3_request *request, + struct aws_http_message *message) { + if (request->num_times_prepared == 0 && message != NULL) { + struct aws_http_header before_finish_header = { + .name = aws_byte_cursor_from_c_str("before_finish"), + .value = aws_byte_cursor_from_c_str("true"), + }; + aws_http_message_add_header(message, before_finish_header); + struct aws_http_header throttle_header = { + .name = aws_byte_cursor_from_c_str("force_throttle"), + .value = aws_byte_cursor_from_c_str("true"), + }; + aws_http_message_add_header(message, throttle_header); + } + if (request->num_times_prepared > 0) { + /* mock that the body buffer was messed up in memory */ + request->request_body.buffer[1]++; + } +} + +/** + * This test is built for + * 1. The retry happens before the upload has finished. + */ +TEST_CASE(multipart_upload_checksum_with_retry_before_finish_mock_server) { + (void)ctx; + struct aws_s3_tester tester; + ASSERT_SUCCESS(aws_s3_tester_init(allocator, &tester)); + struct aws_s3_tester_client_options client_options = { + .part_size = MB_TO_BYTES(5), + .tls_usage = AWS_S3_TLS_DISABLED, + }; + + struct aws_s3_client *client = NULL; + ASSERT_SUCCESS(aws_s3_tester_client_new(&tester, &client_options, &client)); + struct aws_s3_client_vtable *patched_client_vtable = aws_s3_tester_patch_client_vtable(&tester, client, NULL); + patched_client_vtable->after_prepare_upload_part_finish = + s_after_prepare_upload_part_finish_retry_before_finish_sending; + + struct aws_byte_cursor object_path = aws_byte_cursor_from_c_str("/throttle"); + { + /* 1. Trailer checksum */ + struct aws_s3_tester_meta_request_options put_options = { + .allocator = allocator, + .meta_request_type = AWS_S3_META_REQUEST_TYPE_PUT_OBJECT, + .client = client, + .checksum_algorithm = AWS_SCA_CRC32, + .validate_get_response_checksum = false, + .put_options = + { + .object_size_mb = 10, + .object_path_override = object_path, + }, + .mock_server = true, + }; + + struct aws_s3_meta_request_test_results meta_request_test_results; + aws_s3_meta_request_test_results_init(&meta_request_test_results, allocator); + + ASSERT_SUCCESS(aws_s3_tester_send_meta_request_with_options(&tester, &put_options, &meta_request_test_results)); + + ASSERT_INT_EQUALS(meta_request_test_results.upload_review.part_count, 2); + /* Note: the data we currently generate is always the same, + * The retry got the messed up data, while the first run never actually finish reading the bytes, so the messed + * up data checksum got to be sent. */ + ASSERT_STR_EQUALS( + "dKYRxA==", aws_string_c_str(meta_request_test_results.upload_review.part_checksums_array[0])); + ASSERT_STR_EQUALS( + "dxV2Sw==", aws_string_c_str(meta_request_test_results.upload_review.part_checksums_array[1])); + aws_s3_meta_request_test_results_clean_up(&meta_request_test_results); + } + aws_s3_client_release(client); + aws_s3_tester_clean_up(&tester); + + return AWS_OP_SUCCESS; +} + /** * This test is built for * 1. We had a memory leak when the retry was triggered and the checksum was calculated.