Skip to content

Commit

Permalink
Add support for progress_callback in Object#download_file
Browse files Browse the repository at this point in the history
  • Loading branch information
alextwoods committed Aug 18, 2023
1 parent 24b7683 commit 0d2c1a5
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 24 deletions.
2 changes: 1 addition & 1 deletion build_tools/services.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ServiceEnumerator
MINIMUM_CORE_VERSION = "3.177.0"

# Minimum `aws-sdk-core` version for new S3 gem builds
MINIMUM_CORE_VERSION_S3 = "3.179.0"
MINIMUM_CORE_VERSION_S3 = "3.181.0"

EVENTSTREAM_PLUGIN = "Aws::Plugins::EventStreamConfiguration"

Expand Down
2 changes: 2 additions & 0 deletions gems/aws-sdk-core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
Unreleased Changes
------------------

* Feature - Add support for `on_chunk_received` callback.

3.180.3 (2023-08-09)
------------------

Expand Down
31 changes: 31 additions & 0 deletions gems/aws-sdk-core/lib/seahorse/client/plugins/request_callback.rb
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ class RequestCallback < Plugin
bytes in the body.
DOCS

option(:on_chunk_received,
default: nil,
doc_type: 'Proc',
docstring: <<-DOCS)
When a Proc object is provided, it will be used as callback when each chunk
of the response body is received. It provides three arguments: the chunk,
the number of bytes received, and the total number of
bytes in the response (or nil if the server did not send a `content-length`).
DOCS

# @api private
class OptionHandler < Client::Handler
def call(context)
Expand All @@ -68,8 +78,29 @@ def call(context)
end
on_chunk_sent = context.config.on_chunk_sent if on_chunk_sent.nil?
context[:on_chunk_sent] = on_chunk_sent if on_chunk_sent

if context.params.is_a?(Hash) && context.params[:on_chunk_received]
on_chunk_received = context.params.delete(:on_chunk_received)
end
on_chunk_received = context.config.on_chunk_received if on_chunk_received.nil?

add_response_events(on_chunk_received, context) if on_chunk_received

@handler.call(context)
end

def add_response_events(on_chunk_received, context)
shared_data = {bytes_received: 0}

context.http_response.on_headers do |_status, headers|
shared_data[:content_length] = headers['content-length']&.to_i
end

context.http_response.on_data do |chunk|
shared_data[:bytes_received] += chunk.bytesize if chunk && chunk.respond_to?(:bytesize)
on_chunk_received.call(chunk, shared_data[:bytes_received], shared_data[:content_length])
end
end
end

# @api private
Expand Down
2 changes: 2 additions & 0 deletions gems/aws-sdk-s3/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
Unreleased Changes
------------------

* Feature - Add support for `progress_callback` in `Object#download_file` and improve multi-threaded performance #(2901).

1.132.1 (2023-08-09)
------------------

Expand Down
16 changes: 16 additions & 0 deletions gems/aws-sdk-s3/lib/aws-sdk-s3/customizations/object.rb
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,15 @@ def upload_file(source, options = {})
# # and the parts are downloaded in parallel
# obj.download_file('/path/to/very_large_file')
#
# You can provide a callback to monitor progress of the download:
#
# # bytes and part_sizes are each an array with 1 entry per part
# # part_sizes may not be known until the first bytes are retrieved
# progress = Proc.new do |bytes, part_sizes, file_size|
# puts bytes.map.with_index { |b, i| "Part #{i+1}: #{b} / #{part_sizes[i]}"}.join(' ') + "Total: #{100.0 * bytes.sum / file_size}%" }
# end
# obj.download_file('/path/to/file', progress_callback: progress)
#
# @param [String] destination Where to download the file to.
#
# @option options [String] mode `auto`, `single_request`, `get_range`
Expand Down Expand Up @@ -505,6 +514,13 @@ def upload_file(source, options = {})
# response. For multipart downloads, this will be called for each
# part that is downloaded and validated.
#
# @option options [Proc] :progress_callback
# A Proc that will be called when each chunk of the download is received.
# It will be invoked with [bytes_read], [part_sizes], file_size.
# When the object is downloaded as parts (rather than by ranges), the
# part_sizes will not be known ahead of time and will be nil in the
# callback until the first bytes in the part are received.
#
# @return [Boolean] Returns `true` when the file is downloaded without
# any errors.
def download_file(destination, options = {})
Expand Down
134 changes: 111 additions & 23 deletions gems/aws-sdk-s3/lib/aws-sdk-s3/file_downloader.rb
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def download(destination, options = {})
end
@on_checksum_validated = options[:on_checksum_validated]

@progress_callback = options[:progress_callback]

validate!

