Skip to content

Commit

Permalink
feat: add resume-download in tman (#203)
Browse files Browse the repository at this point in the history
  • Loading branch information
halajohn authored Oct 25, 2024
1 parent 268302a commit fd2427b
Showing 1 changed file with 98 additions and 15 deletions.
113 changes: 98 additions & 15 deletions core/src/ten_manager/src/registry/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,62 +306,145 @@ pub async fn upload_package(
Ok(())
}

// Parse the header of content-range.
fn parse_content_range(content_range: &str) -> Option<(u64, u64, u64)> {
let parts: Vec<&str> = content_range.split('/').collect();
if parts.len() == 2 {
let total_size = parts[1].parse().ok()?;
let range_parts: Vec<&str> = parts[0].split('-').collect();
if range_parts.len() == 2 {
let start = range_parts[0].parse().ok()?;
let end = range_parts[1].parse().ok()?;
return Some((start, end, total_size));
}
}
None
}

pub async fn get_package<'a>(
tman_config: &TmanConfig,
url: &str,
temp_path: &'a mut NamedTempFile,
temp_file: &'a mut NamedTempFile,
) -> Result<()> {
let client = reqwest::Client::new();

// Wrap the temp_file in an Rc<RefCell<_>> to allow mutable borrowing inside
// the async closure.
let temp_path = Rc::new(RefCell::new(temp_path));
let temp_file = Rc::new(RefCell::new(temp_file));

let max_retries = 3;
let retry_delay = Duration::from_millis(100);

let download_complete = Rc::new(RefCell::new(false));

// Pass the Rc<RefCell<>> into the retry logic
retry_async(tman_config, max_retries, retry_delay, || {
let client = client.clone();
let url = url.to_string();
let temp_path = Rc::clone(&temp_path); // Clone the Rc pointer.
let temp_file = Rc::clone(&temp_file); // Clone the Rc pointer.
let download_complete = Rc::clone(&download_complete);

Box::pin(async move {
let response =
client.get(&url).send().await.with_context(|| {
// Check the size of the file that has already been downloaded.
let temp_file_len = {
let temp_file_borrow = temp_file.borrow();
temp_file_borrow
.as_file()
.metadata()
.map(|metadata| metadata.len())
.unwrap_or(0)
};

// Set the Range header to support resumable downloads.
let mut headers = HeaderMap::new();
if temp_file_len > 0 {
headers.insert(
reqwest::header::RANGE,
format!("bytes={}-", temp_file_len).parse().unwrap(),
);
}

let response = client
.get(&url)
.headers(headers)
.send()
.await
.with_context(|| {
format!("Failed to send GET request to {}", url)
})?;

if !response.status().is_success() {
if !response.status().is_success()
&& response.status() != reqwest::StatusCode::PARTIAL_CONTENT
{
return Err(anyhow!(
"Failed to download the package: HTTP {}",
response.status()
));
}

// Get the headers of Content-Range or Content-Length.
let content_range = response
.headers()
.get(reqwest::header::CONTENT_RANGE)
.cloned();
let content_length = response
.headers()
.get(reqwest::header::CONTENT_LENGTH)
.cloned();

// Read the content of the response and append the newly downloaded
// part to the file.
let content = response
.bytes()
.await
.with_context(|| "Failed to read bytes from response")?;

if content.is_empty() {
return Err(anyhow!("No new content downloaded"));
}

// Mutably borrow temp_file inside the async block.
let mut temp_path_borrow = temp_path.borrow_mut();
let mut temp_file_borrow = temp_file.borrow_mut();

temp_path_borrow
temp_file_borrow
.as_file_mut()
.write_all(&content)
.with_context(|| "Failed to write content to temporary file")?;

tman_verbose_println!(
tman_config,
"Package downloaded successfully from {} and written to {}",
url,
temp_path_borrow.path().display()
);
// Check if we have downloaded the entire file.
if let Some(content_range) = content_range {
// Parse the content-range to check if download is complete.
let content_range_str = content_range.to_str().unwrap();
if let Some((_, _, total_size)) =
parse_content_range(content_range_str)
{
if temp_file_len + content.len() as u64 >= total_size {
*download_complete.borrow_mut() = true;
}
}
} else if content_length.is_some() {
// If there is no content-range but content-length exists, the
// download should be complete in one go.
*download_complete.borrow_mut() = true;
}

Ok(())
})
})
.await
.await?;

// Only print when `download_complete` is `true`.
if *download_complete.borrow() {
let temp_file_borrow = temp_file.borrow();
tman_verbose_println!(
tman_config,
"Package downloaded successfully from {} and written to {}",
url,
temp_file_borrow.path().display()
);
}

Ok(())
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down

0 comments on commit fd2427b

Please sign in to comment.