Merge with rust-axum-with-google-auth project
This commit is contained in:
parent
3691056d9f
commit
2fb1c74799
File diff suppressed because it is too large
Load Diff
|
|
@ -7,21 +7,24 @@ edition = "2021"
|
||||||
|
|
||||||
# Update all dependencies with `cargo upgrade -i allow && cargo update`
|
# Update all dependencies with `cargo upgrade -i allow && cargo update`
|
||||||
[dependencies]
|
[dependencies]
|
||||||
axum = { version = "0.7.5" }
|
axum = { version = "0.7.6" }
|
||||||
axum_session = { version = "0.14.2" }
|
axum_session = { version = "0.14.2" }
|
||||||
axum-server = { version = "0.7.1" }
|
axum-server = { version = "0.7.1" }
|
||||||
|
axum-extra = { version = "0.9.4", features = ["cookie-private", "typed-header"] }
|
||||||
|
headers = "0.4"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
#serde_json = "1.0.68"
|
serde_json = "1.0"
|
||||||
tokio = { version = "1.40", features = ["full"] }
|
tokio = { version = "1.40", features = ["full"] }
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = { version="0.3", features = ["env-filter"] }
|
tracing-subscriber = { version="0.3", features = ["env-filter"] }
|
||||||
#uuid = { version = "1.1.2", features = ["v4", "serde"] }
|
|
||||||
#async-session = "3.0.0"
|
|
||||||
askama = "0.12"
|
askama = "0.12"
|
||||||
|
minijinja = { version = "2", features = ["loader"] }
|
||||||
oauth2 = "4.4"
|
oauth2 = "4.4"
|
||||||
#reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "json"] }
|
|
||||||
#headers = "0.3"
|
|
||||||
http = "1.1"
|
http = "1.1"
|
||||||
tower-http = { version = "0.5.2", features = ["full"] }
|
tower-http = { version = "0.6.0", features = ["full"] }
|
||||||
#sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres", "macros", "migrate", "chrono", "json"]}
|
chrono = { version = "0.4.38", features = ["serde"] }
|
||||||
#anyhow = "1.0"
|
sqlx = { version = "0.8", features = ["sqlite", "runtime-tokio"] }
|
||||||
|
uuid = { version = "1.10", features = ["v4"] }
|
||||||
|
dotenvy = "0.15"
|
||||||
|
constant_time_eq = "0.3"
|
||||||
|
reqwest = "0.12"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,89 @@
|
||||||
|
use axum::{
|
||||||
|
http::StatusCode,
|
||||||
|
response::{Html, IntoResponse},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct AppError {
|
||||||
|
code: StatusCode,
|
||||||
|
message: String,
|
||||||
|
user_message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppError {
|
||||||
|
pub fn new(message: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
message: message.into(),
|
||||||
|
user_message: "".to_owned(),
|
||||||
|
code: StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn with_user_message(self, user_message: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
user_message: user_message.into(),
|
||||||
|
..self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// pub fn with_code(self, code: StatusCode) -> Self {
|
||||||
|
// Self {
|
||||||
|
// code,
|
||||||
|
// ..self
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for AppError {
|
||||||
|
fn into_response(self) -> axum::response::Response {
|
||||||
|
println!("AppError: {}", self.message);
|
||||||
|
(
|
||||||
|
self.code,
|
||||||
|
Html(format!(
|
||||||
|
r#"
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<title>Oops!</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>Oops!</h1>
|
||||||
|
<p>Sorry, but something went wrong.</p>
|
||||||
|
<p>{}</p>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"#,
|
||||||
|
self.user_message
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<minijinja::Error> for AppError {
|
||||||
|
fn from(err: minijinja::Error) -> Self {
|
||||||
|
AppError::new(format!("Template error: {:#}", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<dotenvy::Error> for AppError {
|
||||||
|
fn from(err: dotenvy::Error) -> Self {
|
||||||
|
AppError::new(format!("Dotenv error: {:#}", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<sqlx::Error> for AppError {
|
||||||
|
fn from(err: sqlx::Error) -> Self {
|
||||||
|
AppError::new(format!("Database query error: {:#}", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for AppError {
|
||||||
|
fn from(err: String) -> Self {
|
||||||
|
AppError::new(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&str> for AppError {
|
||||||
|
fn from(err: &str) -> Self {
|
||||||
|
AppError::new(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -4,8 +4,7 @@ use axum::{
|
||||||
response::{IntoResponse, Redirect},
|
response::{IntoResponse, Redirect},
|
||||||
};
|
};
|
||||||
use oauth2::{
|
use oauth2::{
|
||||||
basic::BasicClient, AuthUrl, ClientId,
|
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope, TokenResponse, TokenUrl
|
||||||
ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope, TokenUrl,
|
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::env;
|
use std::env;
|
||||||
|
|
@ -46,13 +45,13 @@ pub async fn google_authorized(_session: Session<SessionAnyPool>,
|
||||||
println!("{:?}", query);
|
println!("{:?}", query);
|
||||||
|
|
||||||
// Get an auth token
|
// Get an auth token
|
||||||
//let token = match google_oauth_client
|
let token = match google_oauth_client()
|
||||||
//.exchange_code(AuthorizationCode::new(query.code.clone()))
|
.exchange_code(AuthorizationCode::new(query.code.clone()))
|
||||||
//.request_async(async_http_client)
|
.request_async(async_http_client)
|
||||||
//.await {
|
.await {
|
||||||
// Ok(token) => token,
|
Ok(token) => token,
|
||||||
// Err(_) => panic!("Didn't get a token"),
|
Err(_) => panic!("Didn't get a token"),
|
||||||
// };
|
};
|
||||||
/*
|
/*
|
||||||
// Fetch user data from google
|
// Fetch user data from google
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
|
|
@ -77,7 +76,8 @@ pub async fn google_authorized(_session: Session<SessionAnyPool>,
|
||||||
page.push_str(&query.state);
|
page.push_str(&query.state);
|
||||||
page.push_str(&"\nCode: ".to_string());
|
page.push_str(&"\nCode: ".to_string());
|
||||||
page.push_str(&query.code);
|
page.push_str(&query.code);
|
||||||
page.push_str(&"\nScope: ".to_string());
|
page.push_str(&"\nAccess Token: ".to_string());
|
||||||
|
page.push_str(&token.access_token().secret());
|
||||||
|
|
||||||
page
|
page
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,350 +0,0 @@
|
||||||
use askama::Template;
|
|
||||||
use async_session::{MemoryStore, Session, SessionStore as _};
|
|
||||||
use axum::{
|
|
||||||
async_trait,
|
|
||||||
extract::{
|
|
||||||
rejection::TypedHeaderRejectionReason, Extension, FromRequest, RequestParts,
|
|
||||||
TypedHeader,
|
|
||||||
},
|
|
||||||
headers::Cookie,
|
|
||||||
http::{
|
|
||||||
self,
|
|
||||||
header::{HeaderValue},
|
|
||||||
StatusCode
|
|
||||||
},
|
|
||||||
response::{Html, IntoResponse, Redirect, Response},
|
|
||||||
routing::{get, get_service},
|
|
||||||
Router,
|
|
||||||
};
|
|
||||||
use http::{header};
|
|
||||||
use oauth2::{
|
|
||||||
basic::BasicClient,
|
|
||||||
};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::{net::SocketAddr, collections::HashMap};
|
|
||||||
use tower_http::services::ServeDir;
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
use sqlx::{PgPool};
|
|
||||||
use anyhow::*;
|
|
||||||
use sqlx::postgres::PgPoolOptions;
|
|
||||||
|
|
||||||
mod db;
|
|
||||||
|
|
||||||
use db::*;
|
|
||||||
|
|
||||||
mod google_oauth;
|
|
||||||
mod facebook_oauth;
|
|
||||||
mod discord_oauth;
|
|
||||||
|
|
||||||
use google_oauth::*;
|
|
||||||
use facebook_oauth::*;
|
|
||||||
use discord_oauth::*;
|
|
||||||
|
|
||||||
const COOKIE_NAME: &str = "SESSION";
|
|
||||||
|
|
||||||
// The user data we'll get back from Discord.
|
|
||||||
// https://discord.com/developers/docs/resources/user#user-object-user-structure
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct User {
|
|
||||||
id: String,
|
|
||||||
avatar: Option<String>,
|
|
||||||
username: String,
|
|
||||||
discriminator: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() {
|
|
||||||
// Set the default environment variables
|
|
||||||
if std::env::var_os("RUST_LOG").is_none() {
|
|
||||||
std::env::set_var("RUST_LOG", "example_sessions=debug")
|
|
||||||
}
|
|
||||||
if std::env::var_os("GOOGLE_CLIENT_ID").is_none() {
|
|
||||||
std::env::set_var("GOOGLE_CLIENT_ID", "735264084619-clsmvgdqdmum4rvrcj0kuk28k9agir1c.apps.googleusercontent.com")
|
|
||||||
}
|
|
||||||
if std::env::var_os("GOOGLE_CLIENT_SECRET").is_none() {
|
|
||||||
std::env::set_var("GOOGLE_CLIENT_SECRET", "L6uI7FQGoMJd-ay1HO_iGJ6M")
|
|
||||||
}
|
|
||||||
if std::env::var_os("DISCORD_CLIENT_ID").is_none() {
|
|
||||||
std::env::set_var("DISCORD_CLIENT_ID", "956189108559036427")
|
|
||||||
}
|
|
||||||
if std::env::var_os("DISCORD_CLIENT_SECRET").is_none() {
|
|
||||||
std::env::set_var("DISCORD_CLIENT_SECRET", "dx2DZxjDhVMCCnGX4xpz5MxSTgZ4lHBI")
|
|
||||||
}
|
|
||||||
if std::env::var_os("FACEBOOK_CLIENT_ID").is_none() {
|
|
||||||
std::env::set_var("FACEBOOK_CLIENT_ID", "1529124327484248")
|
|
||||||
}
|
|
||||||
if std::env::var_os("FACEBOOK_CLIENT_SECRET").is_none() {
|
|
||||||
std::env::set_var("FACEBOOK_CLIENT_SECRET", "189509b5eb907b3ce34b7e8459030f21")
|
|
||||||
}
|
|
||||||
|
|
||||||
// initialize tracing
|
|
||||||
tracing_subscriber::fmt::init();
|
|
||||||
|
|
||||||
// Initialize database
|
|
||||||
let db = DBApplication::new("postgres://postgres:postgres@localhost/sqlx-demo".into()).await?;
|
|
||||||
println!("Connection acquired!");
|
|
||||||
|
|
||||||
|
|
||||||
// `MemoryStore` just used as an example. Don't use this in production.
|
|
||||||
let store = MemoryStore::new();
|
|
||||||
|
|
||||||
// Create HashMap to store oauth configurations
|
|
||||||
let mut oauth_clients = HashMap::<&str, BasicClient>::new();
|
|
||||||
|
|
||||||
// Get the client structures
|
|
||||||
let facebook_oauth_client = facebook_oauth_client();
|
|
||||||
let discord_oauth_client = discord_oauth_client();
|
|
||||||
let google_oauth_client = google_oauth_client();
|
|
||||||
|
|
||||||
// Get oauth clients for the hashmap
|
|
||||||
oauth_clients.insert("Facebook", facebook_oauth_client);
|
|
||||||
oauth_clients.insert("Discord", discord_oauth_client);
|
|
||||||
oauth_clients.insert("Google", google_oauth_client);
|
|
||||||
|
|
||||||
// build our application with a route
|
|
||||||
let app = Router::new()
|
|
||||||
// `GET /` goes to `root`
|
|
||||||
.nest(
|
|
||||||
"/assets",
|
|
||||||
get_service(ServeDir::new("templates/assets")).handle_error(
|
|
||||||
|error: std::io::Error| async move {
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Unhandled internal error: {}", error),
|
|
||||||
)
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.route("/", get(index))
|
|
||||||
.route("/login", get(login))
|
|
||||||
.route("/logout", get(logout))
|
|
||||||
.route("/dashboard", get(dashboard))
|
|
||||||
.route("/google_auth", get(google_auth))
|
|
||||||
.route("/auth/google", get(google_authorized))
|
|
||||||
.route("/facebook_auth", get(facebook_auth))
|
|
||||||
.route("/auth/facebook", get(facebook_authorized))
|
|
||||||
.route("/discord_auth", get(discord_auth))
|
|
||||||
.route("/auth/discord", get(discord_authorized))
|
|
||||||
.layer(Extension(store))
|
|
||||||
.layer(Extension(oauth_clients));
|
|
||||||
|
|
||||||
// run our app with hyper
|
|
||||||
// `axum::Server` is a re-export of `hyper::Server`
|
|
||||||
let addr = SocketAddr::from(([0, 0, 0, 0], 40192));
|
|
||||||
tracing::debug!("listening on {}", addr);
|
|
||||||
axum::Server::bind(&addr)
|
|
||||||
.serve(app.into_make_service())
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Session is optional
|
|
||||||
async fn index(user: Option<User>) -> impl IntoResponse {
|
|
||||||
let (userid, name) = match user {
|
|
||||||
Some(u) => (true, u.username),
|
|
||||||
None => (false, "You're not logged in.".to_string()),
|
|
||||||
};
|
|
||||||
let template = IndexTemplate { userid, name };
|
|
||||||
HtmlTemplate(template)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn login() -> impl IntoResponse {
|
|
||||||
let name = "".to_string();
|
|
||||||
let userid = false;
|
|
||||||
let template = LoginTemplate { userid, name };
|
|
||||||
HtmlTemplate(template)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Valid user session required. If there is none, redirect to the auth page
|
|
||||||
async fn dashboard(user: User) -> impl IntoResponse {
|
|
||||||
let name = user.username;
|
|
||||||
let userid = true;
|
|
||||||
let template = DashboardTemplate { userid, name };
|
|
||||||
HtmlTemplate(template)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Template)]
|
|
||||||
#[template(path = "login.html")]
|
|
||||||
struct LoginTemplate {
|
|
||||||
userid: bool,
|
|
||||||
name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Template)]
|
|
||||||
#[template(path = "dashboard.html")]
|
|
||||||
struct DashboardTemplate {
|
|
||||||
userid: bool,
|
|
||||||
name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct FreshUserId {
|
|
||||||
pub user_id: UserId,
|
|
||||||
pub cookie: HeaderValue,
|
|
||||||
}
|
|
||||||
|
|
||||||
enum UserIdFromSession {
|
|
||||||
FoundUserId(UserId),
|
|
||||||
CreatedFreshUserId(FreshUserId),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl<B> FromRequest<B> for UserIdFromSession
|
|
||||||
where
|
|
||||||
B: Send,
|
|
||||||
{
|
|
||||||
type Rejection = (StatusCode, &'static str);
|
|
||||||
|
|
||||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
|
||||||
let Extension(store) = Extension::<MemoryStore>::from_request(req)
|
|
||||||
.await
|
|
||||||
.expect("`MemoryStore` extension missing");
|
|
||||||
|
|
||||||
let cookie = Option::<TypedHeader<Cookie>>::from_request(req)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let session_cookie = cookie.as_ref().and_then(|cookie| cookie.get(COOKIE_NAME));
|
|
||||||
|
|
||||||
// return the new created session cookie for client
|
|
||||||
if session_cookie.is_none() {
|
|
||||||
let user_id = UserId::new();
|
|
||||||
let mut session = Session::new();
|
|
||||||
session.insert("user_id", user_id).unwrap();
|
|
||||||
let cookie = store.store_session(session).await.unwrap().unwrap();
|
|
||||||
return Ok(Self::CreatedFreshUserId(FreshUserId {
|
|
||||||
user_id,
|
|
||||||
cookie: HeaderValue::from_str(format!("{}={}", COOKIE_NAME, cookie).as_str())
|
|
||||||
.unwrap(),
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
tracing::debug!(
|
|
||||||
"UserIdFromSession: got session cookie from user agent, {}={}",
|
|
||||||
COOKIE_NAME,
|
|
||||||
session_cookie.unwrap()
|
|
||||||
);
|
|
||||||
// continue to decode the session cookie
|
|
||||||
let user_id = if let Some(session) = store
|
|
||||||
.load_session(session_cookie.unwrap().to_owned())
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
{
|
|
||||||
if let Some(user_id) = session.get::<UserId>("user_id") {
|
|
||||||
tracing::debug!(
|
|
||||||
"UserIdFromSession: session decoded success, user_id={:?}",
|
|
||||||
user_id
|
|
||||||
);
|
|
||||||
user_id
|
|
||||||
} else {
|
|
||||||
return Err((
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
"No `user_id` found in session",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tracing::debug!(
|
|
||||||
"UserIdFromSession: err session not exists in store, {}={}",
|
|
||||||
COOKIE_NAME,
|
|
||||||
session_cookie.unwrap()
|
|
||||||
);
|
|
||||||
return Err((StatusCode::BAD_REQUEST, "No session found for cookie"));
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Self::FoundUserId(user_id))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
|
|
||||||
struct UserId(Uuid);
|
|
||||||
|
|
||||||
impl UserId {
|
|
||||||
fn new() -> Self {
|
|
||||||
Self(Uuid::new_v4())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Template)]
|
|
||||||
#[template(path = "index.html")]
|
|
||||||
struct IndexTemplate {
|
|
||||||
userid: bool,
|
|
||||||
name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct HtmlTemplate<T>(T);
|
|
||||||
|
|
||||||
impl<T> IntoResponse for HtmlTemplate<T>
|
|
||||||
where
|
|
||||||
T: Template,
|
|
||||||
{
|
|
||||||
fn into_response(self) -> Response {
|
|
||||||
match self.0.render() {
|
|
||||||
Ok(html) => Html(html).into_response(),
|
|
||||||
Err(err) => (
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Failed to render template. Error: {}", err),
|
|
||||||
)
|
|
||||||
.into_response(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn logout(
|
|
||||||
Extension(store): Extension<MemoryStore>,
|
|
||||||
TypedHeader(cookies): TypedHeader<headers::Cookie>,
|
|
||||||
) -> impl IntoResponse {
|
|
||||||
let cookie = cookies.get(COOKIE_NAME).unwrap();
|
|
||||||
let session = match store.load_session(cookie.to_string()).await.unwrap() {
|
|
||||||
Some(s) => s,
|
|
||||||
// No session active, just redirect
|
|
||||||
None => return Redirect::to(&"/"),
|
|
||||||
};
|
|
||||||
|
|
||||||
store.destroy_session(session).await.unwrap();
|
|
||||||
|
|
||||||
Redirect::to(&"/")
|
|
||||||
}
|
|
||||||
|
|
||||||
struct AuthRedirect;
|
|
||||||
|
|
||||||
impl IntoResponse for AuthRedirect {
|
|
||||||
fn into_response(self) -> Response {
|
|
||||||
Redirect::temporary(&"/login").into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl<B> FromRequest<B> for User
|
|
||||||
where
|
|
||||||
B: Send,
|
|
||||||
{
|
|
||||||
// If anything goes wrong or no session is found, redirect to the auth page
|
|
||||||
type Rejection = AuthRedirect;
|
|
||||||
|
|
||||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
|
||||||
let Extension(store) = Extension::<MemoryStore>::from_request(req)
|
|
||||||
.await
|
|
||||||
.expect("`MemoryStore` extension is missing");
|
|
||||||
|
|
||||||
let cookies = TypedHeader::<headers::Cookie>::from_request(req)
|
|
||||||
.await
|
|
||||||
.map_err(|e| match *e.name() {
|
|
||||||
header::COOKIE => match e.reason() {
|
|
||||||
TypedHeaderRejectionReason::Missing => AuthRedirect,
|
|
||||||
_ => panic!("unexpected error getting Cookie header(s): {}", e),
|
|
||||||
},
|
|
||||||
_ => panic!("unexpected error getting cookies: {}", e),
|
|
||||||
})?;
|
|
||||||
let session_cookie = cookies.get(COOKIE_NAME).ok_or(AuthRedirect)?;
|
|
||||||
|
|
||||||
let session = store
|
|
||||||
.load_session(session_cookie.to_string())
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.ok_or(AuthRedirect)?;
|
|
||||||
|
|
||||||
let user = session.get::<User>("user").ok_or(AuthRedirect)?;
|
|
||||||
|
|
||||||
Ok(user)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -5,12 +5,21 @@ use axum::{
|
||||||
Router,
|
Router,
|
||||||
routing::{get, get_service}, response::{Html, IntoResponse, Response},
|
routing::{get, get_service}, response::{Html, IntoResponse, Response},
|
||||||
};
|
};
|
||||||
|
use axum_extra::extract::cookie::PrivateCookieJar;
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Serialize, Deserialize};
|
||||||
|
use sqlx::{sqlite::SqlitePoolOptions, SqlitePool};
|
||||||
use tower_http::services::ServeDir;
|
use tower_http::services::ServeDir;
|
||||||
|
|
||||||
|
mod error_handling;
|
||||||
mod google_oauth;
|
mod google_oauth;
|
||||||
|
mod middlewares;
|
||||||
|
mod oauth;
|
||||||
|
|
||||||
|
use error_handling::AppError;
|
||||||
use google_oauth::*;
|
use google_oauth::*;
|
||||||
|
use middlewares::{check_auth, inject_user_data};
|
||||||
|
use oauth::{login, logout, oauth_return};
|
||||||
|
|
||||||
struct HtmlTemplate<T>(T);
|
struct HtmlTemplate<T>(T);
|
||||||
|
|
||||||
|
|
@ -51,6 +60,18 @@ struct DashboardTemplate {
|
||||||
name: String,
|
name: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AppState {
|
||||||
|
pub db_pool: SqlitePool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct UserData {
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub user_id: i64,
|
||||||
|
pub user_email: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct User {
|
struct User {
|
||||||
id: String,
|
id: String,
|
||||||
|
|
@ -67,7 +88,17 @@ async fn main() {
|
||||||
let session_config = SessionConfig::default()
|
let session_config = SessionConfig::default()
|
||||||
.with_table_name("sessions");
|
.with_table_name("sessions");
|
||||||
|
|
||||||
let session_store = SessionStore::<SessionAnyPool>::new(None, session_config).await.unwrap();
|
//let session_store = SessionStore::<SessionAnyPool>::new(None, session_config).await.unwrap();
|
||||||
|
|
||||||
|
let db_pool = SqlitePoolOptions::new()
|
||||||
|
.max_connections(5)
|
||||||
|
.connect("sqlite://db/db.sqlite3")
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("DB connection failed: {}", e))?;
|
||||||
|
|
||||||
|
let app_state = AppState {
|
||||||
|
db_pool: db_pool
|
||||||
|
};
|
||||||
|
|
||||||
// Create HashMap to store oauth configurations
|
// Create HashMap to store oauth configurations
|
||||||
// let mut oauth_clients = HashMap::<&str, BasicClient>::new();
|
// let mut oauth_clients = HashMap::<&str, BasicClient>::new();
|
||||||
|
|
@ -91,7 +122,9 @@ async fn main() {
|
||||||
.route("/login", get(login))
|
.route("/login", get(login))
|
||||||
.route("/google_auth", get(google_auth))
|
.route("/google_auth", get(google_auth))
|
||||||
.route("/auth/google", get(google_authorized))
|
.route("/auth/google", get(google_authorized))
|
||||||
.layer(SessionLayer::new(session_store))
|
.route("/logout", get(logout))
|
||||||
|
//.layer(SessionLayer::new(session_store))
|
||||||
|
.with_state(app_state)
|
||||||
// .layer(Extension(oauth_clients));
|
// .layer(Extension(oauth_clients));
|
||||||
;
|
;
|
||||||
|
|
||||||
|
|
@ -112,7 +145,7 @@ async fn index(session: Session<SessionAnyPool>) -> impl IntoResponse {
|
||||||
HtmlTemplate(template)
|
HtmlTemplate(template)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn login(session: Session<SessionAnyPool>) -> impl IntoResponse {
|
/* async fn login(session: Session<SessionAnyPool>) -> impl IntoResponse {
|
||||||
let logged_in = session.get("logged_in").unwrap_or(false);
|
let logged_in = session.get("logged_in").unwrap_or(false);
|
||||||
let name = session.get("name").unwrap_or("".to_string());
|
let name = session.get("name").unwrap_or("".to_string());
|
||||||
|
|
||||||
|
|
@ -123,7 +156,7 @@ async fn login(session: Session<SessionAnyPool>) -> impl IntoResponse {
|
||||||
|
|
||||||
let template = LoginTemplate { logged_in, name };
|
let template = LoginTemplate { logged_in, name };
|
||||||
HtmlTemplate(template)
|
HtmlTemplate(template)
|
||||||
}
|
} */
|
||||||
|
|
||||||
async fn dashboard(session: Session<SessionAnyPool>) -> impl IntoResponse {
|
async fn dashboard(session: Session<SessionAnyPool>) -> impl IntoResponse {
|
||||||
let logged_in = session.get("logged_in").unwrap_or(false);
|
let logged_in = session.get("logged_in").unwrap_or(false);
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
use super::{AppError, UserData};
|
||||||
|
use axum::{
|
||||||
|
body::Body,
|
||||||
|
extract::State,
|
||||||
|
http::Request,
|
||||||
|
middleware::Next,
|
||||||
|
response::{IntoResponse, Redirect},
|
||||||
|
};
|
||||||
|
use axum_extra::TypedHeader;
|
||||||
|
use chrono::Utc;
|
||||||
|
use headers::Cookie;
|
||||||
|
use sqlx::SqlitePool;
|
||||||
|
|
||||||
|
pub async fn inject_user_data(
|
||||||
|
State(db_pool): State<SqlitePool>,
|
||||||
|
cookie: Option<TypedHeader<Cookie>>,
|
||||||
|
mut request: Request<Body>,
|
||||||
|
next: Next,
|
||||||
|
) -> Result<impl IntoResponse, AppError> {
|
||||||
|
if let Some(cookie) = cookie {
|
||||||
|
if let Some(session_token) = cookie.get("session_token") {
|
||||||
|
let session_token: Vec<&str> = session_token.split('_').collect();
|
||||||
|
let query: Result<(i64, i64, String), _> = sqlx::query_as(
|
||||||
|
r#"SELECT user_id,expires_at,session_token_p2 FROM user_sessions WHERE session_token_p1=?"#,
|
||||||
|
)
|
||||||
|
.bind(session_token[0])
|
||||||
|
.fetch_one(&db_pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if let Ok(query) = query {
|
||||||
|
if let Ok(session_token_p2_db) = query.2.as_bytes().try_into() {
|
||||||
|
if let Ok(session_token_p2_cookie) = session_token
|
||||||
|
.get(1)
|
||||||
|
.copied()
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_bytes()
|
||||||
|
.try_into()
|
||||||
|
{
|
||||||
|
if constant_time_eq::constant_time_eq_n::<36>(
|
||||||
|
session_token_p2_cookie,
|
||||||
|
session_token_p2_db,
|
||||||
|
) {
|
||||||
|
let user_id = query.0;
|
||||||
|
let expires_at = query.1;
|
||||||
|
if expires_at > Utc::now().timestamp() {
|
||||||
|
let query: Result<(String,), _> =
|
||||||
|
sqlx::query_as(r#"SELECT email FROM users WHERE id=?"#)
|
||||||
|
.bind(user_id)
|
||||||
|
.fetch_one(&db_pool)
|
||||||
|
.await;
|
||||||
|
if let Ok(query) = query {
|
||||||
|
let user_email = query.0;
|
||||||
|
request.extensions_mut().insert(Some(UserData {
|
||||||
|
user_id,
|
||||||
|
user_email,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(next.run(request).await)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn check_auth(request: Request<Body>, next: Next) -> Result<impl IntoResponse, AppError> {
|
||||||
|
if request
|
||||||
|
.extensions()
|
||||||
|
.get::<Option<UserData>>()
|
||||||
|
.ok_or("check_auth: extensions have no UserData")?
|
||||||
|
.is_some()
|
||||||
|
{
|
||||||
|
Ok(next.run(request).await)
|
||||||
|
} else {
|
||||||
|
let login_url = "/login?return_url=".to_owned() + &*request.uri().to_string();
|
||||||
|
Ok(Redirect::to(login_url.as_str()).into_response())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,5 +1,229 @@
|
||||||
pub struct OauthSession {
|
// Code adapted from https://github.com/ramosbugs/oauth2-rs/blob/main/examples/google.rs
|
||||||
client: BasicClient,
|
//
|
||||||
pkce_code_verifier: PkceCodeVerifier,
|
// Must set the enviroment variables:
|
||||||
csrf_state: CsrfToken
|
// GOOGLE_CLIENT_ID=xxx
|
||||||
|
// GOOGLE_CLIENT_SECRET=yyy
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::{Extension, Host, Query, State},
|
||||||
|
response::{IntoResponse, Redirect},
|
||||||
|
};
|
||||||
|
use axum_extra::TypedHeader;
|
||||||
|
use dotenvy::var;
|
||||||
|
use headers::Cookie;
|
||||||
|
use oauth2::{
|
||||||
|
basic::BasicClient, reqwest::http_client, AuthUrl, AuthorizationCode, ClientId, ClientSecret,
|
||||||
|
CsrfToken, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RevocationUrl, Scope,
|
||||||
|
TokenResponse, TokenUrl,
|
||||||
|
};
|
||||||
|
|
||||||
|
use chrono::Utc;
|
||||||
|
use sqlx::SqlitePool;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use super::{AppError, UserData};
|
||||||
|
|
||||||
|
fn get_client(hostname: String) -> Result<BasicClient, AppError> {
|
||||||
|
let google_client_id = ClientId::new(var("GOOGLE_CLIENT_ID")?);
|
||||||
|
let google_client_secret = ClientSecret::new(var("GOOGLE_CLIENT_SECRET")?);
|
||||||
|
let auth_url = AuthUrl::new("https://accounts.google.com/o/oauth2/v2/auth".to_string())
|
||||||
|
.map_err(|_| "OAuth: invalid authorization endpoint URL")?;
|
||||||
|
let token_url = TokenUrl::new("https://www.googleapis.com/oauth2/v3/token".to_string())
|
||||||
|
.map_err(|_| "OAuth: invalid token endpoint URL")?;
|
||||||
|
|
||||||
|
let protocol = if hostname.starts_with("localhost") || hostname.starts_with("127.0.0.1") {
|
||||||
|
"http"
|
||||||
|
} else {
|
||||||
|
"https"
|
||||||
|
};
|
||||||
|
|
||||||
|
let redirect_url = format!("{}://{}/oauth_return", protocol, hostname);
|
||||||
|
|
||||||
|
// Set up the config for the Google OAuth2 process.
|
||||||
|
let client = BasicClient::new(
|
||||||
|
google_client_id,
|
||||||
|
Some(google_client_secret),
|
||||||
|
auth_url,
|
||||||
|
Some(token_url),
|
||||||
|
)
|
||||||
|
.set_redirect_uri(RedirectUrl::new(redirect_url).map_err(|_| "OAuth: invalid redirect URL")?)
|
||||||
|
.set_revocation_uri(
|
||||||
|
RevocationUrl::new("https://oauth2.googleapis.com/revoke".to_string())
|
||||||
|
.map_err(|_| "OAuth: invalid revocation endpoint URL")?,
|
||||||
|
);
|
||||||
|
Ok(client)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn login(
|
||||||
|
Extension(user_data): Extension<Option<UserData>>,
|
||||||
|
Query(mut params): Query<HashMap<String, String>>,
|
||||||
|
State(db_pool): State<SqlitePool>,
|
||||||
|
Host(hostname): Host,
|
||||||
|
) -> Result<Redirect, AppError> {
|
||||||
|
if user_data.is_some() {
|
||||||
|
// check if already authenticated
|
||||||
|
return Ok(Redirect::to("/"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let return_url = params
|
||||||
|
.remove("return_url")
|
||||||
|
.unwrap_or_else(|| "/".to_string());
|
||||||
|
// TODO: check if return_url is valid
|
||||||
|
|
||||||
|
let client = get_client(hostname)?;
|
||||||
|
|
||||||
|
let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||||
|
|
||||||
|
let (authorize_url, csrf_state) = client
|
||||||
|
.authorize_url(CsrfToken::new_random)
|
||||||
|
.add_scope(Scope::new(
|
||||||
|
"https://www.googleapis.com/auth/userinfo.email".to_string(),
|
||||||
|
))
|
||||||
|
.set_pkce_challenge(pkce_code_challenge)
|
||||||
|
.url();
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO oauth2_state_storage (csrf_state, pkce_code_verifier, return_url) VALUES (?, ?, ?);",
|
||||||
|
)
|
||||||
|
.bind(csrf_state.secret())
|
||||||
|
.bind(pkce_code_verifier.secret())
|
||||||
|
.bind(return_url)
|
||||||
|
.execute(&db_pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Redirect::to(authorize_url.as_str()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn oauth_return(
|
||||||
|
Query(mut params): Query<HashMap<String, String>>,
|
||||||
|
State(db_pool): State<SqlitePool>,
|
||||||
|
Host(hostname): Host,
|
||||||
|
) -> Result<impl IntoResponse, AppError> {
|
||||||
|
let state = CsrfToken::new(params.remove("state").ok_or("OAuth: without state")?);
|
||||||
|
let code = AuthorizationCode::new(params.remove("code").ok_or("OAuth: without code")?);
|
||||||
|
|
||||||
|
let query: (String, String) = sqlx::query_as(
|
||||||
|
r#"DELETE FROM oauth2_state_storage WHERE csrf_state = ? RETURNING pkce_code_verifier,return_url"#,
|
||||||
|
)
|
||||||
|
.bind(state.secret())
|
||||||
|
.fetch_one(&db_pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Alternative:
|
||||||
|
// let query: (String, String) = sqlx::query_as(
|
||||||
|
// r#"SELECT pkce_code_verifier,return_url FROM oauth2_state_storage WHERE csrf_state = ?"#,
|
||||||
|
// )
|
||||||
|
// .bind(state.secret())
|
||||||
|
// .fetch_one(&db_pool)
|
||||||
|
// .await?;
|
||||||
|
// let _ = sqlx::query("DELETE FROM oauth2_state_storage WHERE csrf_state = ?")
|
||||||
|
// .bind(state.secret())
|
||||||
|
// .execute(&db_pool)
|
||||||
|
// .await;
|
||||||
|
|
||||||
|
let pkce_code_verifier = query.0;
|
||||||
|
let return_url = query.1;
|
||||||
|
let pkce_code_verifier = PkceCodeVerifier::new(pkce_code_verifier);
|
||||||
|
|
||||||
|
// Exchange the code with a token.
|
||||||
|
let client = get_client(hostname)?;
|
||||||
|
let token_response = tokio::task::spawn_blocking(move || {
|
||||||
|
client
|
||||||
|
.exchange_code(code)
|
||||||
|
.set_pkce_verifier(pkce_code_verifier)
|
||||||
|
.request(http_client)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|_| "OAuth: exchange_code failure")?
|
||||||
|
.map_err(|_| "OAuth: tokio spawn blocking failure")?;
|
||||||
|
let access_token = token_response.access_token().secret();
|
||||||
|
|
||||||
|
// Get user info from Google
|
||||||
|
let url =
|
||||||
|
"https://www.googleapis.com/oauth2/v2/userinfo?oauth_token=".to_owned() + access_token;
|
||||||
|
let body = reqwest::get(url)
|
||||||
|
.await
|
||||||
|
.map_err(|_| "OAuth: reqwest failed to query userinfo")?
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map_err(|_| "OAuth: reqwest received invalid userinfo")?;
|
||||||
|
let mut body: serde_json::Value =
|
||||||
|
serde_json::from_str(body.as_str()).map_err(|_| "OAuth: Serde failed to parse userinfo")?;
|
||||||
|
let email = body["email"]
|
||||||
|
.take()
|
||||||
|
.as_str()
|
||||||
|
.ok_or("OAuth: Serde failed to parse email address")?
|
||||||
|
.to_owned();
|
||||||
|
let verified_email = body["verified_email"]
|
||||||
|
.take()
|
||||||
|
.as_bool()
|
||||||
|
.ok_or("OAuth: Serde failed to parse verified_email")?;
|
||||||
|
if !verified_email {
|
||||||
|
return Err(AppError::new("OAuth: email address is not verified".to_owned())
|
||||||
|
.with_user_message("Your email address is not verified. Please verify your email address with Google and try again.".to_owned()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if user exists in database
|
||||||
|
// If not, create a new user
|
||||||
|
let query: Result<(i64,), _> = sqlx::query_as(r#"SELECT id FROM users WHERE email=?"#)
|
||||||
|
.bind(email.as_str())
|
||||||
|
.fetch_one(&db_pool)
|
||||||
|
.await;
|
||||||
|
let user_id = if let Ok(query) = query {
|
||||||
|
query.0
|
||||||
|
} else {
|
||||||
|
let query: (i64,) = sqlx::query_as("INSERT INTO users (email) VALUES (?) RETURNING id")
|
||||||
|
.bind(email)
|
||||||
|
.fetch_one(&db_pool)
|
||||||
|
.await?;
|
||||||
|
query.0
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create a session for the user
|
||||||
|
let session_token_p1 = Uuid::new_v4().to_string();
|
||||||
|
let session_token_p2 = Uuid::new_v4().to_string();
|
||||||
|
let session_token = [session_token_p1.as_str(), "_", session_token_p2.as_str()].concat();
|
||||||
|
let headers = axum::response::AppendHeaders([(
|
||||||
|
axum::http::header::SET_COOKIE,
|
||||||
|
"session_token=".to_owned()
|
||||||
|
+ &*session_token
|
||||||
|
+ "; path=/; httponly; secure; samesite=strict",
|
||||||
|
)]);
|
||||||
|
let now = Utc::now().timestamp();
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO user_sessions
|
||||||
|
(session_token_p1, session_token_p2, user_id, created_at, expires_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?);",
|
||||||
|
)
|
||||||
|
.bind(session_token_p1)
|
||||||
|
.bind(session_token_p2)
|
||||||
|
.bind(user_id)
|
||||||
|
.bind(now)
|
||||||
|
.bind(now + 60 * 60 * 24)
|
||||||
|
.execute(&db_pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok((headers, Redirect::to(return_url.as_str())))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn logout(
|
||||||
|
cookie: Option<TypedHeader<Cookie>>,
|
||||||
|
State(db_pool): State<SqlitePool>,
|
||||||
|
) -> Result<impl IntoResponse, AppError> {
|
||||||
|
if let Some(cookie) = cookie {
|
||||||
|
if let Some(session_token) = cookie.get("session_token") {
|
||||||
|
let session_token: Vec<&str> = session_token.split('_').collect();
|
||||||
|
let _ = sqlx::query("DELETE FROM user_sessions WHERE session_token_1 = ?")
|
||||||
|
.bind(session_token[0])
|
||||||
|
.execute(&db_pool)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let headers = axum::response::AppendHeaders([(
|
||||||
|
axum::http::header::SET_COOKIE,
|
||||||
|
"session_token=deleted; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT",
|
||||||
|
)]);
|
||||||
|
Ok((headers, Redirect::to("/")))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
{% extends "base.html" %}
|
||||||
|
{% block title %}About{% endblock %}
|
||||||
|
{% block content %}
|
||||||
|
This is a demo OAuth website.
|
||||||
|
{% endblock %}
|
||||||
|
|
@ -38,14 +38,14 @@
|
||||||
src="/assets/icons/numix-circle/web-google.svg"
|
src="/assets/icons/numix-circle/web-google.svg"
|
||||||
width="100" alt="Login with Google">
|
width="100" alt="Login with Google">
|
||||||
</a>
|
</a>
|
||||||
<a href="/facebook_auth" class="px-2" title="Login with Facebook"> <img
|
<!-- <a href="/facebook_auth" class="px-2" title="Login with Facebook"> <img
|
||||||
src="/assets/icons/numix-circle/web-facebook.svg"
|
src="/assets/icons/numix-circle/web-facebook.svg"
|
||||||
width="100" alt="Login with Facebook">
|
width="100" alt="Login with Facebook">
|
||||||
</a>
|
</a>
|
||||||
<a href="/discord_auth" class="px-2" title="Login with Discord">
|
<a href="/discord_auth" class="px-2" title="Login with Discord">
|
||||||
<img src="/assets/icons/numix-circle/discord.svg" width="100" alt="Login with Discord" />
|
<img src="/assets/icons/numix-circle/discord.svg" width="100" alt="Login with Discord" />
|
||||||
</a>
|
</a>
|
||||||
|
-->
|
||||||
</div>
|
</div>
|
||||||
</form>
|
</form>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
{% extends "base.html" %}
|
||||||
|
{% block title %}User Profile{% endblock %}
|
||||||
|
{% block content %}
|
||||||
|
This is your user profile page. <br />
|
||||||
|
Your email address: {{ user_email }}.
|
||||||
|
{% endblock %}
|
||||||
Loading…
Reference in New Issue