Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support GitHub OAuth login #483

Merged
merged 3 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ go-defer = "0.1.0"
russh = "0.44.0"
russh-keys = "0.44.0"
axum = "0.7.5"
axum-extra = "0.9.3"
tower-http = "0.5.2"
tower = "0.4.13"
hex = "0.4.3"
Expand All @@ -70,3 +71,4 @@ config = "0.14.0"
shadow-rs = "0.30.0"
reqwest = "0.12.5"
lazy_static = "1.5.0"
uuid = "1.10.0"
7 changes: 7 additions & 0 deletions common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub struct Config {
pub pack: PackConfig,
pub ztm: ZTMConfig,
pub lfs: LFSConfig,
pub oauth: OauthConfig,
}

impl Config {
Expand Down Expand Up @@ -257,3 +258,9 @@ impl Default for LFSConfig {
}
}
}

#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct OauthConfig {
pub github_client_id: String,
pub github_client_secret: String,
}
20 changes: 19 additions & 1 deletion common/src/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//! consistency, especially when multiple modules need to work with the same
//! set of enum variants.

use std::str::FromStr;

use clap::ValueEnum;

Expand All @@ -17,7 +18,7 @@ pub enum ZtmType {
Relay,
}

impl std::str::FromStr for ZtmType {
impl FromStr for ZtmType {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Expand All @@ -28,3 +29,20 @@ impl std::str::FromStr for ZtmType {
}
}
}

/// An enum representing different oauth types.
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub enum SupportOauthType {
GitHub,
}

impl FromStr for SupportOauthType {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"github" => Ok(Self::GitHub),
_ => Err(format!("'{}' is not a valid oauth type", s)),
}
}
}
3 changes: 3 additions & 0 deletions gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ tower-http = { workspace = true, features = [
"trace",
"decompression-full",
] }
axum-extra = { workspace = true, features = ["typed-header"]}
tokio = { workspace = true, features = ["net"] }
reqwest = { workspace = true, features = ["json"] }
uuid = { workspace = true, features = ["v4"] }
regex = "1.10.4"
ed25519-dalek = { version = "2.1.1", features = ["pkcs8"] }
1 change: 1 addition & 0 deletions gateway/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use venus::import_repo::repo::Repo;

pub mod api_router;
pub mod mr_router;
pub mod oauth;

