Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for progress_callback in Object#download_file #2902

Merged
merged 4 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build_tools/services.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ServiceEnumerator
MINIMUM_CORE_VERSION = "3.177.0"

# Minimum `aws-sdk-core` version for new S3 gem builds
MINIMUM_CORE_VERSION_S3 = "3.179.0"
MINIMUM_CORE_VERSION_S3 = "3.181.0"

EVENTSTREAM_PLUGIN = "Aws::Plugins::EventStreamConfiguration"

Expand Down
2 changes: 2 additions & 0 deletions gems/aws-sdk-core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
Unreleased Changes
------------------

* Feature - Add support for `on_chunk_received` callback.

3.180.3 (2023-08-09)
------------------

Expand Down
31 changes: 31 additions & 0 deletions gems/aws-sdk-core/lib/seahorse/client/plugins/request_callback.rb
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ class RequestCallback < Plugin
bytes in the body.
DOCS

option(:on_chunk_received,
mullermp marked this conversation as resolved.
Show resolved Hide resolved
default: nil,
doc_type: 'Proc',
docstring: <<-DOCS)
When a Proc object is provided, it will be used as callback when each chunk
of the response body is received. It provides three arguments: the chunk,
the number of bytes received, and the total number of
bytes in the response (or nil if the server did not send a `content-length`).
DOCS

# @api private
class OptionHandler < Client::Handler
def call(context)
Expand All @@ -68,8 +78,29 @@ def call(context)
end
on_chunk_sent = context.config.on_chunk_sent if on_chunk_sent.nil?
context[:on_chunk_sent] = on_chunk_sent if on_chunk_sent

if context.params.is_a?(Hash) && context.params[:on_chunk_received]
on_chunk_received = context.params.delete(:on_chunk_received)
end
on_chunk_received = context.config.on_chunk_received if on_chunk_received.nil?

add_response_events(on_chunk_received, context) if on_chunk_received

@handler.call(context)
end

def add_response_events(on_chunk_received, context)
shared_data = {bytes_received: 0}

context.http_response.on_headers do |_status, headers|
shared_data[:content_length] = headers['content-length']&.to_i
end

context.http_response.on_data do |chunk|
shared_data[:bytes_received] += chunk.bytesize if chunk && chunk.respond_to?(:bytesize)
on_chunk_received.call(chunk, shared_data[:bytes_received], shared_data[:content_length])
end
end
end

# @api private
Expand Down
2 changes: 2 additions & 0 deletions gems/aws-sdk-s3/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
Unreleased Changes
------------------

* Feature - Add support for `progress_callback` in `Object#download_file` and improve multi-threaded performance #(2901).

1.132.1 (2023-08-09)
------------------

Expand Down
16 changes: 16 additions & 0 deletions gems/aws-sdk-s3/lib/aws-sdk-s3/customizations/object.rb
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,15 @@ def upload_file(source, options = {})
# # and the parts are downloaded in parallel
# obj.download_file('/path/to/very_large_file')
#
# You can provide a callback to monitor progress of the download:
#
# # bytes and part_sizes are each an array with 1 entry per part
# # part_sizes may not be known until the first bytes are retrieved
# progress = Proc.new do |bytes, part_sizes, file_size|
# puts bytes.map.with_index { |b, i| "Part #{i+1}: #{b} / #{part_sizes[i]}"}.join(' ') + "Total: #{100.0 * bytes.sum / file_size}%" }
# end
# obj.download_file('/path/to/file', progress_callback: progress)
#
# @param [String] destination Where to download the file to.
#
# @option options [String] mode `auto`, `single_request`, `get_range`
Expand Down Expand Up @@ -505,6 +514,13 @@ def upload_file(source, options = {})
# response. For multipart downloads, this will be called for each
# part that is downloaded and validated.
#
# @option options [Proc] :progress_callback
mullermp marked this conversation as resolved.
Show resolved Hide resolved
# A Proc that will be called when each chunk of the download is received.
# It will be invoked with [bytes_read], [part_sizes], file_size.
# When the object is downloaded as parts (rather than by ranges), the
# part_sizes will not be known ahead of time and will be nil in the
# callback until the first bytes in the part are received.
#
# @return [Boolean] Returns `true` when the file is downloaded without
# any errors.
def download_file(destination, options = {})
Expand Down
136 changes: 113 additions & 23 deletions gems/aws-sdk-s3/lib/aws-sdk-s3/file_downloader.rb
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def download(destination, options = {})
end
@on_checksum_validated = options[:on_checksum_validated]

@progress_callback = options[:progress_callback]

validate!

