use askama::Template; use async_session::{MemoryStore, Session, SessionStore as _}; use axum::{ async_trait, extract::{ rejection::TypedHeaderRejectionReason, Extension, FromRequest, Query, RequestParts, TypedHeader, }, headers::Cookie, http::{ self, header::SET_COOKIE, header::{HeaderMap, HeaderValue}, StatusCode, }, response::{Html, IntoResponse, Redirect, Response}, routing::{get, get_service}, Router, }; use http::header; use oauth2::{ basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, PkceCodeChallenge, RedirectUrl, Scope, TokenUrl, ClientSecret, TokenResponse, CsrfToken, }; use serde::{Deserialize, Serialize}; use std::{env, net::SocketAddr, collections::HashMap}; use tower_http::{services::ServeDir}; use uuid::Uuid; const COOKIE_NAME: &str = "SESSION"; // The user data we'll get back from Discord. // https://discord.com/developers/docs/resources/user#user-object-user-structure #[derive(Debug, Serialize, Deserialize)] struct User { id: String, avatar: Option, username: String, discriminator: String, } #[tokio::main] async fn main() { // Set the RUST_LOG, if it hasn't been explicitly defined if std::env::var_os("RUST_LOG").is_none() { std::env::set_var("RUST_LOG", "example_sessions=debug") } // initialize tracing tracing_subscriber::fmt::init(); // `MemoryStore` just used as an example. Don't use this in production. let store = MemoryStore::new(); // Create HashMap to store oauth configurations let mut oauth_clients = HashMap::<&str, BasicClient>::new(); // Get the client structures let discord_oauth_client = discord_oauth_client(); // Get oauth clients for the hashmap //oauth_clients.insert("Google".to_string(), google_oauth_client); oauth_clients.insert("Discord", discord_oauth_client); // build our application with a route let app = Router::new() // `GET /` goes to `root` .nest( "/assets", get_service(ServeDir::new("templates/assets")).handle_error( |error: std::io::Error| async move { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), ) }, ), ) .route("/", get(index)) .route("/google_auth", get(google_auth)) .route("/discord_auth", get(discord_auth)) .route("/login", get(login)) .route("/logout", get(logout)) .route("/dashboard", get(dashboard)) .route("/auth/callback", get(google_authorized)) .route("/auth/discord", get(discord_authorized)) .layer(Extension(store)) .layer(Extension(oauth_clients)); // run our app with hyper // `axum::Server` is a re-export of `hyper::Server` let addr = SocketAddr::from(([0, 0, 0, 0], 40192)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } // Session is optional async fn index(user: Option) -> impl IntoResponse { let (userid, name) = match user { Some(u) => (true, u.username), None => (false, "You're not logged in.".to_string()), }; let template = IndexTemplate { userid, name }; HtmlTemplate(template) } async fn login() -> impl IntoResponse { let name = "".to_string(); let userid = false; let template = LoginTemplate { userid, name }; HtmlTemplate(template) } // Valid user session required. If there is none, redirect to the auth page async fn dashboard(user: User) -> impl IntoResponse { let name = user.username; let userid = true; let template = DashboardTemplate { userid, name }; HtmlTemplate(template) } fn google_oauth_client() -> BasicClient { let redirect_url = env::var("REDIRECT_URL") .unwrap_or_else(|_| "http://localhost:40192/auth/callback".to_string()); let google_client_id = env::var("GOOGLE_CLIENT_ID").expect("Missing GOOGLE_CLIENT_ID!"); let google_client_secret = env::var("GOOGLE_CLIENT_SECRET").expect("Missing GOOGLE_CLIENT_SECRET!"); let google_auth_url = env::var("GOOGLE_AUTH_URL").unwrap_or_else(|_| { "https://accounts.google.com/o/oauth2/v2/auth".to_string() }); let google_token_url = env::var("GOOGLE_TOKEN_URL") .unwrap_or_else(|_| "https://www.googleapis.com/oauth2/v3/token".to_string()); BasicClient::new( ClientId::new(google_client_id), Some(ClientSecret::new(google_client_secret)), AuthUrl::new(google_auth_url).unwrap(), Some(TokenUrl::new(google_token_url).unwrap()), ) .set_redirect_uri(RedirectUrl::new(redirect_url).unwrap()) } fn discord_oauth_client() -> BasicClient { let redirect_url = env::var("REDIRECT_URL") .unwrap_or_else(|_| "http://localhost:40192/auth/discord".to_string()); let discord_client_id = env::var("DISCORD_CLIENT_ID").expect("Missing DISCORD_CLIENT_ID!"); let discord_client_secret = env::var("DISCORD_CLIENT_SECRET").expect("Missing DISCORD_CLIENT_SECRET!"); let discord_auth_url = env::var("DISCORD_AUTH_URL").unwrap_or_else(|_| { "https://discord.com/api/oauth2/authorize?response_type=code".to_string() }); let discord_token_url = env::var("DISCORD_TOKEN_URL") .unwrap_or_else(|_| "https://discord.com/api/oauth2/token".to_string()); BasicClient::new( ClientId::new(discord_client_id), Some(ClientSecret::new(discord_client_secret)), AuthUrl::new(discord_auth_url).unwrap(), Some(TokenUrl::new(discord_token_url).unwrap()), ) .set_redirect_uri(RedirectUrl::new(redirect_url).unwrap()) } #[derive(Template)] #[template(path = "login.html")] struct LoginTemplate { userid: bool, name: String, } #[derive(Template)] #[template(path = "dashboard.html")] struct DashboardTemplate { userid: bool, name: String, } struct FreshUserId { pub user_id: UserId, pub cookie: HeaderValue, } enum UserIdFromSession { FoundUserId(UserId), CreatedFreshUserId(FreshUserId), } #[async_trait] impl FromRequest for UserIdFromSession where B: Send, { type Rejection = (StatusCode, &'static str); async fn from_request(req: &mut RequestParts) -> Result { let Extension(store) = Extension::::from_request(req) .await .expect("`MemoryStore` extension missing"); let cookie = Option::>::from_request(req) .await .unwrap(); let session_cookie = cookie.as_ref().and_then(|cookie| cookie.get(COOKIE_NAME)); // return the new created session cookie for client if session_cookie.is_none() { let user_id = UserId::new(); let mut session = Session::new(); session.insert("user_id", user_id).unwrap(); let cookie = store.store_session(session).await.unwrap().unwrap(); return Ok(Self::CreatedFreshUserId(FreshUserId { user_id, cookie: HeaderValue::from_str(format!("{}={}", COOKIE_NAME, cookie).as_str()) .unwrap(), })); } tracing::debug!( "UserIdFromSession: got session cookie from user agent, {}={}", COOKIE_NAME, session_cookie.unwrap() ); // continue to decode the session cookie let user_id = if let Some(session) = store .load_session(session_cookie.unwrap().to_owned()) .await .unwrap() { if let Some(user_id) = session.get::("user_id") { tracing::debug!( "UserIdFromSession: session decoded success, user_id={:?}", user_id ); user_id } else { return Err(( StatusCode::INTERNAL_SERVER_ERROR, "No `user_id` found in session", )); } } else { tracing::debug!( "UserIdFromSession: err session not exists in store, {}={}", COOKIE_NAME, session_cookie.unwrap() ); return Err((StatusCode::BAD_REQUEST, "No session found for cookie")); }; Ok(Self::FoundUserId(user_id)) } } #[derive(Serialize, Deserialize, Debug, Clone, Copy)] struct UserId(Uuid); impl UserId { fn new() -> Self { Self(Uuid::new_v4()) } } #[derive(Template)] #[template(path = "index.html")] struct IndexTemplate { userid: bool, name: String, } struct HtmlTemplate(T); impl IntoResponse for HtmlTemplate where T: Template, { fn into_response(self) -> Response { match self.0.render() { Ok(html) => Html(html).into_response(), Err(err) => ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to render template. Error: {}", err), ) .into_response(), } } } async fn logout( Extension(store): Extension, TypedHeader(cookies): TypedHeader, ) -> impl IntoResponse { let cookie = cookies.get(COOKIE_NAME).unwrap(); let session = match store.load_session(cookie.to_string()).await.unwrap() { Some(s) => s, // No session active, just redirect None => return Redirect::to("/".parse().unwrap()), }; store.destroy_session(session).await.unwrap(); Redirect::to("/".parse().unwrap()) } async fn google_auth() -> impl IntoResponse { let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256(); // Generate the authorization URL to which we'll redirect the user. let (auth_url, csrf_state) = google_oauth_client() .authorize_url(CsrfToken::new_random) .add_scope(Scope::new( "https://www.googleapis.com/auth/userinfo.profile".to_string(), )) .add_scope(Scope::new( "https://www.googleapis.com/auth/userinfo.email".to_string(), )) .set_pkce_challenge(pkce_code_challenge) .url(); // Redirect to Google's oauth service Redirect::to(auth_url.to_string().parse().unwrap()) } async fn discord_auth() -> impl IntoResponse { let discord_oauth_client = discord_oauth_client(); let (auth_url, _csrf_token) = discord_oauth_client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("identify".to_string())) .url(); // Redirect to Discord's oauth service Redirect::to(auth_url.to_string().parse().unwrap()) } #[derive(Debug, Deserialize)] #[allow(dead_code)] struct AuthRequest { code: String, state: String, } async fn google_authorized( Query(query): Query, Extension(store): Extension, Extension(google_oauth_client): Extension, ) -> impl IntoResponse { print!("{}", query.code); /* // Get an auth token let token = google_oauth_client .exchange_code(AuthorizationCode::new(query.code.clone())) .request_async(async_http_client) .await .unwrap(); // Fetch user data from discord let client = reqwest::Client::new(); let user_data: User = client // https://discord.com/developers/docs/resources/user#get-current-user .get("https://discordapp.com/api/users/@me") .bearer_auth(token.access_token().secret()) .send() .await .unwrap() .json::() .await .unwrap(); */ // Create a new session filled with user data let session = Session::new(); //session.insert("user", &user_data).unwrap(); // Store session and get corresponding cookie let cookie = store.store_session(session).await.unwrap().unwrap(); // Build the cookie let cookie = format!("{}={}; SameSite=Lax; Path=/", COOKIE_NAME, cookie); // Set cookie let mut headers = HeaderMap::new(); headers.insert(SET_COOKIE, cookie.parse().unwrap()); //(headers, Redirect::to("/dashboard".parse().unwrap())) let mut page = String::new(); page.push_str(&"Display the data returned by Google\n".to_string()); page.push_str(&"\nState: ".to_string()); page.push_str(&query.state); page.push_str(&"\nCode: ".to_string()); page.push_str(&query.code); page.push_str(&"\nScope: ".to_string()); page } async fn discord_authorized( Query(query): Query, Extension(store): Extension, Extension(oauth_clients): Extension>, ) -> impl IntoResponse { // Check for Discord client if oauth_clients.contains_key("Discord") { // Get Discord client let discord_oauth_client = oauth_clients.get(&"Discord").unwrap(); // Get an auth token let token = discord_oauth_client .exchange_code(AuthorizationCode::new(query.code.clone())) .request_async(async_http_client) .await .unwrap(); // Fetch user data from discord let client = reqwest::Client::new(); let user_data: User = client // https://discord.com/developers/docs/resources/user#get-current-user .get("https://discordapp.com/api/users/@me") .bearer_auth(token.access_token().secret()) .send() .await .unwrap() .json::() .await .unwrap(); // Create a new session filled with user data let mut session = Session::new(); session.insert("user", &user_data).unwrap(); // Store session and get corresponding cookie let cookie = store.store_session(session).await.unwrap().unwrap(); // Build the cookie let cookie = format!("{}={}; SameSite=Lax; Path=/", COOKIE_NAME, cookie); // Set cookie let mut headers = HeaderMap::new(); headers.insert(SET_COOKIE, cookie.parse().unwrap()); (headers, Redirect::to("/dashboard".parse().unwrap())) } else { let mut headers = HeaderMap::new(); (headers, Redirect::to("/".parse().unwrap())) } } struct AuthRedirect; impl IntoResponse for AuthRedirect { fn into_response(self) -> Response { Redirect::temporary("/login".parse().unwrap()).into_response() } } #[async_trait] impl FromRequest for User where B: Send, { // If anything goes wrong or no session is found, redirect to the auth page type Rejection = AuthRedirect; async fn from_request(req: &mut RequestParts) -> Result { let Extension(store) = Extension::::from_request(req) .await .expect("`MemoryStore` extension is missing"); let cookies = TypedHeader::::from_request(req) .await .map_err(|e| match *e.name() { header::COOKIE => match e.reason() { TypedHeaderRejectionReason::Missing => AuthRedirect, _ => panic!("unexpected error getting Cookie header(s): {}", e), }, _ => panic!("unexpected error getting cookies: {}", e), })?; let session_cookie = cookies.get(COOKIE_NAME).ok_or(AuthRedirect)?; let session = store .load_session(session_cookie.to_string()) .await .unwrap() .ok_or(AuthRedirect)?; let user = session.get::("user").ok_or(AuthRedirect)?; Ok(user) } }