From fd2427b337e47290e6f44de4576f6a3b602d0d83 Mon Sep 17 00:00:00 2001 From: Hu Yueh-Wei Date: Fri, 25 Oct 2024 13:27:57 +0700 Subject: [PATCH] feat: add resume-download in tman (#203) --- core/src/ten_manager/src/registry/remote.rs | 113 +++++++++++++++++--- 1 file changed, 98 insertions(+), 15 deletions(-) diff --git a/core/src/ten_manager/src/registry/remote.rs b/core/src/ten_manager/src/registry/remote.rs index fe272530a..860c48be4 100644 --- a/core/src/ten_manager/src/registry/remote.rs +++ b/core/src/ten_manager/src/registry/remote.rs @@ -306,62 +306,145 @@ pub async fn upload_package( Ok(()) } +// Parse the header of content-range. +fn parse_content_range(content_range: &str) -> Option<(u64, u64, u64)> { + let parts: Vec<&str> = content_range.split('/').collect(); + if parts.len() == 2 { + let total_size = parts[1].parse().ok()?; + let range_parts: Vec<&str> = parts[0].split('-').collect(); + if range_parts.len() == 2 { + let start = range_parts[0].parse().ok()?; + let end = range_parts[1].parse().ok()?; + return Some((start, end, total_size)); + } + } + None +} + pub async fn get_package<'a>( tman_config: &TmanConfig, url: &str, - temp_path: &'a mut NamedTempFile, + temp_file: &'a mut NamedTempFile, ) -> Result<()> { let client = reqwest::Client::new(); // Wrap the temp_file in an Rc> to allow mutable borrowing inside // the async closure. - let temp_path = Rc::new(RefCell::new(temp_path)); + let temp_file = Rc::new(RefCell::new(temp_file)); let max_retries = 3; let retry_delay = Duration::from_millis(100); + let download_complete = Rc::new(RefCell::new(false)); + // Pass the Rc> into the retry logic retry_async(tman_config, max_retries, retry_delay, || { let client = client.clone(); let url = url.to_string(); - let temp_path = Rc::clone(&temp_path); // Clone the Rc pointer. + let temp_file = Rc::clone(&temp_file); // Clone the Rc pointer. + let download_complete = Rc::clone(&download_complete); Box::pin(async move { - let response = - client.get(&url).send().await.with_context(|| { + // Check the size of the file that has already been downloaded. + let temp_file_len = { + let temp_file_borrow = temp_file.borrow(); + temp_file_borrow + .as_file() + .metadata() + .map(|metadata| metadata.len()) + .unwrap_or(0) + }; + + // Set the Range header to support resumable downloads. + let mut headers = HeaderMap::new(); + if temp_file_len > 0 { + headers.insert( + reqwest::header::RANGE, + format!("bytes={}-", temp_file_len).parse().unwrap(), + ); + } + + let response = client + .get(&url) + .headers(headers) + .send() + .await + .with_context(|| { format!("Failed to send GET request to {}", url) })?; - if !response.status().is_success() { + if !response.status().is_success() + && response.status() != reqwest::StatusCode::PARTIAL_CONTENT + { return Err(anyhow!( "Failed to download the package: HTTP {}", response.status() )); } + // Get the headers of Content-Range or Content-Length. + let content_range = response + .headers() + .get(reqwest::header::CONTENT_RANGE) + .cloned(); + let content_length = response + .headers() + .get(reqwest::header::CONTENT_LENGTH) + .cloned(); + + // Read the content of the response and append the newly downloaded + // part to the file. let content = response .bytes() .await .with_context(|| "Failed to read bytes from response")?; + if content.is_empty() { + return Err(anyhow!("No new content downloaded")); + } + // Mutably borrow temp_file inside the async block. - let mut temp_path_borrow = temp_path.borrow_mut(); + let mut temp_file_borrow = temp_file.borrow_mut(); - temp_path_borrow + temp_file_borrow + .as_file_mut() .write_all(&content) .with_context(|| "Failed to write content to temporary file")?; - tman_verbose_println!( - tman_config, - "Package downloaded successfully from {} and written to {}", - url, - temp_path_borrow.path().display() - ); + // Check if we have downloaded the entire file. + if let Some(content_range) = content_range { + // Parse the content-range to check if download is complete. + let content_range_str = content_range.to_str().unwrap(); + if let Some((_, _, total_size)) = + parse_content_range(content_range_str) + { + if temp_file_len + content.len() as u64 >= total_size { + *download_complete.borrow_mut() = true; + } + } + } else if content_length.is_some() { + // If there is no content-range but content-length exists, the + // download should be complete in one go. + *download_complete.borrow_mut() = true; + } Ok(()) }) }) - .await + .await?; + + // Only print when `download_complete` is `true`. + if *download_complete.borrow() { + let temp_file_borrow = temp_file.borrow(); + tman_verbose_println!( + tman_config, + "Package downloaded successfully from {} and written to {}", + url, + temp_file_borrow.path().display() + ); + } + + Ok(()) } #[derive(Debug, Serialize, Deserialize)]