496 lines
15 KiB
Rust
496 lines
15 KiB
Rust
use askama::Template;
|
|
use async_session::{MemoryStore, Session, SessionStore as _};
|
|
use axum::{
|
|
async_trait,
|
|
extract::{
|
|
rejection::TypedHeaderRejectionReason, Extension, FromRequest, Query, RequestParts,
|
|
TypedHeader,
|
|
},
|
|
headers::Cookie,
|
|
http::{
|
|
self,
|
|
header::SET_COOKIE,
|
|
header::{HeaderMap, HeaderValue},
|
|
StatusCode,
|
|
},
|
|
response::{Html, IntoResponse, Redirect, Response},
|
|
routing::{get, get_service},
|
|
Router,
|
|
};
|
|
use http::header;
|
|
use oauth2::{
|
|
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
|
|
PkceCodeChallenge, RedirectUrl, Scope, TokenUrl,
|
|
ClientSecret, TokenResponse, CsrfToken,
|
|
};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::{env, net::SocketAddr, collections::HashMap};
|
|
use tower_http::{services::ServeDir};
|
|
use uuid::Uuid;
|
|
|
|
const COOKIE_NAME: &str = "SESSION";
|
|
|
|
// The user data we'll get back from Discord.
|
|
// https://discord.com/developers/docs/resources/user#user-object-user-structure
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct User {
|
|
id: String,
|
|
avatar: Option<String>,
|
|
username: String,
|
|
discriminator: String,
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() {
|
|
// Set the RUST_LOG, if it hasn't been explicitly defined
|
|
if std::env::var_os("RUST_LOG").is_none() {
|
|
std::env::set_var("RUST_LOG", "example_sessions=debug")
|
|
}
|
|
// initialize tracing
|
|
tracing_subscriber::fmt::init();
|
|
|
|
// `MemoryStore` just used as an example. Don't use this in production.
|
|
let store = MemoryStore::new();
|
|
|
|
// Create HashMap to store oauth configurations
|
|
let mut oauth_clients = HashMap::<&str, BasicClient>::new();
|
|
|
|
// Get the client structures
|
|
let discord_oauth_client = discord_oauth_client();
|
|
|
|
// Get oauth clients for the hashmap
|
|
//oauth_clients.insert("Google".to_string(), google_oauth_client);
|
|
oauth_clients.insert("Discord", discord_oauth_client);
|
|
|
|
// build our application with a route
|
|
let app = Router::new()
|
|
// `GET /` goes to `root`
|
|
.nest(
|
|
"/assets",
|
|
get_service(ServeDir::new("templates/assets")).handle_error(
|
|
|error: std::io::Error| async move {
|
|
(
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
format!("Unhandled internal error: {}", error),
|
|
)
|
|
},
|
|
),
|
|
)
|
|
.route("/", get(index))
|
|
.route("/google_auth", get(google_auth))
|
|
.route("/discord_auth", get(discord_auth))
|
|
.route("/login", get(login))
|
|
.route("/logout", get(logout))
|
|
.route("/dashboard", get(dashboard))
|
|
.route("/auth/callback", get(google_authorized))
|
|
.route("/auth/discord", get(discord_authorized))
|
|
.layer(Extension(store))
|
|
.layer(Extension(oauth_clients));
|
|
|
|
// run our app with hyper
|
|
// `axum::Server` is a re-export of `hyper::Server`
|
|
let addr = SocketAddr::from(([0, 0, 0, 0], 40192));
|
|
tracing::debug!("listening on {}", addr);
|
|
axum::Server::bind(&addr)
|
|
.serve(app.into_make_service())
|
|
.await
|
|
.unwrap();
|
|
}
|
|
|
|
// Session is optional
|
|
async fn index(user: Option<User>) -> impl IntoResponse {
|
|
let (userid, name) = match user {
|
|
Some(u) => (true, u.username),
|
|
None => (false, "You're not logged in.".to_string()),
|
|
};
|
|
let template = IndexTemplate { userid, name };
|
|
HtmlTemplate(template)
|
|
}
|
|
|
|
async fn login() -> impl IntoResponse {
|
|
let name = "".to_string();
|
|
let userid = false;
|
|
let template = LoginTemplate { userid, name };
|
|
HtmlTemplate(template)
|
|
}
|
|
|
|
// Valid user session required. If there is none, redirect to the auth page
|
|
async fn dashboard(user: User) -> impl IntoResponse {
|
|
let name = user.username;
|
|
let userid = true;
|
|
let template = DashboardTemplate { userid, name };
|
|
HtmlTemplate(template)
|
|
}
|
|
|
|
fn google_oauth_client() -> BasicClient {
|
|
let redirect_url = env::var("REDIRECT_URL")
|
|
.unwrap_or_else(|_| "http://localhost:40192/auth/callback".to_string());
|
|
|
|
let google_client_id = env::var("GOOGLE_CLIENT_ID").expect("Missing GOOGLE_CLIENT_ID!");
|
|
let google_client_secret = env::var("GOOGLE_CLIENT_SECRET").expect("Missing GOOGLE_CLIENT_SECRET!");
|
|
let google_auth_url = env::var("GOOGLE_AUTH_URL").unwrap_or_else(|_| {
|
|
"https://accounts.google.com/o/oauth2/v2/auth".to_string()
|
|
});
|
|
let google_token_url = env::var("GOOGLE_TOKEN_URL")
|
|
.unwrap_or_else(|_| "https://www.googleapis.com/oauth2/v3/token".to_string());
|
|
|
|
BasicClient::new(
|
|
ClientId::new(google_client_id),
|
|
Some(ClientSecret::new(google_client_secret)),
|
|
AuthUrl::new(google_auth_url).unwrap(),
|
|
Some(TokenUrl::new(google_token_url).unwrap()),
|
|
)
|
|
.set_redirect_uri(RedirectUrl::new(redirect_url).unwrap())
|
|
}
|
|
|
|
fn discord_oauth_client() -> BasicClient {
|
|
let redirect_url = env::var("REDIRECT_URL")
|
|
.unwrap_or_else(|_| "http://localhost:40192/auth/discord".to_string());
|
|
|
|
let discord_client_id = env::var("DISCORD_CLIENT_ID").expect("Missing DISCORD_CLIENT_ID!");
|
|
let discord_client_secret = env::var("DISCORD_CLIENT_SECRET").expect("Missing DISCORD_CLIENT_SECRET!");
|
|
let discord_auth_url = env::var("DISCORD_AUTH_URL").unwrap_or_else(|_| {
|
|
"https://discord.com/api/oauth2/authorize?response_type=code".to_string()
|
|
});
|
|
let discord_token_url = env::var("DISCORD_TOKEN_URL")
|
|
.unwrap_or_else(|_| "https://discord.com/api/oauth2/token".to_string());
|
|
|
|
BasicClient::new(
|
|
ClientId::new(discord_client_id),
|
|
Some(ClientSecret::new(discord_client_secret)),
|
|
AuthUrl::new(discord_auth_url).unwrap(),
|
|
Some(TokenUrl::new(discord_token_url).unwrap()),
|
|
)
|
|
.set_redirect_uri(RedirectUrl::new(redirect_url).unwrap())
|
|
}
|
|
|
|
#[derive(Template)]
|
|
#[template(path = "login.html")]
|
|
struct LoginTemplate {
|
|
userid: bool,
|
|
name: String,
|
|
}
|
|
|
|
#[derive(Template)]
|
|
#[template(path = "dashboard.html")]
|
|
struct DashboardTemplate {
|
|
userid: bool,
|
|
name: String,
|
|
}
|
|
|
|
struct FreshUserId {
|
|
pub user_id: UserId,
|
|
pub cookie: HeaderValue,
|
|
}
|
|
|
|
enum UserIdFromSession {
|
|
FoundUserId(UserId),
|
|
CreatedFreshUserId(FreshUserId),
|
|
}
|
|
|
|
#[async_trait]
|
|
impl<B> FromRequest<B> for UserIdFromSession
|
|
where
|
|
B: Send,
|
|
{
|
|
type Rejection = (StatusCode, &'static str);
|
|
|
|
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
|
let Extension(store) = Extension::<MemoryStore>::from_request(req)
|
|
.await
|
|
.expect("`MemoryStore` extension missing");
|
|
|
|
let cookie = Option::<TypedHeader<Cookie>>::from_request(req)
|
|
.await
|
|
.unwrap();
|
|
|
|
let session_cookie = cookie.as_ref().and_then(|cookie| cookie.get(COOKIE_NAME));
|
|
|
|
// return the new created session cookie for client
|
|
if session_cookie.is_none() {
|
|
let user_id = UserId::new();
|
|
let mut session = Session::new();
|
|
session.insert("user_id", user_id).unwrap();
|
|
let cookie = store.store_session(session).await.unwrap().unwrap();
|
|
return Ok(Self::CreatedFreshUserId(FreshUserId {
|
|
user_id,
|
|
cookie: HeaderValue::from_str(format!("{}={}", COOKIE_NAME, cookie).as_str())
|
|
.unwrap(),
|
|
}));
|
|
}
|
|
|
|
tracing::debug!(
|
|
"UserIdFromSession: got session cookie from user agent, {}={}",
|
|
COOKIE_NAME,
|
|
session_cookie.unwrap()
|
|
);
|
|
// continue to decode the session cookie
|
|
let user_id = if let Some(session) = store
|
|
.load_session(session_cookie.unwrap().to_owned())
|
|
.await
|
|
.unwrap()
|
|
{
|
|
if let Some(user_id) = session.get::<UserId>("user_id") {
|
|
tracing::debug!(
|
|
"UserIdFromSession: session decoded success, user_id={:?}",
|
|
user_id
|
|
);
|
|
user_id
|
|
} else {
|
|
return Err((
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
"No `user_id` found in session",
|
|
));
|
|
}
|
|
} else {
|
|
tracing::debug!(
|
|
"UserIdFromSession: err session not exists in store, {}={}",
|
|
COOKIE_NAME,
|
|
session_cookie.unwrap()
|
|
);
|
|
return Err((StatusCode::BAD_REQUEST, "No session found for cookie"));
|
|
};
|
|
|
|
Ok(Self::FoundUserId(user_id))
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
|
|
struct UserId(Uuid);
|
|
|
|
impl UserId {
|
|
fn new() -> Self {
|
|
Self(Uuid::new_v4())
|
|
}
|
|
}
|
|
|
|
#[derive(Template)]
|
|
#[template(path = "index.html")]
|
|
struct IndexTemplate {
|
|
userid: bool,
|
|
name: String,
|
|
}
|
|
|
|
struct HtmlTemplate<T>(T);
|
|
|
|
impl<T> IntoResponse for HtmlTemplate<T>
|
|
where
|
|
T: Template,
|
|
{
|
|
fn into_response(self) -> Response {
|
|
match self.0.render() {
|
|
Ok(html) => Html(html).into_response(),
|
|
Err(err) => (
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
format!("Failed to render template. Error: {}", err),
|
|
)
|
|
.into_response(),
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn logout(
|
|
Extension(store): Extension<MemoryStore>,
|
|
TypedHeader(cookies): TypedHeader<headers::Cookie>,
|
|
) -> impl IntoResponse {
|
|
let cookie = cookies.get(COOKIE_NAME).unwrap();
|
|
let session = match store.load_session(cookie.to_string()).await.unwrap() {
|
|
Some(s) => s,
|
|
// No session active, just redirect
|
|
None => return Redirect::to("/".parse().unwrap()),
|
|
};
|
|
|
|
store.destroy_session(session).await.unwrap();
|
|
|
|
Redirect::to("/".parse().unwrap())
|
|
}
|
|
|
|
async fn google_auth() -> impl IntoResponse {
|
|
let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256();
|
|
|
|
// Generate the authorization URL to which we'll redirect the user.
|
|
let (auth_url, csrf_state) = google_oauth_client()
|
|
.authorize_url(CsrfToken::new_random)
|
|
.add_scope(Scope::new(
|
|
"https://www.googleapis.com/auth/userinfo.profile".to_string(),
|
|
))
|
|
.add_scope(Scope::new(
|
|
"https://www.googleapis.com/auth/userinfo.email".to_string(),
|
|
))
|
|
.set_pkce_challenge(pkce_code_challenge)
|
|
.url();
|
|
|
|
// Redirect to Google's oauth service
|
|
Redirect::to(auth_url.to_string().parse().unwrap())
|
|
}
|
|
|
|
async fn discord_auth() -> impl IntoResponse {
|
|
let discord_oauth_client = discord_oauth_client();
|
|
let (auth_url, _csrf_token) = discord_oauth_client
|
|
.authorize_url(CsrfToken::new_random)
|
|
.add_scope(Scope::new("identify".to_string()))
|
|
.url();
|
|
|
|
// Redirect to Discord's oauth service
|
|
Redirect::to(auth_url.to_string().parse().unwrap())
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
#[allow(dead_code)]
|
|
struct AuthRequest {
|
|
code: String,
|
|
state: String,
|
|
}
|
|
|
|
async fn google_authorized(
|
|
Query(query): Query<AuthRequest>,
|
|
Extension(store): Extension<MemoryStore>,
|
|
Extension(google_oauth_client): Extension<BasicClient>,
|
|
) -> impl IntoResponse {
|
|
|
|
print!("{}", query.code);
|
|
/*
|
|
// Get an auth token
|
|
let token = google_oauth_client
|
|
.exchange_code(AuthorizationCode::new(query.code.clone()))
|
|
.request_async(async_http_client)
|
|
.await
|
|
.unwrap();
|
|
|
|
// Fetch user data from discord
|
|
let client = reqwest::Client::new();
|
|
let user_data: User = client
|
|
// https://discord.com/developers/docs/resources/user#get-current-user
|
|
.get("https://discordapp.com/api/users/@me")
|
|
.bearer_auth(token.access_token().secret())
|
|
.send()
|
|
.await
|
|
.unwrap()
|
|
.json::<User>()
|
|
.await
|
|
.unwrap();
|
|
*/
|
|
|
|
// Create a new session filled with user data
|
|
let session = Session::new();
|
|
//session.insert("user", &user_data).unwrap();
|
|
|
|
// Store session and get corresponding cookie
|
|
let cookie = store.store_session(session).await.unwrap().unwrap();
|
|
|
|
// Build the cookie
|
|
let cookie = format!("{}={}; SameSite=Lax; Path=/", COOKIE_NAME, cookie);
|
|
|
|
// Set cookie
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert(SET_COOKIE, cookie.parse().unwrap());
|
|
|
|
//(headers, Redirect::to("/dashboard".parse().unwrap()))
|
|
|
|
let mut page = String::new();
|
|
page.push_str(&"Display the data returned by Google\n".to_string());
|
|
page.push_str(&"\nState: ".to_string());
|
|
page.push_str(&query.state);
|
|
page.push_str(&"\nCode: ".to_string());
|
|
page.push_str(&query.code);
|
|
page.push_str(&"\nScope: ".to_string());
|
|
|
|
page
|
|
}
|
|
|
|
async fn discord_authorized(
|
|
Query(query): Query<AuthRequest>,
|
|
Extension(store): Extension<MemoryStore>,
|
|
Extension(oauth_clients): Extension<HashMap::<&str, BasicClient>>,
|
|
) -> impl IntoResponse {
|
|
// Check for Discord client
|
|
if oauth_clients.contains_key("Discord") {
|
|
// Get Discord client
|
|
let discord_oauth_client = oauth_clients.get(&"Discord").unwrap();
|
|
|
|
// Get an auth token
|
|
let token = discord_oauth_client
|
|
.exchange_code(AuthorizationCode::new(query.code.clone()))
|
|
.request_async(async_http_client)
|
|
.await
|
|
.unwrap();
|
|
|
|
// Fetch user data from discord
|
|
let client = reqwest::Client::new();
|
|
let user_data: User = client
|
|
// https://discord.com/developers/docs/resources/user#get-current-user
|
|
.get("https://discordapp.com/api/users/@me")
|
|
.bearer_auth(token.access_token().secret())
|
|
.send()
|
|
.await
|
|
.unwrap()
|
|
.json::<User>()
|
|
.await
|
|
.unwrap();
|
|
|
|
// Create a new session filled with user data
|
|
let mut session = Session::new();
|
|
session.insert("user", &user_data).unwrap();
|
|
|
|
// Store session and get corresponding cookie
|
|
let cookie = store.store_session(session).await.unwrap().unwrap();
|
|
|
|
// Build the cookie
|
|
let cookie = format!("{}={}; SameSite=Lax; Path=/", COOKIE_NAME, cookie);
|
|
|
|
// Set cookie
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert(SET_COOKIE, cookie.parse().unwrap());
|
|
|
|
(headers, Redirect::to("/dashboard".parse().unwrap()))
|
|
} else {
|
|
let mut headers = HeaderMap::new();
|
|
|
|
(headers, Redirect::to("/".parse().unwrap()))
|
|
}
|
|
}
|
|
|
|
struct AuthRedirect;
|
|
|
|
impl IntoResponse for AuthRedirect {
|
|
fn into_response(self) -> Response {
|
|
Redirect::temporary("/login".parse().unwrap()).into_response()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl<B> FromRequest<B> for User
|
|
where
|
|
B: Send,
|
|
{
|
|
// If anything goes wrong or no session is found, redirect to the auth page
|
|
type Rejection = AuthRedirect;
|
|
|
|
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
|
let Extension(store) = Extension::<MemoryStore>::from_request(req)
|
|
.await
|
|
.expect("`MemoryStore` extension is missing");
|
|
|
|
let cookies = TypedHeader::<headers::Cookie>::from_request(req)
|
|
.await
|
|
.map_err(|e| match *e.name() {
|
|
header::COOKIE => match e.reason() {
|
|
TypedHeaderRejectionReason::Missing => AuthRedirect,
|
|
_ => panic!("unexpected error getting Cookie header(s): {}", e),
|
|
},
|
|
_ => panic!("unexpected error getting cookies: {}", e),
|
|
})?;
|
|
let session_cookie = cookies.get(COOKIE_NAME).ok_or(AuthRedirect)?;
|
|
|
|
let session = store
|
|
.load_session(session_cookie.to_string())
|
|
.await
|
|
.unwrap()
|
|
.ok_or(AuthRedirect)?;
|
|
|
|
let user = session.get::<User>("user").ok_or(AuthRedirect)?;
|
|
|
|
Ok(user)
|
|
}
|
|
}
|