Aws::Plugins::UserAgent.feature('s3-transfer') do
Expand All @@ -49,7 +51,7 @@ def download(destination, options = {})
when 'get_range'
if @chunk_size
resp = @client.head_object(@params)
multithreaded_get_by_ranges(construct_chunks(resp.content_length))
multithreaded_get_by_ranges(resp.content_length)
else
msg = 'In :get_range mode, :chunk_size must be provided'
raise ArgumentError, msg
Expand Down Expand Up @@ -82,7 +84,7 @@ def multipart_download
if resp.content_length <= MIN_CHUNK_SIZE
single_request
else
multithreaded_get_by_ranges(construct_chunks(resp.content_length))
multithreaded_get_by_ranges(resp.content_length)
end
else
# partNumber is an option
Expand All @@ -99,9 +101,9 @@ def compute_mode(file_size, count)
chunk_size = compute_chunk(file_size)
part_size = (file_size.to_f / count.to_f).ceil
if chunk_size < part_size
multithreaded_get_by_ranges(construct_chunks(file_size))
multithreaded_get_by_ranges(file_size)
else
multithreaded_get_by_parts(count)
multithreaded_get_by_parts(count, file_size)
end
end

Expand Down Expand Up @@ -133,30 +135,65 @@ def batches(chunks, mode)
chunks.each_slice(@thread_count).to_a
end

def multithreaded_get_by_ranges(chunks)
thread_batches(chunks, 'range')
def multithreaded_get_by_ranges(file_size)
offset = 0
default_chunk_size = compute_chunk(file_size)
chunks = []
part_number = 1 # parts start at 1
while offset < file_size
progress = offset + default_chunk_size
progress = file_size if progress > file_size
range = "bytes=#{offset}-#{progress - 1}"
chunks << Part.new(
part_number: part_number,
size: (progress-offset),
params: @params.merge(range: range)
)
part_number += 1
offset = progress
end
download_in_threads(PartList.new(chunks), file_size)
end

def multithreaded_get_by_parts(parts)
thread_batches(parts, 'part_number')
def multithreaded_get_by_parts(n_parts, total_size)
parts = (1..n_parts).map do |part|
Part.new(part_number: part, params: @params.merge(part_number: part))
end
download_in_threads(PartList.new(parts), total_size)
end

def thread_batches(chunks, param)
batches(chunks, param).each do |batch|
threads = []
batch.each do |chunk|
threads << Thread.new do
resp = @client.get_object(
@params.merge(param.to_sym => chunk)
)
write(resp)
if @on_checksum_validated && resp.checksum_validated
@on_checksum_validated.call(resp.checksum_validated, resp)
def download_in_threads(pending, total_size)
threads = []
if @progress_callback
progress = MultipartProgress.new(pending, total_size, @progress_callback)
end
@thread_count.times do
thread = Thread.new do
begin
while part = pending.shift
if progress
part.params[:on_chunk_received] =
proc do |_chunk, bytes, total|
progress.call(part.part_number, bytes, total)
end
end
resp = @client.get_object(part.params)
write(resp)
if @on_checksum_validated && resp.checksum_validated
@on_checksum_validated.call(resp.checksum_validated, resp)
end
end
nil
rescue => error
# keep other threads from downloading other parts
pending.clear!
raise error
end
end
threads.each(&:join)
thread.abort_on_exception = true
threads << thread
end
threads.map(&:value).compact
end

def write(resp)
Expand All @@ -166,9 +203,9 @@ def write(resp)
end

def single_request
resp = @client.get_object(
@params.merge(response_target: @path)
)
params = @params.merge(response_target: @path)
params[:on_chunk_received] = single_part_progress if @progress_callback
resp = @client.get_object(params)

return resp unless @on_checksum_validated

Expand All @@ -178,6 +215,59 @@ def single_request

resp
end

def single_part_progress
proc do |_chunk, bytes_read, total_size|
@progress_callback.call([bytes_read], [total_size])
end
end

class Part < Struct.new(:part_number, :size, :params)
include Aws::Structure
end

# @api private
class PartList
mullermp marked this conversation as resolved.
Show resolved Hide resolved
include Enumerable
def initialize(parts = [])
@parts = parts
@mutex = Mutex.new
end

def shift
@mutex.synchronize { @parts.shift }
end

def size
@mutex.synchronize { @parts.size }
end

def clear!
@mutex.synchronize { @parts.clear }
end

def each(&block)
@mutex.synchronize { @parts.each(&block) }
end
end

# @api private
class MultipartProgress
def initialize(parts, total_size, progress_callback)
@bytes_received = Array.new(parts.size, 0)
@part_sizes = parts.map(&:size)
@total_size = total_size
@progress_callback = progress_callback
end

def call(part_number, bytes_received, total)
# part numbers start at 1
@bytes_received[part_number - 1] = bytes_received
# part size may not be known until we get the first response
@part_sizes[part_number - 1] ||= total
@progress_callback.call(@bytes_received, @part_sizes, @total_size)
end
end
end
end
end
4 changes: 3 additions & 1 deletion gems/aws-sdk-s3/spec/object/download_file_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,10 @@ module S3
end

it 'raises an error when checksum validation fails on multipart' do
mullermp marked this conversation as resolved.
Show resolved Hide resolved
thread = double(value: nil)
client.stub_responses(:get_object, {body: 'body', checksum_sha1: 'invalid'})
expect(Thread).to receive(:new).and_yield
expect(Thread).to receive(:new).and_yield.and_return(thread)
allow(thread).to receive(:abort_on_exception=)

expect do
large_obj.download_file(path)
Expand Down
Loading