Modified code to use hashmap for oauth clients

This commit is contained in:
Chris Jean-Marie 2022-03-25 12:10:45 -04:00
parent ade737586d
commit c35a999fd8
1 changed files with 29 additions and 11 deletions

View File

@ -20,12 +20,12 @@ use axum::{
use http::header;
use oauth2::{
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
PkceCodeChallenge, RedirectUrl, RevocationUrl, Scope, TokenUrl,
PkceCodeChallenge, RedirectUrl, Scope, TokenUrl,
ClientSecret, TokenResponse, CsrfToken,
};
use serde::{Deserialize, Serialize};
use std::{env, net::SocketAddr};
use tower_http::{services::ServeDir, trace::TraceLayer};
use std::{env, net::SocketAddr, collections::HashMap};
use tower_http::{services::ServeDir};
use uuid::Uuid;
const COOKIE_NAME: &str = "SESSION";
@ -51,9 +51,17 @@ async fn main() {
// `MemoryStore` just used as an example. Don't use this in production.
let store = MemoryStore::new();
let google_oauth_client = google_oauth_client();
// Create HashMap to store oauth configurations
let mut oauth_clients = HashMap::<&str, BasicClient>::new();
// Get the client structures
let discord_oauth_client = discord_oauth_client();
// Get oauth clients for the hashmap
//oauth_clients.insert("Google".to_string(), google_oauth_client);
oauth_clients.insert("Discord", discord_oauth_client);
// build our application with a route
let app = Router::new()
// `GET /` goes to `root`
@ -76,8 +84,7 @@ async fn main() {
.route("/auth/callback", get(google_authorized))
.route("/auth/discord", get(discord_authorized))
.layer(Extension(store))
.layer(Extension(discord_oauth_client))
.layer(Extension(google_oauth_client));
.layer(Extension(oauth_clients));
// run our app with hyper
// `axum::Server` is a re-export of `hyper::Server`
@ -286,11 +293,11 @@ async fn logout(
Redirect::to("/".parse().unwrap())
}
async fn google_auth(Extension(google_oauth_client): Extension<BasicClient>) -> impl IntoResponse {
async fn google_auth() -> impl IntoResponse {
let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256();
// Generate the authorization URL to which we'll redirect the user.
let (auth_url, csrf_state) = google_oauth_client
let (auth_url, csrf_state) = google_oauth_client()
.authorize_url(CsrfToken::new_random)
.add_scope(Scope::new(
"https://www.googleapis.com/auth/userinfo.profile".to_string(),
@ -305,7 +312,8 @@ async fn google_auth(Extension(google_oauth_client): Extension<BasicClient>) ->
Redirect::to(auth_url.to_string().parse().unwrap())
}
async fn discord_auth(Extension(discord_oauth_client): Extension<BasicClient>) -> impl IntoResponse {
async fn discord_auth() -> impl IntoResponse {
let discord_oauth_client = discord_oauth_client();
let (auth_url, _csrf_token) = discord_oauth_client
.authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("identify".to_string()))
@ -352,7 +360,7 @@ async fn google_authorized(
*/
// Create a new session filled with user data
let mut session = Session::new();
let session = Session::new();
//session.insert("user", &user_data).unwrap();
// Store session and get corresponding cookie
@ -381,8 +389,13 @@ async fn google_authorized(
async fn discord_authorized(
Query(query): Query<AuthRequest>,
Extension(store): Extension<MemoryStore>,
Extension(discord_oauth_client): Extension<BasicClient>,
Extension(oauth_clients): Extension<HashMap::<&str, BasicClient>>,
) -> impl IntoResponse {
// Check for Discord client
if oauth_clients.contains_key("Discord") {
// Get Discord client
let discord_oauth_client = oauth_clients.get(&"Discord").unwrap();
// Get an auth token
let token = discord_oauth_client
.exchange_code(AuthorizationCode::new(query.code.clone()))
@ -418,6 +431,11 @@ async fn discord_authorized(
headers.insert(SET_COOKIE, cookie.parse().unwrap());
(headers, Redirect::to("/dashboard".parse().unwrap()))
} else {
let mut headers = HeaderMap::new();
(headers, Redirect::to("/".parse().unwrap()))
}
}
struct AuthRedirect;