From e5642c7e9c0e8a0f60cd505db0b0a0467138203c Mon Sep 17 00:00:00 2001 From: Chris Jean-Marie Date: Mon, 7 Mar 2022 13:57:32 +0000 Subject: [PATCH] Add oauth --- Cargo.lock | 430 +++++++++++++++++++++++++++++++++++++++++++++++++++- Cargo.toml | 4 + src/main.rs | 215 ++++++++++++++++++++++++-- 3 files changed, 633 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f2556b7..869883f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -223,6 +223,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bumpalo" +version = "3.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a45a46ab1f2412e53d3a0ade76ffad2025804294569aae387231a0cd6e0899" + [[package]] name = "bytes" version = "1.1.0" @@ -324,6 +330,15 @@ dependencies = [ "crypto-common", ] +[[package]] +name = "encoding_rs" +version = "0.8.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7896dc8abb250ffdda33912550faa54c88ec8b998dec0b2c55ab224921ce11df" +dependencies = [ + "cfg-if 1.0.0", +] + [[package]] name = "event-listener" version = "2.5.2" @@ -361,6 +376,12 @@ version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c09fd04b7e4073ac7156a9539b57a484a8ea920f79c7c675d05d289ab6110d3" +[[package]] +name = "futures-io" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc4045962a5a5e935ee2fdedaa4e08284547402885ab326734432bed5d12966b" + [[package]] name = "futures-sink" version = "0.3.21" @@ -380,9 +401,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8b7abd5d659d9b90c8cba917f6ec750a74e2dc23902ef9cd4cc8c8b22e6036a" dependencies = [ "futures-core", + "futures-io", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -402,10 +426,37 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d39cd93900197114fa1fcb7ae84ca742095eed9442088988ae74fa744e930e77" dependencies = [ "cfg-if 1.0.0", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] +[[package]] +name = "h2" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9f1f717ddc7b2ba36df7e871fd88db79326551d3d6f1fc406fbfd28b582ff8e" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util 0.6.9", + "tracing", +] + +[[package]] +name = "hashbrown" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" + [[package]] name = "headers" version = "0.3.7" @@ -506,6 +557,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", + "h2", "http", "http-body", "httparse", @@ -519,6 +571,46 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d87c48c02e0dc5e3b849a2041db3029fd066650f8f717c07bf8ed78ccb895cac" +dependencies = [ + "http", + "hyper", + "rustls", + "tokio", + "tokio-rustls", +] + +[[package]] +name = "idna" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8" +dependencies = [ + "matches", + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "indexmap" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282a6247722caba404c065016bbfa522806e51714c34f5dfc3e4a3a46fcb4223" +dependencies = [ + "autocfg", + "hashbrown", +] + +[[package]] +name = "ipnet" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f2d64f2edebec4ce84ad108148e67e1064789bee435edc5b60ad398714a3a9" + [[package]] name = "itoa" version = "1.0.1" @@ -532,6 +624,10 @@ dependencies = [ "askama", "async-session", "axum", + "headers", + "http", + "oauth2", + "reqwest", "serde", "serde_json", "tokio", @@ -540,6 +636,15 @@ dependencies = [ "uuid", ] +[[package]] +name = "js-sys" +version = "0.3.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a38fc24e30fd564ce974c02bf1d337caddff65be6cc4735a1f7eab22a7440f04" +dependencies = [ + "wasm-bindgen", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -690,6 +795,26 @@ dependencies = [ "libc", ] +[[package]] +name = "oauth2" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80e47cfc4c0a1a519d9a025ebfbac3a2439d1b5cdf397d72dcb79b11d9920dab" +dependencies = [ + "base64", + "chrono", + "getrandom", + "http", + "rand", + "reqwest", + "serde", + "serde_json", + "serde_path_to_error", + "sha2", + "thiserror", + "url", +] + [[package]] name = "once_cell" version = "1.10.0" @@ -850,6 +975,80 @@ version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" +[[package]] +name = "reqwest" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f242f1488a539a79bac6dbe7c8609ae43b7914b7736210f239a37cccb32525" +dependencies = [ + "base64", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "hyper", + "hyper-rustls", + "ipnet", + "js-sys", + "lazy_static", + "log", + "mime", + "percent-encoding", + "pin-project-lite", + "rustls", + "rustls-pemfile", + "serde", + "serde_json", + "serde_urlencoded", + "tokio", + "tokio-rustls", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "webpki-roots", + "winreg", +] + +[[package]] +name = "ring" +version = "0.16.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +dependencies = [ + "cc", + "libc", + "once_cell", + "spin", + "untrusted", + "web-sys", + "winapi", +] + +[[package]] +name = "rustls" +version = "0.20.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fbfeb8d0ddb84706bc597a5574ab8912817c52a397f819e5b614e2265206921" +dependencies = [ + "log", + "ring", + "sct", + "webpki", +] + +[[package]] +name = "rustls-pemfile" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eebeaeb360c87bfb72e84abdb3447159c0eaececf1bef2aecd65a8be949d1c9" +dependencies = [ + "base64", +] + [[package]] name = "ryu" version = "1.0.9" @@ -862,6 +1061,16 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "sct" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "serde" version = "1.0.136" @@ -893,6 +1102,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7868ad3b8196a8a0aea99a8220b124278ee5320a55e4fde97794b6f85b1a377" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -947,6 +1165,12 @@ dependencies = [ "libc", ] +[[package]] +name = "slab" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5" + [[package]] name = "smallvec" version = "1.8.0" @@ -963,6 +1187,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + [[package]] name = "subtle" version = "2.4.1" @@ -986,6 +1216,26 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8" +[[package]] +name = "thiserror" +version = "1.0.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.1.4" @@ -995,6 +1245,21 @@ dependencies = [ "once_cell", ] +[[package]] +name = "tinyvec" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c1c1d5a42b6245520c249549ec267180beaffcc0615401ac8e31853d4b6d8d2" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" + [[package]] name = "tokio" version = "1.17.0" @@ -1026,6 +1291,31 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-rustls" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a27d5f2b839802bd8267fa19b0530f5a08b9c08cd417976be2a65d130fe1c11b" +dependencies = [ + "rustls", + "tokio", + "webpki", +] + +[[package]] +name = "tokio-util" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e99e1983e5d376cd8eb4b66604d2e99e79f5bd988c3055891dcd8c9e2604cc0" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "log", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.0" @@ -1060,7 +1350,7 @@ dependencies = [ "pin-project", "pin-project-lite", "tokio", - "tokio-util", + "tokio-util 0.7.0", "tower-layer", "tower-service", "tracing", @@ -1181,12 +1471,46 @@ dependencies = [ "version_check", ] +[[package]] +name = "unicode-bidi" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a01404663e3db436ed2746d9fefef640d868edae3cceb81c3b8d5732fda678f" + +[[package]] +name = "unicode-normalization" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d54590932941a9e9266f0832deed84ebe1bf2e4c9e4a3554d393d18f5e854bf9" +dependencies = [ + "tinyvec", +] + [[package]] name = "unicode-xid" version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + +[[package]] +name = "url" +version = "2.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a507c383b2d33b5fc35d1861e77e6b383d158b2da5e14fe51b83dfedf6fd578c" +dependencies = [ + "form_urlencoded", + "idna", + "matches", + "percent-encoding", + "serde", +] + [[package]] name = "uuid" version = "0.8.2" @@ -1225,6 +1549,101 @@ version = "0.10.2+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" +[[package]] +name = "wasm-bindgen" +version = "0.2.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25f1af7423d8588a3d840681122e72e6a24ddbcb3f0ec385cac0d12d24256c06" +dependencies = [ + "cfg-if 1.0.0", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b21c0df030f5a177f3cba22e9bc4322695ec43e7257d865302900290bcdedca" +dependencies = [ + "bumpalo", + "lazy_static", + "log", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2eb6ec270a31b1d3c7e266b999739109abce8b6c87e4b31fcfcd788b65267395" +dependencies = [ + "cfg-if 1.0.0", + "js-sys", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f4203d69e40a52ee523b2529a773d5ffc1dc0071801c87b3d270b471b80ed01" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa8a30d46208db204854cadbb5d4baf5fcf8071ba5bf48190c3e59937962ebc" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d958d035c4438e28c70e4321a2911302f10135ce78a9c7834c0cab4123d06a2" + +[[package]] +name = "web-sys" +version = "0.3.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c060b319f29dd25724f09a2ba1418f142f539b2be99fbf4d2d5a8f7330afb8eb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" +dependencies = [ + "ring", + "untrusted", +] + +[[package]] +name = "webpki-roots" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "552ceb903e957524388c4d3475725ff2c8b7960922063af6ce53c9a43da07449" +dependencies = [ + "webpki", +] + [[package]] name = "winapi" version = "0.3.9" @@ -1289,3 +1708,12 @@ name = "windows_x86_64_msvc" version = "0.32.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "504a2476202769977a040c6364301a3f65d0cc9e3fb08600b2bda150a0488316" + +[[package]] +name = "winreg" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0120db82e8a1e0b9fb3345a539c478767c0048d842860994d96113d5b667bd69" +dependencies = [ + "winapi", +] diff --git a/Cargo.toml b/Cargo.toml index a0543e1..22e34a3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,3 +15,7 @@ tracing-subscriber = { version="0.3", features = ["env-filter"] } uuid = { version = "0.8", features = ["v4", "serde"] } async-session = "3.0.0" askama = "0.11" +oauth2 = "4.1" +reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "json"] } +headers = "0.3" +http = "0.2" \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index cfd8a03..f838393 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,22 +2,32 @@ use askama::Template; use async_session::{MemoryStore, Session, SessionStore as _}; use axum::{ async_trait, - extract::{self, Extension, FromRequest, Path, RequestParts, TypedHeader}, + extract::{ + self, rejection::TypedHeaderRejectionReason, Extension, FromRequest, Path, Query, + RequestParts, TypedHeader, + }, headers::Cookie, http::{ self, + header::SET_COOKIE, header::{HeaderMap, HeaderValue}, StatusCode, }, - response::{Html, IntoResponse, Response}, + response::{Html, IntoResponse, Redirect, Response}, routing::{get, post}, Json, Router, }; +use http::header; +use oauth2::{ + basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, + ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl, +}; use serde::{Deserialize, Serialize}; -use std::net::SocketAddr; +use std::{env, net::SocketAddr}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use uuid::Uuid; -const AXUM_SESSION_COOKIE_NAME: &str = "axum_session"; +const COOKIE_NAME: &str = "SESSION"; #[tokio::main] async fn main() { @@ -30,13 +40,15 @@ async fn main() { // `MemoryStore` just used as an example. Don't use this in production. let store = MemoryStore::new(); + let oauth_client = oauth_client(); // build our application with a route let app = Router::new() // `GET /` goes to `root` .route("/greet/:name", get(greet)) .route("/", get(handler)) - .layer(Extension(store)); + .layer(Extension(store)) + .layer(Extension(oauth_client)); // run our app with hyper // `axum::Server` is a re-export of `hyper::Server` @@ -64,7 +76,7 @@ async fn handler(user_id: UserIdFromSession) -> impl IntoResponse { headers, format!( "user_id={:?} session_cookie_name={} create_new_session_cookie={}", - user_id, AXUM_SESSION_COOKIE_NAME, create_cookie + user_id, COOKIE_NAME, create_cookie ), ) } @@ -95,9 +107,7 @@ where .await .unwrap(); - let session_cookie = cookie - .as_ref() - .and_then(|cookie| cookie.get(AXUM_SESSION_COOKIE_NAME)); + let session_cookie = cookie.as_ref().and_then(|cookie| cookie.get(COOKIE_NAME)); // return the new created session cookie for client if session_cookie.is_none() { @@ -107,16 +117,14 @@ where let cookie = store.store_session(session).await.unwrap().unwrap(); return Ok(Self::CreatedFreshUserId(FreshUserId { user_id, - cookie: HeaderValue::from_str( - format!("{}={}", AXUM_SESSION_COOKIE_NAME, cookie).as_str(), - ) - .unwrap(), + cookie: HeaderValue::from_str(format!("{}={}", COOKIE_NAME, cookie).as_str()) + .unwrap(), })); } tracing::debug!( "UserIdFromSession: got session cookie from user agent, {}={}", - AXUM_SESSION_COOKIE_NAME, + COOKIE_NAME, session_cookie.unwrap() ); // continue to decode the session cookie @@ -140,7 +148,7 @@ where } else { tracing::debug!( "UserIdFromSession: err session not exists in store, {}={}", - AXUM_SESSION_COOKIE_NAME, + COOKIE_NAME, session_cookie.unwrap() ); return Err((StatusCode::BAD_REQUEST, "No session found for cookie")); @@ -187,3 +195,180 @@ where } } } + +fn oauth_client() -> BasicClient { + // Environment variables (* = required): + // *"CLIENT_ID" "REPLACE_ME"; + // *"CLIENT_SECRET" "REPLACE_ME"; + // "REDIRECT_URL" "http://127.0.0.1:3000/auth/authorized"; + // "AUTH_URL" "https://discord.com/api/oauth2/authorize?response_type=code"; + // "TOKEN_URL" "https://discord.com/api/oauth2/token"; + + let client_id = env::var("CLIENT_ID").expect("Missing CLIENT_ID!"); + let client_secret = env::var("CLIENT_SECRET").expect("Missing CLIENT_SECRET!"); + let redirect_url = env::var("REDIRECT_URL") + .unwrap_or_else(|_| "http://127.0.0.1:3000/auth/authorized".to_string()); + + let auth_url = env::var("AUTH_URL").unwrap_or_else(|_| { + "https://discord.com/api/oauth2/authorize?response_type=code".to_string() + }); + + let token_url = env::var("TOKEN_URL") + .unwrap_or_else(|_| "https://discord.com/api/oauth2/token".to_string()); + + BasicClient::new( + ClientId::new(client_id), + Some(ClientSecret::new(client_secret)), + AuthUrl::new(auth_url).unwrap(), + Some(TokenUrl::new(token_url).unwrap()), + ) + .set_redirect_uri(RedirectUrl::new(redirect_url).unwrap()) +} + +// The user data we'll get back from Discord. +// https://discord.com/developers/docs/resources/user#user-object-user-structure +#[derive(Debug, Serialize, Deserialize)] +struct User { + id: String, + avatar: Option, + username: String, + discriminator: String, +} + +// Session is optional +async fn index(user: Option) -> impl IntoResponse { + match user { + Some(u) => format!( + "Hey {}! You're logged in!\nYou may now access `/protected`.\nLog out with `/logout`.", + u.username + ), + None => "You're not logged in.\nVisit `/auth/discord` to do so.".to_string(), + } +} + +async fn discord_auth(Extension(client): Extension) -> impl IntoResponse { + let (auth_url, _csrf_token) = client + .authorize_url(CsrfToken::new_random) + .add_scope(Scope::new("identify".to_string())) + .url(); + + // Redirect to Discord's oauth service + Redirect::to(auth_url.to_string().parse().unwrap()) +} + +// Valid user session required. If there is none, redirect to the auth page +async fn protected(user: User) -> impl IntoResponse { + format!( + "Welcome to the protected area :)\nHere's your info:\n{:?}", + user + ) +} + +async fn logout( + Extension(store): Extension, + TypedHeader(cookies): TypedHeader, +) -> impl IntoResponse { + let cookie = cookies.get(COOKIE_NAME).unwrap(); + let session = match store.load_session(cookie.to_string()).await.unwrap() { + Some(s) => s, + // No session active, just redirect + None => return Redirect::to("/".parse().unwrap()), + }; + + store.destroy_session(session).await.unwrap(); + + Redirect::to("/".parse().unwrap()) +} + +#[derive(Debug, Deserialize)] +#[allow(dead_code)] +struct AuthRequest { + code: String, + state: String, +} + +async fn login_authorized( + Query(query): Query, + Extension(store): Extension, + Extension(oauth_client): Extension, +) -> impl IntoResponse { + // Get an auth token + let token = oauth_client + .exchange_code(AuthorizationCode::new(query.code.clone())) + .request_async(async_http_client) + .await + .unwrap(); + + // Fetch user data from discord + let client = reqwest::Client::new(); + let user_data: User = client + // https://discord.com/developers/docs/resources/user#get-current-user + .get("https://discordapp.com/api/users/@me") + .bearer_auth(token.access_token().secret()) + .send() + .await + .unwrap() + .json::() + .await + .unwrap(); + + // Create a new session filled with user data + let mut session = Session::new(); + session.insert("user", &user_data).unwrap(); + + // Store session and get corresponding cookie + let cookie = store.store_session(session).await.unwrap().unwrap(); + + // Build the cookie + let cookie = format!("{}={}; SameSite=Lax; Path=/", COOKIE_NAME, cookie); + + // Set cookie + let mut headers = HeaderMap::new(); + headers.insert(SET_COOKIE, cookie.parse().unwrap()); + + (headers, Redirect::to("/".parse().unwrap())) +} + +struct AuthRedirect; + +impl IntoResponse for AuthRedirect { + fn into_response(self) -> Response { + Redirect::temporary("/auth/discord".parse().unwrap()).into_response() + } +} + +#[async_trait] +impl FromRequest for User +where + B: Send, +{ + // If anything goes wrong or no session is found, redirect to the auth page + type Rejection = AuthRedirect; + + async fn from_request(req: &mut RequestParts) -> Result { + let Extension(store) = Extension::::from_request(req) + .await + .expect("`MemoryStore` extension is missing"); + + let cookies = TypedHeader::::from_request(req) + .await + .map_err(|e| match *e.name() { + header::COOKIE => match e.reason() { + TypedHeaderRejectionReason::Missing => AuthRedirect, + _ => panic!("unexpected error getting Cookie header(s): {}", e), + }, + _ => panic!("unexpected error getting cookies: {}", e), + })?; + let session_cookie = cookies.get(COOKIE_NAME).ok_or(AuthRedirect)?; + + let session = store + .load_session(session_cookie.to_string()) + .await + .unwrap() + .ok_or(AuthRedirect)?; + + let user = session.get::("user").ok_or(AuthRedirect)?; + + Ok(user) + } +}