From c35a999fd829d8bb706dffbf30a0b05c626841f2 Mon Sep 17 00:00:00 2001 From: Chris Jean-Marie Date: Fri, 25 Mar 2022 12:10:45 -0400 Subject: [PATCH] Modified code to use hashmap for oauth clients --- src/main.rs | 40 +++++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/src/main.rs b/src/main.rs index 68da044..6046d5f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,12 +20,12 @@ use axum::{ use http::header; use oauth2::{ basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, - PkceCodeChallenge, RedirectUrl, RevocationUrl, Scope, TokenUrl, + PkceCodeChallenge, RedirectUrl, Scope, TokenUrl, ClientSecret, TokenResponse, CsrfToken, }; use serde::{Deserialize, Serialize}; -use std::{env, net::SocketAddr}; -use tower_http::{services::ServeDir, trace::TraceLayer}; +use std::{env, net::SocketAddr, collections::HashMap}; +use tower_http::{services::ServeDir}; use uuid::Uuid; const COOKIE_NAME: &str = "SESSION"; @@ -51,9 +51,17 @@ async fn main() { // `MemoryStore` just used as an example. Don't use this in production. let store = MemoryStore::new(); - let google_oauth_client = google_oauth_client(); + + // Create HashMap to store oauth configurations + let mut oauth_clients = HashMap::<&str, BasicClient>::new(); + + // Get the client structures let discord_oauth_client = discord_oauth_client(); + // Get oauth clients for the hashmap + //oauth_clients.insert("Google".to_string(), google_oauth_client); + oauth_clients.insert("Discord", discord_oauth_client); + // build our application with a route let app = Router::new() // `GET /` goes to `root` @@ -76,8 +84,7 @@ async fn main() { .route("/auth/callback", get(google_authorized)) .route("/auth/discord", get(discord_authorized)) .layer(Extension(store)) - .layer(Extension(discord_oauth_client)) - .layer(Extension(google_oauth_client)); + .layer(Extension(oauth_clients)); // run our app with hyper // `axum::Server` is a re-export of `hyper::Server` @@ -286,11 +293,11 @@ async fn logout( Redirect::to("/".parse().unwrap()) } -async fn google_auth(Extension(google_oauth_client): Extension) -> impl IntoResponse { +async fn google_auth() -> impl IntoResponse { let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256(); // Generate the authorization URL to which we'll redirect the user. - let (auth_url, csrf_state) = google_oauth_client + let (auth_url, csrf_state) = google_oauth_client() .authorize_url(CsrfToken::new_random) .add_scope(Scope::new( "https://www.googleapis.com/auth/userinfo.profile".to_string(), @@ -305,7 +312,8 @@ async fn google_auth(Extension(google_oauth_client): Extension) -> Redirect::to(auth_url.to_string().parse().unwrap()) } -async fn discord_auth(Extension(discord_oauth_client): Extension) -> impl IntoResponse { +async fn discord_auth() -> impl IntoResponse { + let discord_oauth_client = discord_oauth_client(); let (auth_url, _csrf_token) = discord_oauth_client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("identify".to_string())) @@ -352,7 +360,7 @@ async fn google_authorized( */ // Create a new session filled with user data - let mut session = Session::new(); + let session = Session::new(); //session.insert("user", &user_data).unwrap(); // Store session and get corresponding cookie @@ -381,8 +389,13 @@ async fn google_authorized( async fn discord_authorized( Query(query): Query, Extension(store): Extension, - Extension(discord_oauth_client): Extension, + Extension(oauth_clients): Extension>, ) -> impl IntoResponse { + // Check for Discord client + if oauth_clients.contains_key("Discord") { + // Get Discord client + let discord_oauth_client = oauth_clients.get(&"Discord").unwrap(); + // Get an auth token let token = discord_oauth_client .exchange_code(AuthorizationCode::new(query.code.clone())) @@ -418,6 +431,11 @@ async fn discord_authorized( headers.insert(SET_COOKIE, cookie.parse().unwrap()); (headers, Redirect::to("/dashboard".parse().unwrap())) + } else { + let mut headers = HeaderMap::new(); + + (headers, Redirect::to("/".parse().unwrap())) + } } struct AuthRedirect;