Aws::Plugins::UserAgent.feature('s3-transfer') do
Expand All @@ -49,7 +51,7 @@ def download(destination, options = {})
when 'get_range'
if @chunk_size
resp = @client.head_object(@params)
multithreaded_get_by_ranges(construct_chunks(resp.content_length))
multithreaded_get_by_ranges(resp.content_length)
else
msg = 'In :get_range mode, :chunk_size must be provided'
raise ArgumentError, msg
Expand Down Expand Up @@ -82,7 +84,7 @@ def multipart_download
if resp.content_length <= MIN_CHUNK_SIZE
single_request
else
multithreaded_get_by_ranges(construct_chunks(resp.content_length))
multithreaded_get_by_ranges(resp.content_length)
end
else
# partNumber is an option
Expand All @@ -99,9 +101,9 @@ def compute_mode(file_size, count)
chunk_size = compute_chunk(file_size)
part_size = (file_size.to_f / count.to_f).ceil
if chunk_size < part_size
multithreaded_get_by_ranges(construct_chunks(file_size))
multithreaded_get_by_ranges(file_size)
else
multithreaded_get_by_parts(count)
multithreaded_get_by_parts(count, file_size)
end
end

Expand Down Expand Up @@ -133,30 +135,65 @@ def batches(chunks, mode)
chunks.each_slice(@thread_count).to_a
end

def multithreaded_get_by_ranges(chunks)
thread_batches(chunks, 'range')
def multithreaded_get_by_ranges(file_size)
offset = 0
default_chunk_size = compute_chunk(file_size)
chunks = []
part_number = 1 # parts start at 1
while offset < file_size
progress = offset + default_chunk_size
progress = file_size if progress > file_size
range = "bytes=#{offset}-#{progress - 1}"
chunks << Part.new(
part_number: part_number,
size: (progress-offset),
params: @params.merge(range: range)
)
part_number += 1
offset = progress
end
download_in_threads(PartList.new(chunks), file_size)
end

def multithreaded_get_by_parts(parts)
thread_batches(parts, 'part_number')
def multithreaded_get_by_parts(n_parts, total_size)
parts = (1..n_parts).map do |part|
Part.new(part_number: part, params: @params.merge(part_number: part))
end
download_in_threads(PartList.new(parts), total_size)
end

def thread_batches(chunks, param)
batches(chunks, param).each do |batch|
threads = []
batch.each do |chunk|
threads << Thread.new do
resp = @client.get_object(
@params.merge(param.to_sym => chunk)
)
write(resp)
if @on_checksum_validated && resp.checksum_validated
@on_checksum_validated.call(resp.checksum_validated, resp)
def download_in_threads(pending, total_size)
threads = []
if @progress_callback
progress = MultipartProgress.new(pending, total_size, @progress_callback)
end
@thread_count.times do
thread = Thread.new do
begin
while part = pending.shift
if progress
part.params[:on_chunk_received] =
proc do |_chunk, bytes, total|
progress.call(part.part_number, bytes, total)
end
end
resp = @client.get_object(part.params)
write(resp)
if @on_checksum_validated && resp.checksum_validated
@on_checksum_validated.call(resp.checksum_validated, resp)
end
end
nil
rescue => error
# keep other threads from downloading other parts
pending.clear!
error
end
end
threads.each(&:join)
thread.abort_on_exception = true
threads << thread
end
threads.map(&:value).compact
end

def write(resp)
Expand All @@ -166,9 +203,9 @@ def write(resp)
end

def single_request
resp = @client.get_object(
@params.merge(response_target: @path)
)
params = @params.merge(response_target: @path)
params[:on_chunk_received] = single_part_progress if @progress_callback
resp = @client.get_object(params)

return resp unless @on_checksum_validated

Expand All @@ -178,6 +215,57 @@ def single_request

resp
end

def single_part_progress
proc do |_chunk, bytes_read, total_size|
@progress_callback.call([bytes_read], [total_size])
end
end

Part = Structure.new(:part_number, :size, :params, keyword_init: true)

# @api private
class PartList
include Enumerable
def initialize(parts = [])
@parts = parts
@mutex = Mutex.new
end

def shift
@mutex.synchronize { @parts.shift }
end

def size
@mutex.synchronize { @parts.size }
end

def clear!
@mutex.synchronize { @parts.clear }
end

def each(&block)
@mutex.synchronize { @parts.each(&block) }
end
end

# @api private
class MultipartProgress
def initialize(parts, total_size, progress_callback)
@bytes_received = Array.new(parts.size, 0)
@part_sizes = parts.map(&:size)
@total_size = total_size
@progress_callback = progress_callback
end

def call(part_number, bytes_received, total)
# part numbers start at 1
@bytes_received[part_number - 1] = bytes_received
# part size may not be known until we get the first response
@part_sizes[part_number - 1] ||= total
@progress_callback.call(@bytes_received, @part_sizes, @total_size)
end
end
end
end
end

0 comments on commit 0d2c1a5

Please sign in to comment.