#[derive(Clone)]
pub struct ApiServiceState {
Expand Down
72 changes: 72 additions & 0 deletions gateway/src/api/oauth/github.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use axum::async_trait;

use common::errors::MegaError;
use jupiter::context::Context;

use crate::api::oauth::model::{AuthorizeParams, GitHubAccessTokenJson, OauthCallbackParams};
use crate::api::oauth::OauthHandler;

use super::model::GitHubUserJson;

#[derive(Clone)]
pub struct GithubOauthService {
pub context: Context,
pub client_id: String,
pub client_secret: String,
}

const GITHUB_ENDPOINT: &str = "https://github.com";
const GITHUB_API_ENDPOINT: &str = "https://api.github.com";

#[async_trait]
impl OauthHandler for GithubOauthService {
fn authorize_url(&self, params: &AuthorizeParams, state: &str) -> String {
let auth_url = format!(
"https://github.com/login/oauth/authorize?client_id={}&redirect_uri={}&state={}",
self.client_id, params.redirect_uri, state
);
auth_url
}

async fn access_token(
&self,
params: OauthCallbackParams,
redirect_uri: &str,
) -> Result<String, MegaError> {
tracing::debug!("{:?}", params);
// get access_token and user for persist
let url = format!(
"{}/login/oauth/access_token?client_id={}&client_secret={}&code={}&redirect_uri={}",
GITHUB_ENDPOINT, self.client_id, self.client_secret, params.code, redirect_uri
);
let client = reqwest::Client::new();
let resp = client
.post(url)
.header("Accept", "application/json")
.send()
.await
.unwrap();
let access_token = resp
.json::<GitHubAccessTokenJson>()
.await
.unwrap()
.access_token;
Ok(access_token)
}

async fn user_info(&self, access_token: &str) -> Result<GitHubUserJson, MegaError> {
let user_url = format!("{}/user", GITHUB_API_ENDPOINT);
let client = reqwest::Client::new();
let resp = client
.get(user_url)
.header("Authorization", format!("Bearer {}", access_token))
.header("Accept", "application/json")
.header("User-Agent", format!("Mega/{}", "0.0.1"))
.send()
.await
.unwrap();
// tracing::debug!("user_resp: {:?}", resp.text().await.unwrap());
let user_info = resp.json::<GitHubUserJson>().await.unwrap();
Ok(user_info)
}
}
127 changes: 127 additions & 0 deletions gateway/src/api/oauth/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
use std::collections::HashMap;
use std::sync::Arc;

use axum::async_trait;
use axum::response::Redirect;
use axum::{
extract::{Path, Query, State},
http::StatusCode,
routing::get,
Json, Router,
};
use axum_extra::headers::authorization::Bearer;
use axum_extra::headers::Authorization;
use axum_extra::TypedHeader;
use tokio::sync::Mutex;
use uuid::Uuid;

use common::enums::SupportOauthType;
use common::errors::MegaError;
use github::GithubOauthService;
use jupiter::context::Context;
use model::{AuthorizeParams, GitHubUserJson, OauthCallbackParams};

pub mod github;
pub mod model;

#[derive(Clone)]
pub struct OauthServiceState {
pub context: Context,
pub sessions: Arc<Mutex<HashMap<String, String>>>,
}

impl OauthServiceState {
pub fn oauth_handler(&self, ouath_type: SupportOauthType) -> impl OauthHandler {
match ouath_type {
SupportOauthType::GitHub => GithubOauthService {
context: self.context.clone(),
client_id: self.context.config.oauth.github_client_id.clone(),
client_secret: self.context.config.oauth.github_client_secret.clone(),
},
}
}
}

#[async_trait]
pub trait OauthHandler: Send + Sync {
fn authorize_url(&self, params: &AuthorizeParams, state: &str) -> String;

async fn access_token(
&self,
params: OauthCallbackParams,
redirect_uri: &str,
) -> Result<String, MegaError>;

async fn user_info(&self, access_token: &str) -> Result<GitHubUserJson, MegaError>;
}

pub fn routers() -> Router<OauthServiceState> {
Router::new()
.route("/:oauth_type/authorize", get(redirect_authorize))
.route("/:oauth_type/callback", get(oauth_callback))
.route("/:oauth_type/user", get(user))
}

async fn redirect_authorize(
Path(oauth_type): Path<String>,
Query(query): Query<AuthorizeParams>,
service_state: State<OauthServiceState>,
) -> Result<Redirect, (StatusCode, String)> {
let oauth_type: SupportOauthType = match oauth_type.parse::<SupportOauthType>() {
Ok(value) => value,
Err(err) => return Err((StatusCode::BAD_REQUEST, err)),
};

let mut sessions = service_state.sessions.lock().await;
let state = Uuid::new_v4().to_string();
sessions.insert(state.clone(), query.redirect_uri.clone());
let auth_url = service_state
.oauth_handler(oauth_type)
.authorize_url(&query, &state);
Ok(Redirect::temporary(&auth_url))
}

async fn oauth_callback(
Path(oauth_type): Path<String>,
Query(query): Query<OauthCallbackParams>,
service_state: State<OauthServiceState>,
) -> Result<Redirect, (StatusCode, String)> {
let oauth_type: SupportOauthType = match oauth_type.parse::<SupportOauthType>() {
Ok(value) => value,
Err(err) => return Err((StatusCode::BAD_REQUEST, err)),
};
// chcek state,
// TODO storage can be replaced by redis, otherwise invalid state can't be expired
let mut sessions = service_state.sessions.lock().await;

let redirect_uri = match sessions.get(&query.state) {
Some(uri) => uri.clone(),
None => return Err((StatusCode::BAD_REQUEST, "Invalid state".to_string())),
};
let access_token = service_state
.oauth_handler(oauth_type)
.access_token(query.clone(), &redirect_uri)
.await
.unwrap();
sessions.remove(&query.state);

let callback_url = format!("{}?access_token={}", redirect_uri, access_token);
Ok(Redirect::temporary(&callback_url))
}

async fn user(
Path(oauth_type): Path<String>,
TypedHeader(Authorization::<Bearer>(token)): TypedHeader<Authorization<Bearer>>,
service_state: State<OauthServiceState>,
) -> Result<Json<GitHubUserJson>, (StatusCode, String)> {
let oauth_type: SupportOauthType = match oauth_type.parse::<SupportOauthType>() {
Ok(value) => value,
Err(err) => return Err((StatusCode::BAD_REQUEST, err)),
};
let res = service_state
.oauth_handler(oauth_type)
.user_info(token.token())
.await
.unwrap();
Ok(Json(res))
}
28 changes: 28 additions & 0 deletions gateway/src/api/oauth/model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, Debug)]
pub struct AuthorizeParams {
pub redirect_uri: String,
}

