From 8e6aac3a03f0beba9147bc1546f27917a49d8881 Mon Sep 17 00:00:00 2001 From: Alex Woods Date: Tue, 22 Aug 2023 08:44:02 -0700 Subject: [PATCH] Add support for progress_callback in Object#download_file (#2902) --- build_tools/services.rb | 2 +- gems/aws-sdk-core/CHANGELOG.md | 2 + .../client/plugins/request_callback.rb | 31 ++++ gems/aws-sdk-s3/CHANGELOG.md | 2 + .../lib/aws-sdk-s3/customizations/object.rb | 42 ++++++ .../lib/aws-sdk-s3/file_downloader.rb | 136 +++++++++++++++--- .../lib/aws-sdk-s3/multipart_upload_part.rb | 2 +- .../spec/object/download_file_spec.rb | 50 ++++++- 8 files changed, 241 insertions(+), 26 deletions(-) diff --git a/build_tools/services.rb b/build_tools/services.rb index 42542df3380..ad6662f27c1 100644 --- a/build_tools/services.rb +++ b/build_tools/services.rb @@ -13,7 +13,7 @@ class ServiceEnumerator MINIMUM_CORE_VERSION = "3.177.0" # Minimum `aws-sdk-core` version for new S3 gem builds - MINIMUM_CORE_VERSION_S3 = "3.179.0" + MINIMUM_CORE_VERSION_S3 = "3.181.0" EVENTSTREAM_PLUGIN = "Aws::Plugins::EventStreamConfiguration" diff --git a/gems/aws-sdk-core/CHANGELOG.md b/gems/aws-sdk-core/CHANGELOG.md index a5f24a0c61f..4270bc4afb8 100644 --- a/gems/aws-sdk-core/CHANGELOG.md +++ b/gems/aws-sdk-core/CHANGELOG.md @@ -1,6 +1,8 @@ Unreleased Changes ------------------ +* Feature - Add support for `on_chunk_received` callback. + 3.180.3 (2023-08-09) ------------------ diff --git a/gems/aws-sdk-core/lib/seahorse/client/plugins/request_callback.rb b/gems/aws-sdk-core/lib/seahorse/client/plugins/request_callback.rb index 593f558d761..375ebb1856a 100644 --- a/gems/aws-sdk-core/lib/seahorse/client/plugins/request_callback.rb +++ b/gems/aws-sdk-core/lib/seahorse/client/plugins/request_callback.rb @@ -60,6 +60,16 @@ class RequestCallback < Plugin bytes in the body. DOCS + option(:on_chunk_received, + default: nil, + doc_type: 'Proc', + docstring: <<-DOCS) +When a Proc object is provided, it will be used as callback when each chunk +of the response body is received. It provides three arguments: the chunk, +the number of bytes received, and the total number of +bytes in the response (or nil if the server did not send a `content-length`). + DOCS + # @api private class OptionHandler < Client::Handler def call(context) @@ -68,8 +78,29 @@ def call(context) end on_chunk_sent = context.config.on_chunk_sent if on_chunk_sent.nil? context[:on_chunk_sent] = on_chunk_sent if on_chunk_sent + + if context.params.is_a?(Hash) && context.params[:on_chunk_received] + on_chunk_received = context.params.delete(:on_chunk_received) + end + on_chunk_received = context.config.on_chunk_received if on_chunk_received.nil? + + add_response_events(on_chunk_received, context) if on_chunk_received + @handler.call(context) end + + def add_response_events(on_chunk_received, context) + shared_data = {bytes_received: 0} + + context.http_response.on_headers do |_status, headers| + shared_data[:content_length] = headers['content-length']&.to_i + end + + context.http_response.on_data do |chunk| + shared_data[:bytes_received] += chunk.bytesize if chunk && chunk.respond_to?(:bytesize) + on_chunk_received.call(chunk, shared_data[:bytes_received], shared_data[:content_length]) + end + end end # @api private diff --git a/gems/aws-sdk-s3/CHANGELOG.md b/gems/aws-sdk-s3/CHANGELOG.md index 9cbbed9ccdb..da1b21295ca 100644 --- a/gems/aws-sdk-s3/CHANGELOG.md +++ b/gems/aws-sdk-s3/CHANGELOG.md @@ -1,6 +1,8 @@ Unreleased Changes ------------------ +* Feature - Add support for `progress_callback` in `Object#download_file` and improve multi-threaded performance #(2901). + 1.132.1 (2023-08-09) ------------------ diff --git a/gems/aws-sdk-s3/lib/aws-sdk-s3/customizations/object.rb b/gems/aws-sdk-s3/lib/aws-sdk-s3/customizations/object.rb index efa848efea4..10c8fd7be2c 100644 --- a/gems/aws-sdk-s3/lib/aws-sdk-s3/customizations/object.rb +++ b/gems/aws-sdk-s3/lib/aws-sdk-s3/customizations/object.rb @@ -353,6 +353,10 @@ def public_url(options = {}) # obj.upload_stream do |write_stream| # IO.copy_stream(STDIN, write_stream) # end + # @param [Hash] options + # Additional options for {Client#create_multipart_upload}, + # {Client#complete_multipart_upload}, + # and {Client#upload_part} can be provided. # # @option options [Integer] :thread_count (10) The number of parallel # multipart uploads @@ -375,6 +379,9 @@ def public_url(options = {}) # @return [Boolean] Returns `true` when the object is uploaded # without any errors. # + # @see Client#create_multipart_upload + # @see Client#complete_multipart_upload + # @see Client#upload_part def upload_stream(options = {}, &block) uploading_options = options.dup uploader = MultipartStreamUploader.new( @@ -427,6 +434,13 @@ def upload_stream(options = {}, &block) # using an open Tempfile, rewind it before uploading or else the object # will be empty. # + # @param [Hash] options + # Additional options for {Client#put_object} + # when file sizes below the multipart threshold. For files larger than + # the multipart threshold, options for {Client#create_multipart_upload}, + # {Client#complete_multipart_upload}, + # and {Client#upload_part} can be provided. + # # @option options [Integer] :multipart_threshold (104857600) Files larger # than or equal to `:multipart_threshold` are uploaded using the S3 # multipart APIs. @@ -448,6 +462,11 @@ def upload_stream(options = {}, &block) # # @return [Boolean] Returns `true` when the object is uploaded # without any errors. + # + # @see Client#put_object + # @see Client#create_multipart_upload + # @see Client#complete_multipart_upload + # @see Client#upload_part def upload_file(source, options = {}) uploading_options = options.dup uploader = FileUploader.new( @@ -475,8 +494,21 @@ def upload_file(source, options = {}) # # and the parts are downloaded in parallel # obj.download_file('/path/to/very_large_file') # + # You can provide a callback to monitor progress of the download: + # + # # bytes and part_sizes are each an array with 1 entry per part + # # part_sizes may not be known until the first bytes are retrieved + # progress = Proc.new do |bytes, part_sizes, file_size| + # puts bytes.map.with_index { |b, i| "Part #{i+1}: #{b} / #{part_sizes[i]}"}.join(' ') + "Total: #{100.0 * bytes.sum / file_size}%" } + # end + # obj.download_file('/path/to/file', progress_callback: progress) + # # @param [String] destination Where to download the file to. # + # @param [Hash] options + # Additional options for {Client#get_object} and #{Client#head_object} + # may be provided. + # # @option options [String] mode `auto`, `single_request`, `get_range` # `single_request` mode forces only 1 GET request is made in download, # `get_range` mode allows `chunk_size` parameter to configured in @@ -505,8 +537,18 @@ def upload_file(source, options = {}) # response. For multipart downloads, this will be called for each # part that is downloaded and validated. # + # @option options [Proc] :progress_callback + # A Proc that will be called when each chunk of the download is received. + # It will be invoked with [bytes_read], [part_sizes], file_size. + # When the object is downloaded as parts (rather than by ranges), the + # part_sizes will not be known ahead of time and will be nil in the + # callback until the first bytes in the part are received. + # # @return [Boolean] Returns `true` when the file is downloaded without # any errors. + # + # @see Client#get_object + # @see Client#head_object def download_file(destination, options = {}) downloader = FileDownloader.new(client: client) Aws::Plugins::UserAgent.feature('resource') do diff --git a/gems/aws-sdk-s3/lib/aws-sdk-s3/file_downloader.rb b/gems/aws-sdk-s3/lib/aws-sdk-s3/file_downloader.rb index 01448298536..f43149267c4 100644 --- a/gems/aws-sdk-s3/lib/aws-sdk-s3/file_downloader.rb +++ b/gems/aws-sdk-s3/lib/aws-sdk-s3/file_downloader.rb @@ -40,6 +40,8 @@ def download(destination, options = {}) end @on_checksum_validated = options[:on_checksum_validated] + @progress_callback = options[:progress_callback] + validate! Aws::Plugins::UserAgent.feature('s3-transfer') do @@ -49,7 +51,7 @@ def download(destination, options = {}) when 'get_range' if @chunk_size resp = @client.head_object(@params) - multithreaded_get_by_ranges(construct_chunks(resp.content_length)) + multithreaded_get_by_ranges(resp.content_length) else msg = 'In :get_range mode, :chunk_size must be provided' raise ArgumentError, msg @@ -82,7 +84,7 @@ def multipart_download if resp.content_length <= MIN_CHUNK_SIZE single_request else - multithreaded_get_by_ranges(construct_chunks(resp.content_length)) + multithreaded_get_by_ranges(resp.content_length) end else # partNumber is an option @@ -99,9 +101,9 @@ def compute_mode(file_size, count) chunk_size = compute_chunk(file_size) part_size = (file_size.to_f / count.to_f).ceil if chunk_size < part_size - multithreaded_get_by_ranges(construct_chunks(file_size)) + multithreaded_get_by_ranges(file_size) else - multithreaded_get_by_parts(count) + multithreaded_get_by_parts(count, file_size) end end @@ -133,30 +135,65 @@ def batches(chunks, mode) chunks.each_slice(@thread_count).to_a end - def multithreaded_get_by_ranges(chunks) - thread_batches(chunks, 'range') + def multithreaded_get_by_ranges(file_size) + offset = 0 + default_chunk_size = compute_chunk(file_size) + chunks = [] + part_number = 1 # parts start at 1 + while offset < file_size + progress = offset + default_chunk_size + progress = file_size if progress > file_size + range = "bytes=#{offset}-#{progress - 1}" + chunks << Part.new( + part_number: part_number, + size: (progress-offset), + params: @params.merge(range: range) + ) + part_number += 1 + offset = progress + end + download_in_threads(PartList.new(chunks), file_size) end - def multithreaded_get_by_parts(parts) - thread_batches(parts, 'part_number') + def multithreaded_get_by_parts(n_parts, total_size) + parts = (1..n_parts).map do |part| + Part.new(part_number: part, params: @params.merge(part_number: part)) + end + download_in_threads(PartList.new(parts), total_size) end - def thread_batches(chunks, param) - batches(chunks, param).each do |batch| - threads = [] - batch.each do |chunk| - threads << Thread.new do - resp = @client.get_object( - @params.merge(param.to_sym => chunk) - ) - write(resp) - if @on_checksum_validated && resp.checksum_validated - @on_checksum_validated.call(resp.checksum_validated, resp) + def download_in_threads(pending, total_size) + threads = [] + if @progress_callback + progress = MultipartProgress.new(pending, total_size, @progress_callback) + end + @thread_count.times do + thread = Thread.new do + begin + while part = pending.shift + if progress + part.params[:on_chunk_received] = + proc do |_chunk, bytes, total| + progress.call(part.part_number, bytes, total) + end + end + resp = @client.get_object(part.params) + write(resp) + if @on_checksum_validated && resp.checksum_validated + @on_checksum_validated.call(resp.checksum_validated, resp) + end end + nil + rescue => error + # keep other threads from downloading other parts + pending.clear! + raise error end end - threads.each(&:join) + thread.abort_on_exception = true + threads << thread end + threads.map(&:value).compact end def write(resp) @@ -166,9 +203,9 @@ def write(resp) end def single_request - resp = @client.get_object( - @params.merge(response_target: @path) - ) + params = @params.merge(response_target: @path) + params[:on_chunk_received] = single_part_progress if @progress_callback + resp = @client.get_object(params) return resp unless @on_checksum_validated @@ -178,6 +215,59 @@ def single_request resp end + + def single_part_progress + proc do |_chunk, bytes_read, total_size| + @progress_callback.call([bytes_read], [total_size], total_size) + end + end + + class Part < Struct.new(:part_number, :size, :params) + include Aws::Structure + end + + # @api private + class PartList + include Enumerable + def initialize(parts = []) + @parts = parts + @mutex = Mutex.new + end + + def shift + @mutex.synchronize { @parts.shift } + end + + def size + @mutex.synchronize { @parts.size } + end + + def clear! + @mutex.synchronize { @parts.clear } + end + + def each(&block) + @mutex.synchronize { @parts.each(&block) } + end + end + + # @api private + class MultipartProgress + def initialize(parts, total_size, progress_callback) + @bytes_received = Array.new(parts.size, 0) + @part_sizes = parts.map(&:size) + @total_size = total_size + @progress_callback = progress_callback + end + + def call(part_number, bytes_received, total) + # part numbers start at 1 + @bytes_received[part_number - 1] = bytes_received + # part size may not be known until we get the first response + @part_sizes[part_number - 1] ||= total + @progress_callback.call(@bytes_received, @part_sizes, @total_size) + end + end end end end diff --git a/gems/aws-sdk-s3/lib/aws-sdk-s3/multipart_upload_part.rb b/gems/aws-sdk-s3/lib/aws-sdk-s3/multipart_upload_part.rb index 34cbef50239..d47ffb9ff8d 100644 --- a/gems/aws-sdk-s3/lib/aws-sdk-s3/multipart_upload_part.rb +++ b/gems/aws-sdk-s3/lib/aws-sdk-s3/multipart_upload_part.rb @@ -286,7 +286,7 @@ def wait_until(options = {}, &block) # @option options [required, String] :copy_source # Specifies the source object for the copy operation. You specify the # value in one of two formats, depending on whether you want to access - # the source object through an [access point][1]: + # the source object through an [access point][1]\: # # * For objects not accessed through an access point, specify the name # of the source bucket and key of the source object, separated by a diff --git a/gems/aws-sdk-s3/spec/object/download_file_spec.rb b/gems/aws-sdk-s3/spec/object/download_file_spec.rb index a8cc498f894..c85753e1c58 100644 --- a/gems/aws-sdk-s3/spec/object/download_file_spec.rb +++ b/gems/aws-sdk-s3/spec/object/download_file_spec.rb @@ -121,6 +121,30 @@ module S3 small_obj.download_file(path) end + it 'reports progress for single part objects' do + small_file_size = 1024 + expect(client).to receive(:get_object).with({ + bucket: 'bucket', + key: 'small', + checksum_mode: 'ENABLED', + response_target: path, + on_chunk_received: instance_of(Proc) + }) do |args| + args[:on_chunk_received].call(small_file, small_file_size, small_file_size) + end + + n_calls = 0 + callback = proc do |bytes, part_sizes, total| + expect(bytes).to eq([small_file_size]) + expect(part_sizes).to eq([small_file_size]) + expect(total).to eq(small_file_size) + n_calls += 1 + end + + small_obj.download_file(path, progress_callback: callback) + expect(n_calls).to eq(1) + end + it 'download larger files in parts' do expect(client).to receive(:head_object).with({ bucket: 'bucket', @@ -136,6 +160,28 @@ module S3 large_obj.download_file(path) end + it 'reports progress for files downloaded in parts' do + expect(client).to receive(:get_object).exactly(4).times do |args| + args[:on_chunk_received].call(large_file, 4, 4) + client.stub_data( + :get_object, + body: StringIO.new('chunk'), content_range: 'bytes 0-4/4' + ) + end + + n_calls = 0 + callback = proc do |bytes, part_sizes, total| + expect(bytes.size).to eq(4) + expect(part_sizes.size).to eq(4) + expect(total).to eq(20*one_meg) + n_calls += 1 + end + + large_obj.download_file(path, progress_callback: callback) + + expect(n_calls).to eq(4) + end + it 'download larger files in ranges' do expect(client).to receive(:head_object).with({ bucket: 'bucket', @@ -172,8 +218,10 @@ module S3 end it 'raises an error when checksum validation fails on multipart' do + thread = double(value: nil) client.stub_responses(:get_object, {body: 'body', checksum_sha1: 'invalid'}) - expect(Thread).to receive(:new).and_yield + expect(Thread).to receive(:new).and_yield.and_return(thread) + allow(thread).to receive(:abort_on_exception) expect do large_obj.download_file(path)