Modified code to use hashmap for oauth clients
This commit is contained in:
parent
ade737586d
commit
c35a999fd8
40
src/main.rs
40
src/main.rs
|
|
@ -20,12 +20,12 @@ use axum::{
|
||||||
use http::header;
|
use http::header;
|
||||||
use oauth2::{
|
use oauth2::{
|
||||||
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
|
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
|
||||||
PkceCodeChallenge, RedirectUrl, RevocationUrl, Scope, TokenUrl,
|
PkceCodeChallenge, RedirectUrl, Scope, TokenUrl,
|
||||||
ClientSecret, TokenResponse, CsrfToken,
|
ClientSecret, TokenResponse, CsrfToken,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{env, net::SocketAddr};
|
use std::{env, net::SocketAddr, collections::HashMap};
|
||||||
use tower_http::{services::ServeDir, trace::TraceLayer};
|
use tower_http::{services::ServeDir};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
const COOKIE_NAME: &str = "SESSION";
|
const COOKIE_NAME: &str = "SESSION";
|
||||||
|
|
@ -51,9 +51,17 @@ async fn main() {
|
||||||
|
|
||||||
// `MemoryStore` just used as an example. Don't use this in production.
|
// `MemoryStore` just used as an example. Don't use this in production.
|
||||||
let store = MemoryStore::new();
|
let store = MemoryStore::new();
|
||||||
let google_oauth_client = google_oauth_client();
|
|
||||||
|
// Create HashMap to store oauth configurations
|
||||||
|
let mut oauth_clients = HashMap::<&str, BasicClient>::new();
|
||||||
|
|
||||||
|
// Get the client structures
|
||||||
let discord_oauth_client = discord_oauth_client();
|
let discord_oauth_client = discord_oauth_client();
|
||||||
|
|
||||||
|
// Get oauth clients for the hashmap
|
||||||
|
//oauth_clients.insert("Google".to_string(), google_oauth_client);
|
||||||
|
oauth_clients.insert("Discord", discord_oauth_client);
|
||||||
|
|
||||||
// build our application with a route
|
// build our application with a route
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
// `GET /` goes to `root`
|
// `GET /` goes to `root`
|
||||||
|
|
@ -76,8 +84,7 @@ async fn main() {
|
||||||
.route("/auth/callback", get(google_authorized))
|
.route("/auth/callback", get(google_authorized))
|
||||||
.route("/auth/discord", get(discord_authorized))
|
.route("/auth/discord", get(discord_authorized))
|
||||||
.layer(Extension(store))
|
.layer(Extension(store))
|
||||||
.layer(Extension(discord_oauth_client))
|
.layer(Extension(oauth_clients));
|
||||||
.layer(Extension(google_oauth_client));
|
|
||||||
|
|
||||||
// run our app with hyper
|
// run our app with hyper
|
||||||
// `axum::Server` is a re-export of `hyper::Server`
|
// `axum::Server` is a re-export of `hyper::Server`
|
||||||
|
|
@ -286,11 +293,11 @@ async fn logout(
|
||||||
Redirect::to("/".parse().unwrap())
|
Redirect::to("/".parse().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn google_auth(Extension(google_oauth_client): Extension<BasicClient>) -> impl IntoResponse {
|
async fn google_auth() -> impl IntoResponse {
|
||||||
let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256();
|
let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||||
|
|
||||||
// Generate the authorization URL to which we'll redirect the user.
|
// Generate the authorization URL to which we'll redirect the user.
|
||||||
let (auth_url, csrf_state) = google_oauth_client
|
let (auth_url, csrf_state) = google_oauth_client()
|
||||||
.authorize_url(CsrfToken::new_random)
|
.authorize_url(CsrfToken::new_random)
|
||||||
.add_scope(Scope::new(
|
.add_scope(Scope::new(
|
||||||
"https://www.googleapis.com/auth/userinfo.profile".to_string(),
|
"https://www.googleapis.com/auth/userinfo.profile".to_string(),
|
||||||
|
|
@ -305,7 +312,8 @@ async fn google_auth(Extension(google_oauth_client): Extension<BasicClient>) ->
|
||||||
Redirect::to(auth_url.to_string().parse().unwrap())
|
Redirect::to(auth_url.to_string().parse().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn discord_auth(Extension(discord_oauth_client): Extension<BasicClient>) -> impl IntoResponse {
|
async fn discord_auth() -> impl IntoResponse {
|
||||||
|
let discord_oauth_client = discord_oauth_client();
|
||||||
let (auth_url, _csrf_token) = discord_oauth_client
|
let (auth_url, _csrf_token) = discord_oauth_client
|
||||||
.authorize_url(CsrfToken::new_random)
|
.authorize_url(CsrfToken::new_random)
|
||||||
.add_scope(Scope::new("identify".to_string()))
|
.add_scope(Scope::new("identify".to_string()))
|
||||||
|
|
@ -352,7 +360,7 @@ async fn google_authorized(
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Create a new session filled with user data
|
// Create a new session filled with user data
|
||||||
let mut session = Session::new();
|
let session = Session::new();
|
||||||
//session.insert("user", &user_data).unwrap();
|
//session.insert("user", &user_data).unwrap();
|
||||||
|
|
||||||
// Store session and get corresponding cookie
|
// Store session and get corresponding cookie
|
||||||
|
|
@ -381,8 +389,13 @@ async fn google_authorized(
|
||||||
async fn discord_authorized(
|
async fn discord_authorized(
|
||||||
Query(query): Query<AuthRequest>,
|
Query(query): Query<AuthRequest>,
|
||||||
Extension(store): Extension<MemoryStore>,
|
Extension(store): Extension<MemoryStore>,
|
||||||
Extension(discord_oauth_client): Extension<BasicClient>,
|
Extension(oauth_clients): Extension<HashMap::<&str, BasicClient>>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
|
// Check for Discord client
|
||||||
|
if oauth_clients.contains_key("Discord") {
|
||||||
|
// Get Discord client
|
||||||
|
let discord_oauth_client = oauth_clients.get(&"Discord").unwrap();
|
||||||
|
|
||||||
// Get an auth token
|
// Get an auth token
|
||||||
let token = discord_oauth_client
|
let token = discord_oauth_client
|
||||||
.exchange_code(AuthorizationCode::new(query.code.clone()))
|
.exchange_code(AuthorizationCode::new(query.code.clone()))
|
||||||
|
|
@ -418,6 +431,11 @@ async fn discord_authorized(
|
||||||
headers.insert(SET_COOKIE, cookie.parse().unwrap());
|
headers.insert(SET_COOKIE, cookie.parse().unwrap());
|
||||||
|
|
||||||
(headers, Redirect::to("/dashboard".parse().unwrap()))
|
(headers, Redirect::to("/dashboard".parse().unwrap()))
|
||||||
|
} else {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
|
||||||
|
(headers, Redirect::to("/".parse().unwrap()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct AuthRedirect;
|
struct AuthRedirect;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue