Modified code to use hashmap for oauth clients
This commit is contained in:
parent
ade737586d
commit
c35a999fd8
40
src/main.rs
40
src/main.rs
|
|
@ -20,12 +20,12 @@ use axum::{
|
|||
use http::header;
|
||||
use oauth2::{
|
||||
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
|
||||
PkceCodeChallenge, RedirectUrl, RevocationUrl, Scope, TokenUrl,
|
||||
PkceCodeChallenge, RedirectUrl, Scope, TokenUrl,
|
||||
ClientSecret, TokenResponse, CsrfToken,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{env, net::SocketAddr};
|
||||
use tower_http::{services::ServeDir, trace::TraceLayer};
|
||||
use std::{env, net::SocketAddr, collections::HashMap};
|
||||
use tower_http::{services::ServeDir};
|
||||
use uuid::Uuid;
|
||||
|
||||
const COOKIE_NAME: &str = "SESSION";
|
||||
|
|
@ -51,9 +51,17 @@ async fn main() {
|
|||
|
||||
// `MemoryStore` just used as an example. Don't use this in production.
|
||||
let store = MemoryStore::new();
|
||||
let google_oauth_client = google_oauth_client();
|
||||
|
||||
// Create HashMap to store oauth configurations
|
||||
let mut oauth_clients = HashMap::<&str, BasicClient>::new();
|
||||
|
||||
// Get the client structures
|
||||
let discord_oauth_client = discord_oauth_client();
|
||||
|
||||
// Get oauth clients for the hashmap
|
||||
//oauth_clients.insert("Google".to_string(), google_oauth_client);
|
||||
oauth_clients.insert("Discord", discord_oauth_client);
|
||||
|
||||
// build our application with a route
|
||||
let app = Router::new()
|
||||
// `GET /` goes to `root`
|
||||
|
|
@ -76,8 +84,7 @@ async fn main() {
|
|||
.route("/auth/callback", get(google_authorized))
|
||||
.route("/auth/discord", get(discord_authorized))
|
||||
.layer(Extension(store))
|
||||
.layer(Extension(discord_oauth_client))
|
||||
.layer(Extension(google_oauth_client));
|
||||
.layer(Extension(oauth_clients));
|
||||
|
||||
// run our app with hyper
|
||||
// `axum::Server` is a re-export of `hyper::Server`
|
||||
|
|
@ -286,11 +293,11 @@ async fn logout(
|
|||
Redirect::to("/".parse().unwrap())
|
||||
}
|
||||
|
||||
async fn google_auth(Extension(google_oauth_client): Extension<BasicClient>) -> impl IntoResponse {
|
||||
async fn google_auth() -> impl IntoResponse {
|
||||
let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
|
||||
// Generate the authorization URL to which we'll redirect the user.
|
||||
let (auth_url, csrf_state) = google_oauth_client
|
||||
let (auth_url, csrf_state) = google_oauth_client()
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
.add_scope(Scope::new(
|
||||
"https://www.googleapis.com/auth/userinfo.profile".to_string(),
|
||||
|
|
@ -305,7 +312,8 @@ async fn google_auth(Extension(google_oauth_client): Extension<BasicClient>) ->
|
|||
Redirect::to(auth_url.to_string().parse().unwrap())
|
||||
}
|
||||
|
||||
async fn discord_auth(Extension(discord_oauth_client): Extension<BasicClient>) -> impl IntoResponse {
|
||||
async fn discord_auth() -> impl IntoResponse {
|
||||
let discord_oauth_client = discord_oauth_client();
|
||||
let (auth_url, _csrf_token) = discord_oauth_client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
.add_scope(Scope::new("identify".to_string()))
|
||||
|
|
@ -352,7 +360,7 @@ async fn google_authorized(
|
|||
*/
|
||||
|
||||
// Create a new session filled with user data
|
||||
let mut session = Session::new();
|
||||
let session = Session::new();
|
||||
//session.insert("user", &user_data).unwrap();
|
||||
|
||||
// Store session and get corresponding cookie
|
||||
|
|
@ -381,8 +389,13 @@ async fn google_authorized(
|
|||
async fn discord_authorized(
|
||||
Query(query): Query<AuthRequest>,
|
||||
Extension(store): Extension<MemoryStore>,
|
||||
Extension(discord_oauth_client): Extension<BasicClient>,
|
||||
Extension(oauth_clients): Extension<HashMap::<&str, BasicClient>>,
|
||||
) -> impl IntoResponse {
|
||||
// Check for Discord client
|
||||
if oauth_clients.contains_key("Discord") {
|
||||
// Get Discord client
|
||||
let discord_oauth_client = oauth_clients.get(&"Discord").unwrap();
|
||||
|
||||
// Get an auth token
|
||||
let token = discord_oauth_client
|
||||
.exchange_code(AuthorizationCode::new(query.code.clone()))
|
||||
|
|
@ -418,6 +431,11 @@ async fn discord_authorized(
|
|||
headers.insert(SET_COOKIE, cookie.parse().unwrap());
|
||||
|
||||
(headers, Redirect::to("/dashboard".parse().unwrap()))
|
||||
} else {
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
(headers, Redirect::to("/".parse().unwrap()))
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthRedirect;
|
||||
|
|
|
|||
Loading…
Reference in New Issue