#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct OauthCallbackParams {
pub code: String,
pub state: String,
}


#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct GitHubAccessTokenJson {
pub access_token: String,
pub scope: Option<String>,
pub token_type: String,
}

#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct GitHubUserJson {
pub login: String,
pub id: u32,
pub avatar_url: String,
pub email: String,
}
34 changes: 27 additions & 7 deletions gateway/src/https_server.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::net::SocketAddr;
use std::ops::Deref;
use std::path::PathBuf;
Expand All @@ -8,16 +9,14 @@ use std::{thread, time};
use anyhow::Result;
use axum::body::Body;
use axum::extract::{Query, State};
use axum::http::{Request, StatusCode, Uri};
use axum::http::{self, Request, StatusCode, Uri};
use axum::response::Response;
use axum::routing::get;
use axum::Router;
use axum_server::tls_rustls::RustlsConfig;
use clap::Args;
use common::enums::ZtmType;
use gemini::ztm::agent::{run_ztm_client, LocalZTMAgent};
use gemini::ztm::hub::LocalZTMHub;
use regex::Regex;
use tokio::sync::Mutex;
use tower::ServiceBuilder;
use tower_http::cors::{Any, CorsLayer};
use tower_http::decompression::RequestDecompressionLayer;
Expand All @@ -26,11 +25,15 @@ use tower_http::trace::TraceLayer;
use ceres::lfs::LfsConfig;
use ceres::protocol::{SmartProtocol, TransportProtocol};
use common::config::Config;
use common::enums::ZtmType;
use common::model::{CommonOptions, GetParams};
use gemini::ztm::agent::{run_ztm_client, LocalZTMAgent};
use gemini::ztm::hub::LocalZTMHub;
use jupiter::context::Context;
use jupiter::raw_storage::local_storage::LocalStorage;

use crate::api::api_router::{self};
use crate::api::oauth::{self, OauthServiceState};
use crate::api::ApiServiceState;
use crate::ca_server::run_ca_server;
use crate::lfs;
Expand Down Expand Up @@ -140,20 +143,37 @@ pub async fn app(config: Config, host: String, port: u16, common: CommonOptions)
common: common.clone(),
};

let api_state = ApiServiceState { context };
let api_state = ApiServiceState {
context: context.clone(),
};

// add RequestDecompressionLayer for handle gzip encode
// add TraceLayer for log record
// add CorsLayer to add cors header
Router::new()
.nest("/api/v1", api_router::routers().with_state(api_state))
.nest(
"/api/v1",
api_router::routers().with_state(api_state.clone()),
)
.nest(
"/auth",
oauth::routers().with_state(OauthServiceState {
context,
sessions: Arc::new(Mutex::new(HashMap::new())),
}),
)
.route(
"/*path",
get(get_method_router)
.post(post_method_router)
.put(put_method_router),
)
.layer(ServiceBuilder::new().layer(CorsLayer::new().allow_origin(Any)))
.layer(
ServiceBuilder::new().layer(CorsLayer::new().allow_origin(Any).allow_headers(vec![
http::header::AUTHORIZATION,
http::header::CONTENT_TYPE,
])),
)
.layer(TraceLayer::new_for_http())
.layer(RequestDecompressionLayer::new())
.with_state(state)
Expand Down
Loading
Loading