Skip to content

Commit

Permalink
add new routers for handle oauth requests
Browse files Browse the repository at this point in the history
  • Loading branch information
benjamin-747 committed Jul 27, 2024
1 parent c352e3b commit abba553
Show file tree
Hide file tree
Showing 9 changed files with 270 additions and 7 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,4 @@ config = "0.14.0"
shadow-rs = "0.30.0"
reqwest = "0.12.5"
lazy_static = "1.5.0"
uuid = "1.10.0"
20 changes: 19 additions & 1 deletion common/src/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//! consistency, especially when multiple modules need to work with the same
//! set of enum variants.
use std::str::FromStr;

use clap::ValueEnum;

Expand All @@ -17,7 +18,7 @@ pub enum ZtmType {
Relay,
}

impl std::str::FromStr for ZtmType {
impl FromStr for ZtmType {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Expand All @@ -28,3 +29,20 @@ impl std::str::FromStr for ZtmType {
}
}
}

/// An enum representing different oauth types.
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub enum SupportOauthType {
GitHub,
}

impl FromStr for SupportOauthType {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"github" => Ok(Self::GitHub),
_ => Err(format!("'{}' is not a valid oauth type", s)),
}
}
}
2 changes: 2 additions & 0 deletions gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,7 @@ tower-http = { workspace = true, features = [
"decompression-full",
] }
tokio = { workspace = true, features = ["net"] }
reqwest = { workspace = true, features = ["json"] }
uuid = { workspace = true, features = ["v4"] }
regex = "1.10.4"
ed25519-dalek = { version = "2.1.1", features = ["pkcs8"] }
1 change: 1 addition & 0 deletions gateway/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use venus::import_repo::repo::Repo;

pub mod api_router;
pub mod mr_router;
pub mod oauth;

#[derive(Clone)]
pub struct ApiServiceState {
Expand Down
72 changes: 72 additions & 0 deletions gateway/src/api/oauth/github.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use axum::async_trait;

use common::errors::MegaError;
use jupiter::context::Context;

use crate::api::oauth::model::{AuthorizeParams, GitHubAccessTokenJson, OauthCallbackParams};
use crate::api::oauth::OauthHandler;

use super::model::GitHubUserJson;

#[derive(Clone)]
pub struct GithubOauthService {
pub context: Context,
pub client_id: String,
pub client_secret: String,
}

const GITHUB_ENDPOINT: &str = "https://github.com";
const GITHUB_API_ENDPOINT: &str = "https://api.github.com";

#[async_trait]
impl OauthHandler for GithubOauthService {
fn authorize_url(&self, params: &AuthorizeParams, state: &str) -> String {
let auth_url = format!(
"https://github.com/login/oauth/authorize?client_id={}&redirect_uri={}&state={}",
self.client_id, params.redirect_uri, state
);
auth_url
}

async fn access_token(
&self,
params: OauthCallbackParams,
redirect_uri: &str,
) -> Result<String, MegaError> {
tracing::debug!("{:?}", params);
// get access_token and user for persist
let url = format!(
"{}/login/oauth/access_token?client_id={}&client_secret={}&code={}&redirect_uri={}",
GITHUB_ENDPOINT, self.client_id, self.client_secret, params.code, redirect_uri
);
let client = reqwest::Client::new();
let resp = client
.post(url)
.header("Accept", "application/json")
.send()
.await
.unwrap();
let access_token = resp
.json::<GitHubAccessTokenJson>()
.await
.unwrap()
.access_token;
Ok(access_token)
}

async fn user_info(&self, access_token: &str) -> Result<GitHubUserJson, MegaError> {
let user_url = format!("{}/user", GITHUB_API_ENDPOINT);
let client = reqwest::Client::new();
let resp = client
.get(user_url)
.header("Authorization", format!("Bearer {}", access_token))
.header("Accept", "application/json")
.header("User-Agent", format!("Mega/{}", "0.0.1"))
.send()
.await
.unwrap();
// tracing::debug!("user_resp: {:?}", resp.text().await.unwrap());
let user_info = resp.json::<GitHubUserJson>().await.unwrap();
Ok(user_info)
}
}
126 changes: 126 additions & 0 deletions gateway/src/api/oauth/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
use std::collections::HashMap;
use std::sync::Arc;

use axum::async_trait;
use axum::response::Redirect;
use axum::{
extract::{Path, Query, State},
http::StatusCode,
routing::get,
Json, Router,
};
use tokio::sync::Mutex;
use uuid::Uuid;

use common::enums::SupportOauthType;
use common::errors::MegaError;
use github::GithubOauthService;
use jupiter::context::Context;
use model::{AuthorizeParams, GitHubUserJson, OauthCallbackParams};

pub mod github;
pub mod model;

#[derive(Clone)]
pub struct OauthServiceState {
pub context: Context,
pub sessions: Arc<Mutex<HashMap<String, String>>>,
}

impl OauthServiceState {
pub fn oauth_handler(&self, ouath_type: SupportOauthType) -> impl OauthHandler {
match ouath_type {
SupportOauthType::GitHub => GithubOauthService {
context: self.context.clone(),
client_id: String::from("Ov23li3y0koaZFzk8CUE"),
client_secret: String::from("58babfc9794ca137feff59c57c82ef6f5318ec37"),
},
}
}
}

#[async_trait]
pub trait OauthHandler: Send + Sync {
fn authorize_url(&self, params: &AuthorizeParams, state: &str) -> String;

async fn access_token(
&self,
params: OauthCallbackParams,
redirect_uri: &str,
) -> Result<String, MegaError>;

async fn user_info(&self, access_token: &str) -> Result<GitHubUserJson, MegaError>;
}

pub fn routers() -> Router<OauthServiceState> {
Router::new()
.route("/:oauth_type/authorize", get(redirect_authorize))
.route("/:oauth_type/callback", get(oauth_callback))
.route("/:oauth_type/user", get(user))
}

async fn redirect_authorize(
Path(oauth_type): Path<String>,
Query(query): Query<AuthorizeParams>,
service_state: State<OauthServiceState>,
) -> Result<Redirect, (StatusCode, String)> {
let oauth_type: SupportOauthType = match oauth_type.parse::<SupportOauthType>() {
Ok(value) => value,
Err(err) => return Err((StatusCode::BAD_REQUEST, err)),
};

let mut sessions = service_state.sessions.lock().await;
let state = Uuid::new_v4().to_string();
sessions.insert(state.clone(), query.redirect_uri.clone());
let auth_url = service_state
.oauth_handler(oauth_type)
.authorize_url(&query, &state);
Ok(Redirect::temporary(&auth_url))
}

async fn oauth_callback(
Path(oauth_type): Path<String>,
Query(query): Query<OauthCallbackParams>,
service_state: State<OauthServiceState>,
) -> Result<Redirect, (StatusCode, String)> {
let oauth_type: SupportOauthType = match oauth_type.parse::<SupportOauthType>() {
Ok(value) => value,
Err(err) => return Err((StatusCode::BAD_REQUEST, err)),
};
// chcek state,
// TODO storage can be replaced by redis, otherwise invalid state can't be expired
let mut sessions = service_state.sessions.lock().await;

let redirect_uri = match sessions.get(&query.state) {
Some(uri) => uri.clone(),
None => return Err((StatusCode::BAD_REQUEST, "Invalid state".to_string())),
};
let access_token = service_state
.oauth_handler(oauth_type)
.access_token(query.clone(), &redirect_uri)
.await
.unwrap();
sessions.remove(&query.state);

let callback_url = format!("{}/login?access_token={}", redirect_uri, access_token);
Ok(Redirect::temporary(&callback_url))
}

async fn user(
Path(oauth_type): Path<String>,
Query(query): Query<HashMap<String, String>>,
service_state: State<OauthServiceState>,
) -> Result<Json<GitHubUserJson>, (StatusCode, String)> {
let oauth_type: SupportOauthType = match oauth_type.parse::<SupportOauthType>() {
Ok(value) => value,
Err(err) => return Err((StatusCode::BAD_REQUEST, err)),
};
let access_token = query.get("access_token").unwrap();

let res = service_state
.oauth_handler(oauth_type)
.user_info(access_token)
.await
.unwrap();
Ok(Json(res))
}
28 changes: 28 additions & 0 deletions gateway/src/api/oauth/model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, Debug)]
pub struct AuthorizeParams {
pub redirect_uri: String,
}

