Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance basic authentication for Libra #701

Merged
merged 10 commits into from
Nov 25, 2024
22 changes: 8 additions & 14 deletions libra/src/command/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ use ceres::protocol::ServiceType::UploadPack;
use clap::Parser;
use indicatif::ProgressBar;
use mercury::internal::object::commit::Commit;
use mercury::{errors::GitError, hash::SHA1};
use mercury::hash::SHA1;
use tokio::io::{AsyncRead, AsyncReadExt};
use tokio_util::io::StreamReader;
use url::Url;

use crate::command::{ask_basic_auth, load_object};
use crate::command::load_object;
use crate::{
command::index_pack::{self, IndexPackArgs},
internal::{
Expand Down Expand Up @@ -89,20 +89,14 @@ pub async fn fetch_repository(remote_config: &RemoteConfig, branch: Option<Strin
};
let http_client = HttpsClient::from_url(&url);

let mut refs = http_client.discovery_reference(UploadPack, None).await;
let mut auth = None;
while let Err(e) = refs {
if let GitError::UnAuthorized(_) = e {
auth = Some(ask_basic_auth());
refs = http_client
.discovery_reference(UploadPack, auth.clone())
.await;
} else {
let refs = match http_client.discovery_reference(UploadPack).await {
Ok(refs) => refs,
Err(e) => {
eprintln!("fatal: {}", e);
return;
}
}
let refs = refs.unwrap();
};

if refs.is_empty() {
tracing::warn!("fetch empty, no refs found");
return;
Expand Down Expand Up @@ -133,7 +127,7 @@ pub async fn fetch_repository(remote_config: &RemoteConfig, branch: Option<Strin
let have = current_have().await; // TODO: return `DiscRef` rather than only hash, to compare `have` & `want` more accurately

let mut result_stream = http_client
.fetch_objects(&have, &want, auth.to_owned())
.fetch_objects(&have, &want)
.await
.unwrap();

Expand Down
40 changes: 14 additions & 26 deletions libra/src/command/lfs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::path::Path;
use reqwest::StatusCode;
use ceres::lfs::lfs_structs::LockListQuery;
use mercury::internal::index::Index;
use crate::command::{ask_basic_auth, status};
use crate::command::status;
use crate::internal::head::Head;
use crate::internal::protocol::lfs_client::LFSClient;
use crate::utils::{lfs, path, util};
Expand Down Expand Up @@ -80,7 +80,7 @@ pub async fn execute(cmd: LfsCmds) {
}
}
}
LfsCmds::Untrack { path } => {
LfsCmds::Untrack { path } => { // only remove totally same pattern with path ?
let path = convert_patterns_to_workdir(path); //
untrack_lfs_patterns(&attr_path, path).unwrap();
}
Expand Down Expand Up @@ -110,19 +110,13 @@ pub async fn execute(cmd: LfsCmds) {
}

let refspec = current_refspec().await.unwrap();
let mut auth = None;
loop {
let code = LFSClient::get().await.lock(path.clone(), refspec.clone(), auth.clone()).await;
if code.is_success() {
println!("Locked {}", path);
} else if code == StatusCode::FORBIDDEN {
eprintln!("Forbidden: You must have push access to create a lock");
auth = Some(ask_basic_auth());
continue;
} else if code == StatusCode::CONFLICT {
eprintln!("Conflict: already created lock");
}
break;
let code = LFSClient::get().await.lock(path.clone(), refspec.clone()).await;
if code.is_success() {
println!("Locked {}", path);
} else if code == StatusCode::FORBIDDEN {
eprintln!("Forbidden: You must have push access to create a lock");
} else if code == StatusCode::CONFLICT {
eprintln!("Conflict: already created lock");
}
}
LfsCmds::Unlock { path, force, id } => {
Expand Down Expand Up @@ -155,17 +149,11 @@ pub async fn execute(cmd: LfsCmds) {
}
Some(id) => id
};
let mut auth = None;
loop {
let code = LFSClient::get().await.unlock(id.clone(), refspec.clone(), force, auth.clone()).await;
if code.is_success() {
println!("Unlocked {}", path);
} else if code == StatusCode::FORBIDDEN {
eprintln!("Forbidden: You must have push access to unlock");
auth = Some(ask_basic_auth());
continue;
}
break;
let code = LFSClient::get().await.unlock(id.clone(), refspec.clone(), force).await;
if code.is_success() {
println!("Unlocked {}", path);
} else if code == StatusCode::FORBIDDEN {
eprintln!("Forbidden: You must have push access to unlock");
}
}
LfsCmds::LsFiles { long, size, name_only} => {
Expand Down
22 changes: 8 additions & 14 deletions libra/src/command/push.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,17 @@ use tokio::sync::mpsc;
use url::Url;
use ceres::protocol::ServiceType::ReceivePack;
use ceres::protocol::smart::{add_pkt_line_string, read_pkt_line};
use mercury::errors::GitError;
use mercury::hash::SHA1;
use mercury::internal::object::blob::Blob;
use mercury::internal::object::commit::Commit;
use mercury::internal::object::tree::{Tree, TreeItemMode};
use mercury::internal::pack::encode::PackEncoder;
use mercury::internal::pack::entry::Entry;
use crate::command::{ask_basic_auth, branch};
use crate::command::branch;
use crate::internal::branch::Branch;
use crate::internal::config::Config;
use crate::internal::head::Head;
use crate::internal::protocol::https_client::{BasicAuth, HttpsClient};
use crate::internal::protocol::https_client::HttpsClient;
use crate::internal::protocol::lfs_client::LFSClient;
use crate::internal::protocol::ProtocolClient;
use crate::utils::object_ext::{BlobExt, CommitExt, TreeExt};
Expand Down Expand Up @@ -74,18 +73,13 @@ pub async fn execute(args: PushArgs) {

let url = Url::parse(&repo_url).unwrap();
let client = HttpsClient::from_url(&url);
let mut refs = client.discovery_reference(ReceivePack, None).await;
let mut auth: Option<BasicAuth> = None;
while let Err(e) = refs { // retry if unauthorized
if let GitError::UnAuthorized(_) = e {
auth = Some(ask_basic_auth());
refs = client.discovery_reference(ReceivePack, auth.clone()).await;
} else {
let refs = match client.discovery_reference(ReceivePack).await {
Ok(refs) => refs,
Err(e) => {
eprintln!("fatal: {}", e);
return;
}
}
let refs = refs.unwrap();
};

let tracked_branch = Config::get("branch", Some(&branch), "merge")
.await // New branch may not have tracking branch
Expand Down Expand Up @@ -115,7 +109,7 @@ pub async fn execute(args: PushArgs) {

{ // upload lfs files
let client = LFSClient::from_url(&url);
let res = client.push_objects(&objs, auth.clone()).await;
let res = client.push_objects(&objs).await;
if res.is_err() {
eprintln!("fatal: LFS files upload failed, stop pushing");
return;
Expand Down Expand Up @@ -143,7 +137,7 @@ pub async fn execute(args: PushArgs) {
data.extend_from_slice(&pack_data);
println!("Delta compression done.");

let res = client.send_pack(data.freeze(), auth).await.unwrap(); // TODO: send stream
let res = client.send_pack(data.freeze()).await.unwrap(); // TODO: send stream

if res.status() != 200 {
eprintln!("status code: {}", res.status());
Expand Down
91 changes: 58 additions & 33 deletions libra/src/internal/protocol/https_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ use futures_util::{StreamExt, TryStreamExt};
use mercury::errors::GitError;
use mercury::hash::SHA1;
use reqwest::header::CONTENT_TYPE;
use reqwest::{Body, Response};
use reqwest::{Body, RequestBuilder, Response, StatusCode};
use std::io::Error as IoError;
use std::ops::Deref;
use std::sync::Mutex;
use tokio_util::bytes::BytesMut;
use url::Url;
use crate::command::ask_basic_auth;

/// A Git protocol client that communicates with a Git server over HTTPS.
/// Only support `SmartProtocol` now, see [http-protocol](https://www.git-scm.com/docs/http-protocol) for protocol details.
Expand Down Expand Up @@ -41,6 +44,41 @@ pub struct BasicAuth {
pub(crate) password: String,
}

impl BasicAuth {
/// send request with basic auth, retry 3 times
pub async fn send<Fut>(request_builder: impl Fn() -> Fut) -> Result<Response, reqwest::Error>
where
Fut: std::future::Future<Output=RequestBuilder>,
{
static AUTH: Mutex<Option<BasicAuth>> = Mutex::new(None);
const MAX_TRY: usize = 3;
let mut res;
let mut try_cnt = 0;
loop {
let mut request = request_builder().await; // RequestBuilder can't be cloned
if let Some(auth) = AUTH.lock().unwrap().deref() {
request = request.basic_auth(auth.username.clone(), Some(auth.password.clone()));
} // if no auth exists, try without auth (e.g. clone public)
res = request.send().await?;
if res.status() == StatusCode::FORBIDDEN { // 403: no access, no need to retry
eprintln!("Authentication failed, forbidden");
break;
} else if res.status() != StatusCode::UNAUTHORIZED {
break;
}
// 401 (Unauthorized): username or password is incorrect
if try_cnt >= MAX_TRY {
eprintln!("Failed to authenticate after {} attempts", MAX_TRY);
break;
}
eprintln!("Authentication required, retrying...");
AUTH.lock().unwrap().replace(ask_basic_auth());
try_cnt += 1;
}
Ok(res)
}
}

#[derive(Debug, Clone, PartialEq)]
pub struct DiscoveredReference {
pub(crate) _hash: String,
Expand All @@ -61,18 +99,13 @@ impl HttpsClient {
pub async fn discovery_reference(
&self,
service: ServiceType,
auth: Option<BasicAuth>,
) -> Result<Vec<DiscRef>, GitError> {
let service: &str = &service.to_string();
let url = self
.url
.join(&format!("info/refs?service={}", service))
.unwrap();
let mut request = self.client.get(url);
if let Some(auth) = auth {
request = request.basic_auth(auth.username, Some(auth.password));
}
let res = request.send().await.unwrap();
let res = BasicAuth::send(|| async{self.client.get(url.clone())}).await.unwrap();
tracing::debug!("{:?}", res);

if res.status() == 401 {
Expand Down Expand Up @@ -165,22 +198,19 @@ impl HttpsClient {
&self,
have: &Vec<String>,
want: &Vec<String>,
auth: Option<BasicAuth>,
) -> Result<impl StreamExt<Item = Result<Bytes, IoError>>, IoError> {
// POST $GIT_URL/git-upload-pack HTTP/1.0
let url = self.url.join("git-upload-pack").unwrap();
let body = generate_upload_pack_content(have, want).await;
tracing::debug!("fetch_objects with body: {:?}", body);

let mut req = self
.client
.post(url)
.header("Content-Type", "application/x-git-upload-pack-request")
.body(body);
if let Some(auth) = auth {
req = req.basic_auth(auth.username, Some(auth.password));
}
let res = req.send().await.unwrap();
let res = BasicAuth::send(|| async {
self
.client
.post(url.clone())
.header("Content-Type", "application/x-git-upload-pack-request")
.body(body.clone())
}).await.unwrap();
tracing::debug!("request: {:?}", res);

if res.status() != 200 && res.status() != 304 {
Expand All @@ -197,22 +227,17 @@ impl HttpsClient {
Ok(result)
}

pub async fn send_pack<T: Into<Body>>(
pub async fn send_pack<T: Into<Body> + Clone>(
&self,
data: T,
auth: Option<BasicAuth>,
) -> Result<Response, reqwest::Error> {
let mut request = self
.client
.post(self.url.join("git-receive-pack").unwrap())
.header(CONTENT_TYPE, "application/x-git-receive-pack-request")
.body(data);

if let Some(auth) = auth {
request = request.basic_auth(auth.username, Some(auth.password));
}

request.send().await
BasicAuth::send(|| async {
self
.client
.post(self.url.join("git-receive-pack").unwrap())
.header(CONTENT_TYPE, "application/x-git-receive-pack-request")
.body(data.clone())
}).await
}
}
/// for fetching
Expand Down Expand Up @@ -259,7 +284,7 @@ mod tests {
let test_repo = "https://github.com/web3infra-foundation/mega.git/";

let client = HttpsClient::from_url(&Url::parse(test_repo).unwrap());
let refs = client.discovery_reference(UploadPack, None).await;
let refs = client.discovery_reference(UploadPack).await;
if refs.is_err() {
tracing::error!("{:?}", refs.err().unwrap());
panic!();
Expand All @@ -276,7 +301,7 @@ mod tests {

let test_repo = "https://github.com/web3infra-foundation/mega/";
let client = HttpsClient::from_url(&Url::parse(test_repo).unwrap());
let refs = client.discovery_reference(UploadPack, None).await.unwrap();
let refs = client.discovery_reference(UploadPack).await.unwrap();
let refs: Vec<DiscoveredReference> = refs
.iter()
.filter(|r| r._ref.starts_with("refs/heads"))
Expand All @@ -287,7 +312,7 @@ mod tests {
let want = refs.iter().map(|r| r._hash.clone()).collect();

let have = vec!["81a162e7b725bbad2adfe01879fd57e0119406b9".to_string()];
let mut result_stream = client.fetch_objects(&have, &want, None).await.unwrap();
let mut result_stream = client.fetch_objects(&have, &want).await.unwrap();

let mut buffer = vec![];
while let Some(item) = result_stream.next().await {
Expand Down
Loading
Loading