Refactor - place all non auth routes in routes.rs
This commit is contained in:
parent
464a05638b
commit
06a6811972
|
|
@ -1,63 +1,253 @@
|
|||
use axum::
|
||||
response::{IntoResponse, Redirect}
|
||||
;
|
||||
use oauth2::{
|
||||
basic::BasicClient, AuthUrl, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope, TokenUrl
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
// Code adapted from https://github.com/ramosbugs/oauth2-rs/blob/main/examples/google.rs
|
||||
//
|
||||
// Must set the enviroment variables:
|
||||
// GOOGLE_CLIENT_ID=xxx
|
||||
// GOOGLE_CLIENT_SECRET=yyy
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct AuthRequest {
|
||||
code: String,
|
||||
state: String,
|
||||
use axum::{
|
||||
extract::{Extension, Host, Query, State},
|
||||
response::{IntoResponse, Redirect},
|
||||
};
|
||||
use axum_extra::TypedHeader;
|
||||
use dotenvy::var;
|
||||
use headers::Cookie;
|
||||
use oauth2::{
|
||||
basic::BasicClient, reqwest::http_client, AuthUrl, AuthorizationCode, ClientId, ClientSecret,
|
||||
CsrfToken, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RevocationUrl, Scope,
|
||||
TokenResponse, TokenUrl,
|
||||
};
|
||||
|
||||
use chrono::Utc;
|
||||
use sqlx::SqlitePool;
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{AppError, UserData};
|
||||
|
||||
fn get_client(hostname: String) -> Result<BasicClient, AppError> {
|
||||
let google_client_id = ClientId::new(var("GOOGLE_CLIENT_ID")?);
|
||||
let google_client_secret = ClientSecret::new(var("GOOGLE_CLIENT_SECRET")?);
|
||||
let auth_url = AuthUrl::new("https://accounts.google.com/o/oauth2/v2/auth".to_string())
|
||||
.map_err(|_| "OAuth: invalid authorization endpoint URL")?;
|
||||
let token_url = TokenUrl::new("https://www.googleapis.com/oauth2/v3/token".to_string())
|
||||
.map_err(|_| "OAuth: invalid token endpoint URL")?;
|
||||
|
||||
let protocol = if hostname.starts_with("localhost") || hostname.starts_with("127.0.0.1") {
|
||||
"http"
|
||||
} else {
|
||||
"https"
|
||||
};
|
||||
|
||||
let redirect_url = format!("{}://{}/google_auth_return", protocol, hostname);
|
||||
|
||||
// Set up the config for the Google OAuth2 process.
|
||||
let client = BasicClient::new(
|
||||
google_client_id,
|
||||
Some(google_client_secret),
|
||||
auth_url,
|
||||
Some(token_url),
|
||||
)
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_url).map_err(|_| "OAuth: invalid redirect URL")?)
|
||||
.set_revocation_uri(
|
||||
RevocationUrl::new("https://oauth2.googleapis.com/revoke".to_string())
|
||||
.map_err(|_| "OAuth: invalid revocation endpoint URL")?,
|
||||
);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
pub async fn google_auth() -> impl IntoResponse {
|
||||
let (pkce_code_challenge, _pkce_code_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
pub async fn login(
|
||||
Extension(user_data): Extension<Option<UserData>>,
|
||||
Query(mut params): Query<HashMap<String, String>>,
|
||||
State(db_pool): State<SqlitePool>,
|
||||
Host(hostname): Host,
|
||||
) -> Result<Redirect, AppError> {
|
||||
|
||||
// Generate the authorization URL to which we'll redirect the user.
|
||||
let (auth_url, _csrf_state) = google_oauth_client()
|
||||
if user_data.is_some() {
|
||||
// check if already authenticated
|
||||
return Ok(Redirect::to("/"));
|
||||
}
|
||||
|
||||
let return_url = params
|
||||
.remove("return_url")
|
||||
.unwrap_or_else(|| "/dashboard".to_string());
|
||||
// TODO: check if return_url is valid
|
||||
println!("Return URL: {}", return_url);
|
||||
|
||||
let client = get_client(hostname)?;
|
||||
|
||||
let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
|
||||
let (authorize_url, csrf_state) = client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
.add_scope(Scope::new(
|
||||
"https://www.googleapis.com/auth/userinfo.profile".to_string(),
|
||||
"https://www.googleapis.com/auth/userinfo.email".to_string(),
|
||||
))
|
||||
.add_scope(Scope::new(
|
||||
"https://www.googleapis.com/auth/userinfo.email".to_string(),
|
||||
"https://www.googleapis.com/auth/userinfo.profile".to_string(),
|
||||
))
|
||||
.set_pkce_challenge(pkce_code_challenge)
|
||||
.url();
|
||||
|
||||
// Redirect to Google's oauth service
|
||||
Redirect::to(&auth_url.to_string())
|
||||
}
|
||||
|
||||
pub fn google_oauth_client() -> BasicClient {
|
||||
if std::env::var_os("GOOGLE_CLIENT_ID").is_none() {
|
||||
std::env::set_var("GOOGLE_CLIENT_ID", "735264084619-clsmvgdqdmum4rvrcj0kuk28k9agir1c.apps.googleusercontent.com")
|
||||
}
|
||||
if std::env::var_os("GOOGLE_CLIENT_SECRET").is_none() {
|
||||
std::env::set_var("GOOGLE_CLIENT_SECRET", "L6uI7FQGoMJd-ay1HO_iGJ6M")
|
||||
}
|
||||
|
||||
let redirect_url = env::var("REDIRECT_URL")
|
||||
.unwrap_or_else(|_| "http://localhost:40192/google_auth_return".to_string());
|
||||
// .unwrap_or_else(|_| "https://www.jean-marie.ca/auth/google".to_string());
|
||||
|
||||
let google_client_id = env::var("GOOGLE_CLIENT_ID").expect("Missing GOOGLE_CLIENT_ID!");
|
||||
let google_client_secret =
|
||||
env::var("GOOGLE_CLIENT_SECRET").expect("Missing GOOGLE_CLIENT_SECRET!");
|
||||
let google_auth_url = env::var("GOOGLE_AUTH_URL")
|
||||
.unwrap_or_else(|_| "https://accounts.google.com/o/oauth2/v2/auth".to_string());
|
||||
let google_token_url = env::var("GOOGLE_TOKEN_URL")
|
||||
.unwrap_or_else(|_| "https://www.googleapis.com/oauth2/v3/token".to_string());
|
||||
|
||||
BasicClient::new(
|
||||
ClientId::new(google_client_id),
|
||||
Some(ClientSecret::new(google_client_secret)),
|
||||
AuthUrl::new(google_auth_url).unwrap(),
|
||||
Some(TokenUrl::new(google_token_url).unwrap()),
|
||||
sqlx::query(
|
||||
"INSERT INTO oauth2_state_storage (csrf_state, pkce_code_verifier, return_url) VALUES (?, ?, ?);",
|
||||
)
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_url).unwrap())
|
||||
.bind(csrf_state.secret())
|
||||
.bind(pkce_code_verifier.secret())
|
||||
.bind(return_url)
|
||||
.execute(&db_pool)
|
||||
.await?;
|
||||
|
||||
Ok(Redirect::to(authorize_url.as_str()))
|
||||
}
|
||||
|
||||
pub async fn google_auth_return(
|
||||
Query(mut params): Query<HashMap<String, String>>,
|
||||
State(db_pool): State<SqlitePool>,
|
||||
Host(hostname): Host,
|
||||
) -> Result<impl IntoResponse, AppError> {
|
||||
let state = CsrfToken::new(params.remove("state").ok_or("OAuth: without state")?);
|
||||
let code = AuthorizationCode::new(params.remove("code").ok_or("OAuth: without code")?);
|
||||
|
||||
let query: (String, String) = sqlx::query_as(
|
||||
r#"DELETE FROM oauth2_state_storage WHERE csrf_state = ? RETURNING pkce_code_verifier,return_url"#,
|
||||
)
|
||||
.bind(state.secret())
|
||||
.fetch_one(&db_pool)
|
||||
.await?;
|
||||
|
||||
// Alternative:
|
||||
// let query: (String, String) = sqlx::query_as(
|
||||
// r#"SELECT pkce_code_verifier,return_url FROM oauth2_state_storage WHERE csrf_state = ?"#,
|
||||
// )
|
||||
// .bind(state.secret())
|
||||
// .fetch_one(&db_pool)
|
||||
// .await?;
|
||||
// let _ = sqlx::query("DELETE FROM oauth2_state_storage WHERE csrf_state = ?")
|
||||
// .bind(state.secret())
|
||||
// .execute(&db_pool)
|
||||
// .await;
|
||||
|
||||
let pkce_code_verifier = query.0;
|
||||
let return_url = query.1;
|
||||
let pkce_code_verifier = PkceCodeVerifier::new(pkce_code_verifier);
|
||||
|
||||
// Exchange the code with a token.
|
||||
let client = get_client(hostname)?;
|
||||
let token_response = tokio::task::spawn_blocking(move || {
|
||||
client
|
||||
.exchange_code(code)
|
||||
.set_pkce_verifier(pkce_code_verifier)
|
||||
.request(http_client)
|
||||
})
|
||||
.await
|
||||
.map_err(|_| "OAuth: exchange_code failure")?
|
||||
.map_err(|_| "OAuth: tokio spawn blocking failure")?;
|
||||
let access_token = token_response.access_token().secret();
|
||||
|
||||
// Get user info from Google
|
||||
let url =
|
||||
"https://www.googleapis.com/oauth2/v2/userinfo?oauth_token=".to_owned() + access_token;
|
||||
let body = reqwest::get(url)
|
||||
.await
|
||||
.map_err(|_| "OAuth: reqwest failed to query userinfo")?
|
||||
.text()
|
||||
.await
|
||||
.map_err(|_| "OAuth: reqwest received invalid userinfo")?;
|
||||
let mut body: serde_json::Value =
|
||||
serde_json::from_str(body.as_str()).map_err(|_| "OAuth: Serde failed to parse userinfo")?;
|
||||
let email = body["email"]
|
||||
.take()
|
||||
.as_str()
|
||||
.ok_or("OAuth: Serde failed to parse email address")?
|
||||
.to_owned();
|
||||
let name = body["name"]
|
||||
.take()
|
||||
.as_str()
|
||||
.ok_or("OAuth: Serde failed to parse email address")?
|
||||
.to_owned();
|
||||
let family_name = body["family_name"]
|
||||
.take()
|
||||
.as_str()
|
||||
.ok_or("OAuth: Serde failed to parse email address")?
|
||||
.to_owned();
|
||||
let given_name = body["given_name"]
|
||||
.take()
|
||||
.as_str()
|
||||
.ok_or("OAuth: Serde failed to parse email address")?
|
||||
.to_owned();
|
||||
let verified_email = body["verified_email"]
|
||||
.take()
|
||||
.as_bool()
|
||||
.ok_or("OAuth: Serde failed to parse verified_email")?;
|
||||
if !verified_email {
|
||||
return Err(AppError::new("OAuth: email address is not verified".to_owned())
|
||||
.with_user_message("Your email address is not verified. Please verify your email address with Google and try again.".to_owned()));
|
||||
}
|
||||
|
||||
// Check if user exists in database
|
||||
// If not, create a new user
|
||||
let query: Result<(i64,), _> = sqlx::query_as(r#"SELECT id FROM users WHERE email=?"#)
|
||||
.bind(email.as_str())
|
||||
.fetch_one(&db_pool)
|
||||
.await;
|
||||
let user_id = if let Ok(query) = query {
|
||||
query.0
|
||||
} else {
|
||||
let query: (i64,) = sqlx::query_as("INSERT INTO users (email, name, family_name, given_name) VALUES (?, ?, ?, ?) RETURNING id")
|
||||
.bind(email)
|
||||
.bind(name)
|
||||
.bind(family_name)
|
||||
.bind(given_name)
|
||||
.fetch_one(&db_pool)
|
||||
.await?;
|
||||
query.0
|
||||
};
|
||||
|
||||
// Create a session for the user
|
||||
let session_token_p1 = Uuid::new_v4().to_string();
|
||||
let session_token_p2 = Uuid::new_v4().to_string();
|
||||
let session_token = [session_token_p1.as_str(), "_", session_token_p2.as_str()].concat();
|
||||
let headers = axum::response::AppendHeaders([(
|
||||
axum::http::header::SET_COOKIE,
|
||||
"session_token=".to_owned()
|
||||
+ &*session_token
|
||||
+ "; path=/; httponly; secure; samesite=strict",
|
||||
)]);
|
||||
let now = Utc::now().timestamp();
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO user_sessions
|
||||
(session_token_p1, session_token_p2, user_id, created_at, expires_at)
|
||||
VALUES (?, ?, ?, ?, ?);",
|
||||
)
|
||||
.bind(session_token_p1)
|
||||
.bind(session_token_p2)
|
||||
.bind(user_id)
|
||||
.bind(now)
|
||||
.bind(now + 60 * 60 * 24)
|
||||
.execute(&db_pool)
|
||||
.await?;
|
||||
|
||||
println!("Returning to: {}", return_url);
|
||||
Ok((headers, Redirect::to(return_url.as_str())))
|
||||
}
|
||||
|
||||
pub async fn logout(
|
||||
cookie: Option<TypedHeader<Cookie>>,
|
||||
State(db_pool): State<SqlitePool>,
|
||||
) -> Result<impl IntoResponse, AppError> {
|
||||
if let Some(cookie) = cookie {
|
||||
if let Some(session_token) = cookie.get("session_token") {
|
||||
let session_token: Vec<&str> = session_token.split('_').collect();
|
||||
let _ = sqlx::query("DELETE FROM user_sessions WHERE session_token_1 = ?")
|
||||
.bind(session_token[0])
|
||||
.execute(&db_pool)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
let headers = axum::response::AppendHeaders([(
|
||||
axum::http::header::SET_COOKIE,
|
||||
"session_token=deleted; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT",
|
||||
)]);
|
||||
Ok((headers, Redirect::to("/")))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
use std::net::SocketAddr;
|
||||
use askama_axum::Template;
|
||||
use axum::{
|
||||
middleware, response::{Html, IntoResponse, Response}, routing::{get, get_service}, Extension, Router
|
||||
middleware, routing::{get, get_service}, Extension, Router
|
||||
};
|
||||
use http::{Request, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{prelude::FromRow, sqlite::SqlitePoolOptions, SqlitePool};
|
||||
use sqlx::migrate::Migrator;
|
||||
|
|
@ -12,39 +10,12 @@ use tower_http::services::ServeDir;
|
|||
mod error_handling;
|
||||
mod google_oauth;
|
||||
mod middlewares;
|
||||
mod oauth;
|
||||
mod routes;
|
||||
|
||||
use error_handling::AppError;
|
||||
use google_oauth::*;
|
||||
use middlewares::{check_auth, inject_user_data};
|
||||
use oauth::{login, logout, google_auth_return};
|
||||
use routes::*;
|
||||
|
||||
struct HtmlTemplate<T>(T);
|
||||
|
||||
impl<T> IntoResponse for HtmlTemplate<T>
|
||||
where
|
||||
T: Template,
|
||||
{
|
||||
fn into_response(self) -> Response {
|
||||
match self.0.render() {
|
||||
Ok(html) => Html(html).into_response(),
|
||||
Err(err) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to render template. Error: {}", err),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Template)]
|
||||
#[template(path = "index.html")]
|
||||
struct IndexTemplate {
|
||||
logged_in: bool,
|
||||
name: String,
|
||||
}
|
||||
use google_oauth::{login, logout, google_auth_return};
|
||||
use routes::{dashboard, index, profile, user_profile, useradmin};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
|
|
@ -85,6 +56,7 @@ async fn main() {
|
|||
// build our application with some routes
|
||||
let app = Router::new()
|
||||
//Routes that require authentication
|
||||
.route("/dashboard", get(dashboard))
|
||||
.route("/profile", get(profile))
|
||||
.route("/useradmin", get(useradmin))
|
||||
.route("/users/:user_id", get(user_profile))
|
||||
|
|
@ -96,7 +68,7 @@ async fn main() {
|
|||
.route("/", get(index))
|
||||
.route("/login", get(login))
|
||||
.route("/logout", get(logout))
|
||||
.route("/google_auth", get(google_auth))
|
||||
//.route("/google_auth", get(google_auth))
|
||||
.route("/google_auth_return", get(google_auth_return))
|
||||
.route_layer(middleware::from_fn_with_state(app_state.db_pool.clone(), inject_user_data))
|
||||
.with_state(app_state.db_pool)
|
||||
|
|
@ -112,15 +84,3 @@ async fn main() {
|
|||
.unwrap();
|
||||
|
||||
}
|
||||
|
||||
async fn index<T>(
|
||||
Extension(user_data): Extension<Option<UserData>>,
|
||||
_request: Request<T>,
|
||||
) -> impl IntoResponse {
|
||||
let user_email = user_data.map(|s| s.name);
|
||||
let logged_in = user_email.is_some();
|
||||
let name = user_email.unwrap_or_default();
|
||||
|
||||
let template = IndexTemplate { logged_in, name};
|
||||
HtmlTemplate(template)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,8 +19,11 @@ pub async fn inject_user_data(
|
|||
mut request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Result<impl IntoResponse, AppError> {
|
||||
println!("Injecting user data");
|
||||
if let Some(cookie) = cookie {
|
||||
println!("{:#?}", cookie.get("session_token"));
|
||||
if let Some(session_token) = cookie.get("session_token") {
|
||||
println!("Found session token: {}", session_token);
|
||||
let session_token: Vec<&str> = session_token.split('_').collect();
|
||||
let query: Result<(i64, i64, String), _> = sqlx::query_as(
|
||||
r#"SELECT user_id,expires_at,session_token_p2 FROM user_sessions WHERE session_token_p1=?"#,
|
||||
|
|
@ -43,6 +46,7 @@ pub async fn inject_user_data(
|
|||
session_token_p2_db,
|
||||
) {
|
||||
let id = query.0;
|
||||
println!("Found user: {}", id);
|
||||
let expires_at = query.1;
|
||||
if expires_at > Utc::now().timestamp() {
|
||||
let row = sqlx::query_as!(
|
||||
|
|
@ -60,6 +64,8 @@ pub async fn inject_user_data(
|
|||
family_name: row.family_name,
|
||||
given_name: row.given_name,
|
||||
}));
|
||||
} else {
|
||||
println!("Session expired");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -76,6 +82,8 @@ pub async fn check_auth(
|
|||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Result<impl IntoResponse, AppError> {
|
||||
println!("check_auth");
|
||||
println!("{:#?}", request);
|
||||
if request
|
||||
.extensions()
|
||||
.get::<Option<UserData>>()
|
||||
|
|
@ -165,3 +173,13 @@ pub async fn check_auth(
|
|||
Ok(Redirect::to(login_url.as_str()).into_response())
|
||||
}
|
||||
}
|
||||
|
||||
fn access_auth (_uri: &str, _userid: i64) {
|
||||
// Check uri and userid for access and build the menu for links to the allowed resources
|
||||
|
||||
let _user_auth = true;
|
||||
|
||||
let _menu: Vec<String> = vec![];
|
||||
|
||||
|
||||
}
|
||||
|
|
@ -1,251 +0,0 @@
|
|||
// Code adapted from https://github.com/ramosbugs/oauth2-rs/blob/main/examples/google.rs
|
||||
//
|
||||
// Must set the enviroment variables:
|
||||
// GOOGLE_CLIENT_ID=xxx
|
||||
// GOOGLE_CLIENT_SECRET=yyy
|
||||
|
||||
use axum::{
|
||||
extract::{Extension, Host, Query, State},
|
||||
response::{IntoResponse, Redirect},
|
||||
};
|
||||
use axum_extra::TypedHeader;
|
||||
use dotenvy::var;
|
||||
use headers::Cookie;
|
||||
use oauth2::{
|
||||
basic::BasicClient, reqwest::http_client, AuthUrl, AuthorizationCode, ClientId, ClientSecret,
|
||||
CsrfToken, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RevocationUrl, Scope,
|
||||
TokenResponse, TokenUrl,
|
||||
};
|
||||
|
||||
use chrono::Utc;
|
||||
use sqlx::SqlitePool;
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{AppError, UserData};
|
||||
|
||||
fn get_client(hostname: String) -> Result<BasicClient, AppError> {
|
||||
let google_client_id = ClientId::new(var("GOOGLE_CLIENT_ID")?);
|
||||
let google_client_secret = ClientSecret::new(var("GOOGLE_CLIENT_SECRET")?);
|
||||
let auth_url = AuthUrl::new("https://accounts.google.com/o/oauth2/v2/auth".to_string())
|
||||
.map_err(|_| "OAuth: invalid authorization endpoint URL")?;
|
||||
let token_url = TokenUrl::new("https://www.googleapis.com/oauth2/v3/token".to_string())
|
||||
.map_err(|_| "OAuth: invalid token endpoint URL")?;
|
||||
|
||||
let protocol = if hostname.starts_with("localhost") || hostname.starts_with("127.0.0.1") {
|
||||
"http"
|
||||
} else {
|
||||
"https"
|
||||
};
|
||||
|
||||
let redirect_url = format!("{}://{}/google_auth_return", protocol, hostname);
|
||||
|
||||
// Set up the config for the Google OAuth2 process.
|
||||
let client = BasicClient::new(
|
||||
google_client_id,
|
||||
Some(google_client_secret),
|
||||
auth_url,
|
||||
Some(token_url),
|
||||
)
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_url).map_err(|_| "OAuth: invalid redirect URL")?)
|
||||
.set_revocation_uri(
|
||||
RevocationUrl::new("https://oauth2.googleapis.com/revoke".to_string())
|
||||
.map_err(|_| "OAuth: invalid revocation endpoint URL")?,
|
||||
);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
pub async fn login(
|
||||
Extension(user_data): Extension<Option<UserData>>,
|
||||
Query(mut params): Query<HashMap<String, String>>,
|
||||
State(db_pool): State<SqlitePool>,
|
||||
Host(hostname): Host,
|
||||
) -> Result<Redirect, AppError> {
|
||||
|
||||
if user_data.is_some() {
|
||||
// check if already authenticated
|
||||
return Ok(Redirect::to("/"));
|
||||
}
|
||||
|
||||
let return_url = params
|
||||
.remove("return_url")
|
||||
.unwrap_or_else(|| "/".to_string());
|
||||
// TODO: check if return_url is valid
|
||||
|
||||
let client = get_client(hostname)?;
|
||||
|
||||
let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
|
||||
let (authorize_url, csrf_state) = client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
.add_scope(Scope::new(
|
||||
"https://www.googleapis.com/auth/userinfo.email".to_string(),
|
||||
))
|
||||
.add_scope(Scope::new(
|
||||
"https://www.googleapis.com/auth/userinfo.profile".to_string(),
|
||||
))
|
||||
.set_pkce_challenge(pkce_code_challenge)
|
||||
.url();
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO oauth2_state_storage (csrf_state, pkce_code_verifier, return_url) VALUES (?, ?, ?);",
|
||||
)
|
||||
.bind(csrf_state.secret())
|
||||
.bind(pkce_code_verifier.secret())
|
||||
.bind(return_url)
|
||||
.execute(&db_pool)
|
||||
.await?;
|
||||
|
||||
Ok(Redirect::to(authorize_url.as_str()))
|
||||
}
|
||||
|
||||
pub async fn google_auth_return(
|
||||
Query(mut params): Query<HashMap<String, String>>,
|
||||
State(db_pool): State<SqlitePool>,
|
||||
Host(hostname): Host,
|
||||
) -> Result<impl IntoResponse, AppError> {
|
||||
let state = CsrfToken::new(params.remove("state").ok_or("OAuth: without state")?);
|
||||
let code = AuthorizationCode::new(params.remove("code").ok_or("OAuth: without code")?);
|
||||
|
||||
let query: (String, String) = sqlx::query_as(
|
||||
r#"DELETE FROM oauth2_state_storage WHERE csrf_state = ? RETURNING pkce_code_verifier,return_url"#,
|
||||
)
|
||||
.bind(state.secret())
|
||||
.fetch_one(&db_pool)
|
||||
.await?;
|
||||
|
||||
// Alternative:
|
||||
// let query: (String, String) = sqlx::query_as(
|
||||
// r#"SELECT pkce_code_verifier,return_url FROM oauth2_state_storage WHERE csrf_state = ?"#,
|
||||
// )
|
||||
// .bind(state.secret())
|
||||
// .fetch_one(&db_pool)
|
||||
// .await?;
|
||||
// let _ = sqlx::query("DELETE FROM oauth2_state_storage WHERE csrf_state = ?")
|
||||
// .bind(state.secret())
|
||||
// .execute(&db_pool)
|
||||
// .await;
|
||||
|
||||
let pkce_code_verifier = query.0;
|
||||
let return_url = query.1;
|
||||
let pkce_code_verifier = PkceCodeVerifier::new(pkce_code_verifier);
|
||||
|
||||
// Exchange the code with a token.
|
||||
let client = get_client(hostname)?;
|
||||
let token_response = tokio::task::spawn_blocking(move || {
|
||||
client
|
||||
.exchange_code(code)
|
||||
.set_pkce_verifier(pkce_code_verifier)
|
||||
.request(http_client)
|
||||
})
|
||||
.await
|
||||
.map_err(|_| "OAuth: exchange_code failure")?
|
||||
.map_err(|_| "OAuth: tokio spawn blocking failure")?;
|
||||
let access_token = token_response.access_token().secret();
|
||||
|
||||
// Get user info from Google
|
||||
let url =
|
||||
"https://www.googleapis.com/oauth2/v2/userinfo?oauth_token=".to_owned() + access_token;
|
||||
let body = reqwest::get(url)
|
||||
.await
|
||||
.map_err(|_| "OAuth: reqwest failed to query userinfo")?
|
||||
.text()
|
||||
.await
|
||||
.map_err(|_| "OAuth: reqwest received invalid userinfo")?;
|
||||
let mut body: serde_json::Value =
|
||||
serde_json::from_str(body.as_str()).map_err(|_| "OAuth: Serde failed to parse userinfo")?;
|
||||
let email = body["email"]
|
||||
.take()
|
||||
.as_str()
|
||||
.ok_or("OAuth: Serde failed to parse email address")?
|
||||
.to_owned();
|
||||
let name = body["name"]
|
||||
.take()
|
||||
.as_str()
|
||||
.ok_or("OAuth: Serde failed to parse email address")?
|
||||
.to_owned();
|
||||
let family_name = body["family_name"]
|
||||
.take()
|
||||
.as_str()
|
||||
.ok_or("OAuth: Serde failed to parse email address")?
|
||||
.to_owned();
|
||||
let given_name = body["given_name"]
|
||||
.take()
|
||||
.as_str()
|
||||
.ok_or("OAuth: Serde failed to parse email address")?
|
||||
.to_owned();
|
||||
let verified_email = body["verified_email"]
|
||||
.take()
|
||||
.as_bool()
|
||||
.ok_or("OAuth: Serde failed to parse verified_email")?;
|
||||
if !verified_email {
|
||||
return Err(AppError::new("OAuth: email address is not verified".to_owned())
|
||||
.with_user_message("Your email address is not verified. Please verify your email address with Google and try again.".to_owned()));
|
||||
}
|
||||
|
||||
// Check if user exists in database
|
||||
// If not, create a new user
|
||||
let query: Result<(i64,), _> = sqlx::query_as(r#"SELECT id FROM users WHERE email=?"#)
|
||||
.bind(email.as_str())
|
||||
.fetch_one(&db_pool)
|
||||
.await;
|
||||
let user_id = if let Ok(query) = query {
|
||||
query.0
|
||||
} else {
|
||||
let query: (i64,) = sqlx::query_as("INSERT INTO users (email, name, family_name, given_name) VALUES (?, ?, ?, ?) RETURNING id")
|
||||
.bind(email)
|
||||
.bind(name)
|
||||
.bind(family_name)
|
||||
.bind(given_name)
|
||||
.fetch_one(&db_pool)
|
||||
.await?;
|
||||
query.0
|
||||
};
|
||||
|
||||
// Create a session for the user
|
||||
let session_token_p1 = Uuid::new_v4().to_string();
|
||||
let session_token_p2 = Uuid::new_v4().to_string();
|
||||
let session_token = [session_token_p1.as_str(), "_", session_token_p2.as_str()].concat();
|
||||
let headers = axum::response::AppendHeaders([(
|
||||
axum::http::header::SET_COOKIE,
|
||||
"session_token=".to_owned()
|
||||
+ &*session_token
|
||||
+ "; path=/; httponly; secure; samesite=strict",
|
||||
)]);
|
||||
let now = Utc::now().timestamp();
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO user_sessions
|
||||
(session_token_p1, session_token_p2, user_id, created_at, expires_at)
|
||||
VALUES (?, ?, ?, ?, ?);",
|
||||
)
|
||||
.bind(session_token_p1)
|
||||
.bind(session_token_p2)
|
||||
.bind(user_id)
|
||||
.bind(now)
|
||||
.bind(now + 60 * 60 * 24)
|
||||
.execute(&db_pool)
|
||||
.await?;
|
||||
|
||||
Ok((headers, Redirect::to(return_url.as_str())))
|
||||
}
|
||||
|
||||
pub async fn logout(
|
||||
cookie: Option<TypedHeader<Cookie>>,
|
||||
State(db_pool): State<SqlitePool>,
|
||||
) -> Result<impl IntoResponse, AppError> {
|
||||
if let Some(cookie) = cookie {
|
||||
if let Some(session_token) = cookie.get("session_token") {
|
||||
let session_token: Vec<&str> = session_token.split('_').collect();
|
||||
let _ = sqlx::query("DELETE FROM user_sessions WHERE session_token_1 = ?")
|
||||
.bind(session_token[0])
|
||||
.execute(&db_pool)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
let headers = axum::response::AppendHeaders([(
|
||||
axum::http::header::SET_COOKIE,
|
||||
"session_token=deleted; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT",
|
||||
)]);
|
||||
Ok((headers, Redirect::to("/")))
|
||||
}
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
use askama_axum::Template;
|
||||
use axum::{extract::{Path, State}, response::IntoResponse, Extension};
|
||||
use http::Request;
|
||||
use askama_axum::{Response, Template};
|
||||
use axum::{extract::{Path, State}, response::{Html, IntoResponse}, Extension};
|
||||
use http::{Request, StatusCode};
|
||||
use sqlx::SqlitePool;
|
||||
|
||||
use crate::{HtmlTemplate, UserData};
|
||||
use crate::UserData;
|
||||
|
||||
#[derive(Template)]
|
||||
#[template(path = "profile.html")]
|
||||
|
|
@ -13,6 +13,64 @@ struct ProfileTemplate {
|
|||
user: UserData
|
||||
}
|
||||
|
||||
struct HtmlTemplate<T>(T);
|
||||
|
||||
impl<T> IntoResponse for HtmlTemplate<T>
|
||||
where
|
||||
T: Template,
|
||||
{
|
||||
fn into_response(self) -> Response {
|
||||
match self.0.render() {
|
||||
Ok(html) => Html(html).into_response(),
|
||||
Err(err) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to render template. Error: {}", err),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Template)]
|
||||
#[template(path = "index.html")]
|
||||
struct IndexTemplate {
|
||||
logged_in: bool,
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[derive(Template)]
|
||||
#[template(path = "dashboard.html")]
|
||||
struct DashboardTemplate {
|
||||
logged_in: bool,
|
||||
name: String,
|
||||
}
|
||||
|
||||
pub async fn index<T>(
|
||||
Extension(user_data): Extension<Option<UserData>>,
|
||||
_request: Request<T>,
|
||||
) -> impl IntoResponse {
|
||||
let user_name = user_data.map(|s| s.name);
|
||||
let logged_in = user_name.is_some();
|
||||
let name = user_name.unwrap_or_default();
|
||||
|
||||
println!("logged_in: {}, name: {}", logged_in, name);
|
||||
|
||||
let template = IndexTemplate { logged_in, name };
|
||||
HtmlTemplate(template)
|
||||
}
|
||||
|
||||
pub async fn dashboard<T>(
|
||||
Extension(user_data): Extension<Option<UserData>>,
|
||||
_request: Request<T>,
|
||||
) -> impl IntoResponse {
|
||||
let user_name = user_data.map(|s| s.name);
|
||||
let logged_in = user_name.is_some();
|
||||
let name = user_name.unwrap_or_default();
|
||||
|
||||
let template = DashboardTemplate { logged_in, name};
|
||||
HtmlTemplate(template)
|
||||
}
|
||||
|
||||
/// Handles the profile page.
|
||||
pub async fn profile<T>(
|
||||
Extension(user_data): Extension<Option<UserData>>,
|
||||
|
|
|
|||
|
|
@ -1,17 +1,39 @@
|
|||
<div class="container py-5 h-100">
|
||||
<br>
|
||||
<p>This will be the private information area for the extended Jean-Marie family.</p>
|
||||
<div>
|
||||
<h2>Web links</h2>
|
||||
<h3>Fonts</h3>
|
||||
<ul>
|
||||
<li><a href="https://fonts.google.com">Google fonts</a></li>
|
||||
<li><a href="https://www.fontspace.com">Font Space</a></li>
|
||||
</ul>
|
||||
<h3>Family tree</h3>
|
||||
<ul>
|
||||
<li><a href="https://www.ancestry.com">Ancestry</a></li>
|
||||
<li><a href="https://www.geni.com">Geni</a></li>
|
||||
</ul>
|
||||
{% extends "base.html" %}
|
||||
{% block content %}
|
||||
<div class="container-fluid">
|
||||
<div class="row align-items-stretch">
|
||||
<div id="menu" class="col-md-2 bg-light">
|
||||
<!-- internal menu -->
|
||||
<h2>Menu</h2>
|
||||
<ul>
|
||||
<li>Web links</li>
|
||||
<li><a href="/useradmin">User Administration</a></li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class="col-8">
|
||||
<p>This will be the private information area for the extended Jean-Marie family.</p>
|
||||
<div>
|
||||
<h2>Web links</h2>
|
||||
<h3>TLC Creations</h3>
|
||||
<ul>
|
||||
<li><a href="https://www.tlccreations.ca">TLC Creations</a></li>
|
||||
</ul>
|
||||
<h3>Fonts</h3>
|
||||
<ul>
|
||||
<li><a href="https://fonts.google.com">Google fonts</a></li>
|
||||
<li><a href="https://www.fontspace.com">Font Space</a></li>
|
||||
</ul>
|
||||
<h3>Family tree</h3>
|
||||
<ul>
|
||||
<li><a href="https://www.ancestry.com">Ancestry</a></li>
|
||||
<li><a href="https://www.geni.com">Geni</a></li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
<div id="events" class="col-2 bg-light">
|
||||
<!-- events -->
|
||||
<h2>Events</h2>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{% endblock content %}
|
||||
|
|
@ -1,42 +1,10 @@
|
|||
{% extends "base.html" %}
|
||||
{% block content %}
|
||||
{% if logged_in %}
|
||||
<div class="container-fluid">
|
||||
<div class="row align-items-stretch">
|
||||
<div id="menu" class="col-md-2 bg-light">
|
||||
<!-- internal menu -->
|
||||
<h2>Menu</h2>
|
||||
<ul>
|
||||
<li>Web links</li>
|
||||
<li><a href="/useradmin">User Administration</a></li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class="col-8">
|
||||
<p>This will be the private information area for the extended Jean-Marie family.</p>
|
||||
<div>
|
||||
<h2>Web links</h2>
|
||||
<h3>TLC Creations</h3>
|
||||
<ul>
|
||||
<li><a href="https://www.tlccreations.ca">TLC Creations</a></li>
|
||||
</ul>
|
||||
<h3>Fonts</h3>
|
||||
<ul>
|
||||
<li><a href="https://fonts.google.com">Google fonts</a></li>
|
||||
<li><a href="https://www.fontspace.com">Font Space</a></li>
|
||||
</ul>
|
||||
<h3>Family tree</h3>
|
||||
<ul>
|
||||
<li><a href="https://www.ancestry.com">Ancestry</a></li>
|
||||
<li><a href="https://www.geni.com">Geni</a></li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
<div id="events" class="col-2 bg-light">
|
||||
<!-- events -->
|
||||
<h2>Events</h2>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<!-- Redirect to dashboard -->
|
||||
<script>
|
||||
window.location.href = "/dashboard";
|
||||
</script>
|
||||
{% else %}
|
||||
<!-- Carousel -->
|
||||
<div id="demo" class="carousel slide" data-bs-ride="carousel">
|
||||
|
|
@ -81,5 +49,6 @@
|
|||
<span class="carousel-control-next-icon"></span>
|
||||
</button>
|
||||
</div>
|
||||
{% endif %}
|
||||
{% endblock content %}
|
||||
</div>
|
||||
{% endif %}
|
||||
{% endblock content %}
|
||||
Loading…
Reference in New Issue