#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct OauthCallbackParams {
pub code: String,
pub state: String,
}


#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct GitHubAccessTokenJson {
pub access_token: String,
pub scope: Option<String>,
pub token_type: String,
}

#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct GitHubUserJson {
pub login: String,
pub id: u32,
pub avatar_url: String,
pub email: String,
}
25 changes: 20 additions & 5 deletions gateway/src/https_server.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::net::SocketAddr;
use std::ops::Deref;
use std::path::PathBuf;
Expand All @@ -14,10 +15,8 @@ use axum::routing::get;
use axum::Router;
use axum_server::tls_rustls::RustlsConfig;
use clap::Args;
use common::enums::ZtmType;
use gemini::ztm::agent::{run_ztm_client, LocalZTMAgent};
use gemini::ztm::hub::LocalZTMHub;
use regex::Regex;
use tokio::sync::Mutex;
use tower::ServiceBuilder;
use tower_http::cors::{Any, CorsLayer};
use tower_http::decompression::RequestDecompressionLayer;
Expand All @@ -26,11 +25,15 @@ use tower_http::trace::TraceLayer;
use ceres::lfs::LfsConfig;
use ceres::protocol::{SmartProtocol, TransportProtocol};
use common::config::Config;
use common::enums::ZtmType;
use common::model::{CommonOptions, GetParams};
use gemini::ztm::agent::{run_ztm_client, LocalZTMAgent};
use gemini::ztm::hub::LocalZTMHub;
use jupiter::context::Context;
use jupiter::raw_storage::local_storage::LocalStorage;

use crate::api::api_router::{self};
use crate::api::oauth::{self, OauthServiceState};
use crate::api::ApiServiceState;
use crate::ca_server::run_ca_server;
use crate::lfs;
Expand Down Expand Up @@ -140,13 +143,25 @@ pub async fn app(config: Config, host: String, port: u16, common: CommonOptions)
common: common.clone(),
};

let api_state = ApiServiceState { context };
let api_state = ApiServiceState {
context: context.clone(),
};

// add RequestDecompressionLayer for handle gzip encode
// add TraceLayer for log record
// add CorsLayer to add cors header
Router::new()
.nest("/api/v1", api_router::routers().with_state(api_state))
.nest(
"/api/v1",
api_router::routers().with_state(api_state.clone()),
)
.nest(
"/auth",
oauth::routers().with_state(OauthServiceState {
context,
sessions: Arc::new(Mutex::new(HashMap::new())),
}),
)
.route(
"/*path",
get(get_method_router)
Expand Down
2 changes: 1 addition & 1 deletion mercury/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ sha1 = { workspace = true }
colored = { workspace = true }
chrono = { workspace = true }
tracing-subscriber = { workspace = true }
uuid = { version = "1.7.0", features = ["v4"] }
uuid = { workspace = true, features = ["v4"] }
sha1_smol = "1.0.0"
threadpool = "1.8.1"
num_cpus.workspace = true
Expand Down

0 comments on commit abba553

Please sign in to comment.