Skip to content

update Libra LFS object download for Moly #647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libra/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ ring = "0.17.8"
hex = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
async_static = "0.1.3"
once_cell = "1.19.0"
byte-unit = "5.1.4"
scopeguard = "1.2.0"
lru-mem = "0.3.0"
anyhow = { workspace = true }

[target.'cfg(unix)'.dependencies] # only on Unix
pager = "0.16.0"
Expand Down
10 changes: 5 additions & 5 deletions libra/src/command/lfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use ceres::lfs::lfs_structs::LockListQuery;
use mercury::internal::index::Index;
use crate::command::{ask_basic_auth, status};
use crate::internal::head::Head;
use crate::internal::protocol::lfs_client::LFS_CLIENT;
use crate::internal::protocol::lfs_client::LFSClient;
use crate::utils::{lfs, path, util};
use crate::utils::path_ext::PathExt;

Expand Down Expand Up @@ -94,7 +94,7 @@ pub async fn execute(cmd: LfsCmds) {
cursor: "".to_string(),
refspec,
};
let locks = LFS_CLIENT.await.get_locks(query).await.locks;
let locks = LFSClient::get().await.get_locks(query).await.locks;
if !locks.is_empty() {
let max_path_len = locks.iter().map(|l| l.path.len()).max().unwrap();
for lock in locks {
Expand All @@ -112,7 +112,7 @@ pub async fn execute(cmd: LfsCmds) {
let refspec = current_refspec().await.unwrap();
let mut auth = None;
loop {
let code = LFS_CLIENT.await.lock(path.clone(), refspec.clone(), auth.clone()).await;
let code = LFSClient::get().await.lock(path.clone(), refspec.clone(), auth.clone()).await;
if code.is_success() {
println!("Locked {}", path);
} else if code == StatusCode::FORBIDDEN {
Expand Down Expand Up @@ -140,7 +140,7 @@ pub async fn execute(cmd: LfsCmds) {
let id = match id {
None => {
// get id by path
let locks = LFS_CLIENT.await.get_locks(LockListQuery {
let locks = LFSClient::get().await.get_locks(LockListQuery {
refspec: refspec.clone(),
path: path.clone(),
id: "".to_string(),
Expand All @@ -157,7 +157,7 @@ pub async fn execute(cmd: LfsCmds) {
};
let mut auth = None;
loop {
let code = LFS_CLIENT.await.unlock(id.clone(), refspec.clone(), force, auth.clone()).await;
let code = LFSClient::get().await.unlock(id.clone(), refspec.clone(), force, auth.clone()).await;
if code.is_success() {
println!("Unlocked {}", path);
} else if code == StatusCode::FORBIDDEN {
Expand Down
6 changes: 4 additions & 2 deletions libra/src/command/restore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ use clap::Parser;
use std::collections::{HashMap, HashSet};
use std::{fs, io};
use std::path::PathBuf;
use crate::internal::protocol::lfs_client::LFSClient;
use mercury::hash::SHA1;
use mercury::internal::object::blob::Blob;
use mercury::internal::object::commit::Commit;
use mercury::internal::object::tree::Tree;
use mercury::internal::object::types::ObjectType;
use crate::command::calc_file_blob_hash;
use crate::internal::protocol::lfs_client::LFS_CLIENT;

#[derive(Parser, Debug)]
pub struct RestoreArgs {
Expand Down Expand Up @@ -154,7 +154,9 @@ async fn restore_to_file(hash: &SHA1, path: &PathBuf) -> io::Result<()> {
fs::copy(&lfs_obj_path, &path_abs)?;
} else {
// not exist, download from server
LFS_CLIENT.await.download_object(&oid, size, &path_abs).await;
if let Err(e) = LFSClient::get().await.download_object(&oid, size, &path_abs, None).await {
eprintln!("fatal: {}", e);
}
}
}
None => {
Expand Down
136 changes: 106 additions & 30 deletions libra/src/internal/protocol/lfs_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use crate::internal::config::Config;
use crate::internal::protocol::https_client::BasicAuth;
use crate::internal::protocol::ProtocolClient;
use crate::utils::lfs;
use async_static::async_static;
use ceres::lfs::lfs_structs::{BatchRequest, FetchchunkResponse, Link, LockList, LockListQuery, LockRequest, Ref, Representation, RequestVars, UnlockRequest, VerifiableLockList, VerifiableLockRequest};
use ceres::lfs::lfs_structs::{BatchRequest, ChunkRepresentation, FetchchunkResponse, LockList, LockListQuery, LockRequest, Ref, Representation, RequestVars, UnlockRequest, VerifiableLockList, VerifiableLockRequest};
use futures_util::StreamExt;
use mercury::internal::object::types::ObjectType;
use mercury::internal::pack::entry::Entry;
Expand All @@ -13,18 +12,27 @@ use ring::digest::{Context, SHA256};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::path::Path;
use tokio::io::AsyncWriteExt;
use anyhow::anyhow;
use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt};
use tokio::sync::OnceCell;
use url::Url;

async_static! {
pub static ref LFS_CLIENT: LFSClient = LFSClient::new().await;
}

#[derive(Debug)]
pub struct LFSClient {
pub batch_url: Url,
pub lfs_url: Url,
pub client: Client,
}
static LFS_CLIENT: OnceCell<LFSClient> = OnceCell::const_new();
impl LFSClient {
/// Get LFSClient instance
/// - DO NOT use `async_static!`: No IDE Code Completion & lagging
pub async fn get() -> &'static LFSClient {
LFS_CLIENT.get_or_init(|| async {
LFSClient::new().await
}).await
}
}

/// see [successful-responses](https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md#successful-responses)
#[derive(Serialize, Deserialize)]
Expand Down Expand Up @@ -208,8 +216,31 @@ impl LFSClient {
Ok(())
}

/// Just for resume download
async fn update_file_checksum(file: &mut tokio::fs::File, checksum: &mut Context) {
file.seek(tokio::io::SeekFrom::Start(0)).await.unwrap();
let mut buf = [0u8; 8192];
loop {
let n = file.read(&mut buf).await.unwrap();
if n == 0 {
break;
}
checksum.update(&buf[..n]);
}
}

#[allow(clippy::type_complexity)]
/// download (GET) one LFS file from remote server
pub async fn download_object(&self, oid: &str, size: u64, path: impl AsRef<Path>) {
pub async fn download_object(
&self,
oid: &str,
size: u64,
path: impl AsRef<Path>,
mut reporter: Option<(
&mut (dyn FnMut(f64) -> anyhow::Result<()> + Send), // progress callback
f64 // step
)>) -> anyhow::Result<()>
{
let batch_request = BatchRequest {
operation: "download".to_string(),
transfers: vec![lfs::LFS_TRANSFER_API.to_string()],
Expand All @@ -225,68 +256,113 @@ impl LFSClient {
.post(self.batch_url.clone())
.json(&batch_request)
.headers(lfs::LFS_HEADERS.clone());
let response = request.send().await.unwrap();
let response = request.send().await?;

let text = response.text().await.unwrap();
tracing::debug!("LFS download response:\n {:#?}", serde_json::from_str::<serde_json::Value>(&text).unwrap());
let resp = serde_json::from_str::<LfsBatchResponse>(&text).unwrap();
let text = response.text().await?;
tracing::debug!("LFS download response:\n {:#?}", serde_json::from_str::<serde_json::Value>(&text)?);
let resp = serde_json::from_str::<LfsBatchResponse>(&text)?;

let link = resp.objects[0].actions.as_ref().unwrap().get("download").unwrap();

let mut is_chunked = false;
// Chunk API
let links = match self.fetch_chunk_links(&link.href).await {
let mut chunk_size = None; // infer that all chunks are the same size!
let links = match self.fetch_chunks(&link.href).await {
Ok(chunks) => {
is_chunked = true;
chunk_size = chunks.first().map(|c| c.size);
tracing::info!("LFS Chunk API supported.");
chunks
chunks.into_iter().map(|c| c.link).collect()
},
Err(_) => vec![link.clone()],
};

let mut file = tokio::fs::File::create(path).await.unwrap();
let mut checksum = Context::new(&SHA256);
let mut got_parts = 0;
let mut file = if links.len() <= 1 || lfs::parse_pointer_file(&path).is_ok() {
// pointer file or Not Chunks, truncate
tokio::fs::File::create(path).await?
} else {
// for Chunks, calc offset to resume download
let mut file = tokio::fs::File::options()
.write(true)
.read(true)
.create(true)
.truncate(false)
.open(&path).await?;
let file_len = file.metadata().await?.len();
if file_len > size {
println!("Local file size is larger than remote, truncate to 0.");
file.set_len(0).await?; // clear
file.seek(tokio::io::SeekFrom::Start(0)).await?;
} else if file_len > 0 {
let chunk_size = chunk_size.unwrap() as u64;
got_parts = file_len / chunk_size;
let file_offset = got_parts * chunk_size;
println!("Resume download from offset: {}, part: {}", file_offset, got_parts + 1);
file.set_len(file_offset).await?; // truncate
Self::update_file_checksum(&mut file, &mut checksum).await; // resume checksum
file.seek(tokio::io::SeekFrom::End(0)).await?;
}
file
};

println!("Downloading LFS file: {}", oid);
let mut cnt = 0;
let total = links.len();
for link in links {
cnt += 1;
let parts = links.len();
let mut downloaded: u64 = file.metadata().await?.len();
let mut last_progress = 0.0;
let start_part = got_parts as usize;
for link in links.iter().skip(start_part) {
got_parts += 1;
if is_chunked {
println!("- part: {}/{}", cnt, total);
println!("- part: {}/{}", got_parts, parts);
}

let mut request = self.client.get(&link.href);
for (k, v) in &link.header {
request = request.header(k, v);
}

let response = request.send().await.unwrap();
let response = request.send().await?;
if !response.status().is_success() {
eprintln!("fatal: LFS download failed. Status: {}, Message: {}", response.status(), response.text().await.unwrap());
return;
eprintln!("fatal: LFS download failed. Status: {}, Message: {}", response.status(), response.text().await?);
return Err(anyhow!("LFS download failed."));
}

let mut stream = response.bytes_stream();

while let Some(chunk) = stream.next().await { // TODO: progress bar
let chunk = chunk.unwrap();
file.write_all(&chunk).await.unwrap();
while let Some(chunk) = stream.next().await { // TODO: progress bar TODO: multi-thread or async
let chunk = chunk?;
file.write_all(&chunk).await?;
checksum.update(&chunk);

// report progress
if let Some((ref mut report_fn, step)) = reporter {
downloaded += chunk.len() as u64;
let progress = (downloaded as f64 / size as f64) * 100.0;
if progress >= last_progress + step {
last_progress = progress;
report_fn(progress)?;
}
}
}
}
let checksum = hex::encode(checksum.finish().as_ref());
if checksum == oid {
println!("Downloaded.");
Ok(())
} else {
eprintln!("fatal: LFS download failed. Checksum mismatch: {} != {}. Fallback to pointer file.", checksum, oid);
let pointer = lfs::format_pointer_string(oid, size);
file.set_len(0).await.unwrap(); // clear
file.write_all(pointer.as_bytes()).await.unwrap();
file.set_len(0).await?; // clear
file.seek(tokio::io::SeekFrom::Start(0)).await?; // ensure
file.write_all(pointer.as_bytes()).await?;
Err(anyhow!("Checksum mismatch, fallback to pointer file."))
}
}

/// Only for MonoRepo (mega)
async fn fetch_chunk_links(&self, obj_link: &str) -> Result<Vec<Link>, ()> {
async fn fetch_chunks(&self, obj_link: &str) -> Result<Vec<ChunkRepresentation>, ()> {
let mut url = Url::parse(obj_link).unwrap();
let path = url.path().trim_end_matches('/');
url.set_path(&(path.to_owned() + "/chunks")); // reserve query params (for GitHub link)
Expand All @@ -304,7 +380,7 @@ impl LFSClient {
let mut res = resp.json::<FetchchunkResponse>().await.unwrap();
// sort by offset
res.chunks.sort_by(|a, b| a.offset.cmp(&b.offset));
Ok(res.chunks.into_iter().map(|c| c.link).collect())
Ok(res.chunks)
}
}

Expand Down
25 changes: 24 additions & 1 deletion libra/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use mercury::errors::GitError;

mod command;
mod internal;
pub mod internal;
mod utils;
pub mod cli;

Expand Down Expand Up @@ -34,4 +34,27 @@ mod tests {
std::env::set_current_dir(tmp_dir.path()).unwrap();
exec(vec!["init"]).unwrap();
}

#[tokio::test]
async fn test_lfs_client() {
use url::Url;
use crate::internal::protocol::lfs_client::LFSClient;
use crate::internal::protocol::ProtocolClient;

let client = LFSClient::from_url(&Url::parse("https://git.gitmono.org").unwrap());
println!("{:?}", client);
let mut report_fn = |progress: f64| {
println!("progress: {:.2}%", progress);
Ok(())
};
client.download_object(
"a744b4beab939d899e22c8a070b7041a275582fb942483c9436d455173c7e23d",
338607424,
"/home/bean/projects/tmp/Qwen2.5-0.5B-Instruct-Q2_K.gguf",
Some((
&mut report_fn,
0.1
))
).await.expect("Failed to download object");
}
}
2 changes: 2 additions & 0 deletions mega/tests/lfs_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ fn lfs_split_with_git() {
git_lfs_clone(url).expect("Failed to clone large file from mega server");

mega.kill().expect("Failed to kill mega server");
thread::sleep(Duration::from_secs(1)); // wait for server to stop, avoiding affecting other tests
}

#[test]
Expand All @@ -241,4 +242,5 @@ fn lfs_split_with_libra() {
libra_lfs_clone(url).expect("(libra)Failed to clone large file from mega server");

mega.kill().expect("Failed to kill mega server");
thread::sleep(Duration::from_secs(1)); // wait for server to stop, avoiding affecting other tests
}
Loading