diff --git a/Cargo.lock b/Cargo.lock index 03ed8f8..87d9d54 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -422,6 +422,12 @@ dependencies = [ "xdg", ] +[[package]] +name = "convert_case" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" + [[package]] name = "cookie" version = "0.11.5" @@ -522,6 +528,19 @@ dependencies = [ "cipher", ] +[[package]] +name = "derive_more" +version = "0.99.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" +dependencies = [ + "convert_case", + "proc-macro2 1.0.59", + "quote 1.0.28", + "rustc_version", + "syn 1.0.109", +] + [[package]] name = "deunicode" version = "0.4.3" @@ -2363,6 +2382,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "0.37.19" @@ -2461,6 +2489,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" + [[package]] name = "serde" version = "1.0.163" @@ -2654,6 +2688,17 @@ dependencies = [ "unicode-xid 0.1.0", ] +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2 1.0.59", + "quote 1.0.28", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.18" @@ -2695,6 +2740,7 @@ version = "0.1.0" dependencies = [ "chrono", "comrak", + "derive_more", "dotenv", "reqwest", "rocket 0.5.0-rc.3", @@ -2702,6 +2748,7 @@ dependencies = [ "rocket_dyn_templates", "sass-rocket-fairing", "serde", + "serde_json", "serde_yaml", ] diff --git a/Cargo.toml b/Cargo.toml index 2f4e93c..f3ee4f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" [dependencies] chrono = { version = "0.4.26", features = ["serde"] } comrak = "0.18.0" +derive_more = "0.99.17" dotenv = "0.15.0" reqwest = "0.11.18" rocket = { version = "=0.5.0-rc.3", features = ["secrets"] } @@ -15,4 +16,5 @@ rocket_contrib = { version = "0.4.11", features = ["templates"] } rocket_dyn_templates = { version = "0.1.0-rc.3", features = ["tera"] } sass-rocket-fairing = "0.2.0" serde = "1.0.163" +serde_json = "1.0.96" serde_yaml = "0.9.21" diff --git a/src/main.rs b/src/main.rs index 2a4a149..c5bc049 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,13 +2,19 @@ extern crate rocket; use core::panic; +use chrono::{Duration, Utc}; +use derive_more::From; +use reqwest::StatusCode; +use rocket::form::Form; use rocket::fs::{relative, FileServer}; use rocket::http::{Cookie, CookieJar}; use rocket::request::{FromRequest}; -use rocket::response::{Redirect}; +use rocket::response::{Redirect, content::RawHtml}; use rocket::{request, Request}; use rocket_dyn_templates::{context, Template}; use sass_rocket_fairing::SassFairing; +use serde::ser::SerializeStruct; +use serde::{Deserialize, Serialize, Serializer}; use std::collections::HashMap; use std::convert::Infallible; use std::env; @@ -20,16 +26,12 @@ use challenge::Challenge; mod kyujitai; #[get("/")] -fn get_challenge(challenge: u32, cookies: &CookieJar<'_>) -> Template { - let value = cookies - .get_private(TOKEN_COOKIE) - .map(|cookie| cookie.value().to_owned()); - let logged_in = value.is_some(); +async fn get_challenge(challenge: u32, cookies: &CookieJar<'_>) -> Template { Template::render( "index", context! { challenge, - logged_in, + user: User::get(cookies).await.unwrap(), content: { use comrak::{parse_document, Arena, ComrakOptions}; let options = { @@ -66,19 +68,41 @@ fn get_challenge(challenge: u32, cookies: &CookieJar<'_>) -> Template { #[get("/login")] fn login() -> Redirect { Redirect::to(format!( - "https://discord.com/api/oauth2/authorize?client_id={client_id}&redirect_uri={redirect_uri}&response_type=code&scope=identify%20guilds.join%20guilds", + // Switch from response_type=code to response_type=token from URL generator + "https://discord.com/api/oauth2/authorize?client_id={client_id}&redirect_uri={redirect_uri}&response_type=token&scope=identify%20guilds.join%20guilds", client_id = env::var("CLIENT_ID").unwrap(), - redirect_uri = format!("{}login", env::var("DOMAIN").unwrap()), + redirect_uri = format!("{}success", env::var("DOMAIN").unwrap()), )) - // TODO: After returning from Discord go to previous page (with Referer?) } -#[get("/login?")] -fn login_success(code: String, cookies: &CookieJar<'_>) -> Redirect { - cookies.add_private(Cookie::new(TOKEN_COOKIE, code)); + +#[derive(FromForm)] +struct Login<'r> { + access_token: &'r str, + expires_in: u64, +} + +#[post("/login", data = "")] +fn post_login(login: Form>, cookies: &CookieJar<'_>) -> Redirect { + cookies.add_private(Cookie::new(TOKEN_COOKIE, login.access_token.to_owned())); + cookies.add(Cookie::new(TOKEN_EXPIRE_COOKIE, (Utc::now() + Duration::seconds(login.expires_in as i64)).timestamp().to_string())); Redirect::to("/") } +#[get("/success")] +fn success() -> RawHtml<&'static str> { + RawHtml("
+ + +
+") +} + struct Referer(Option); #[rocket::async_trait] @@ -92,6 +116,7 @@ impl<'r> FromRequest<'r> for Referer { } const TOKEN_COOKIE: &str = "token"; +const TOKEN_EXPIRE_COOKIE: &str = "token_expire"; #[get("/logout")] fn logout(cookies: &CookieJar<'_>, referer: Referer) -> Redirect { @@ -124,13 +149,128 @@ fn logout(cookies: &CookieJar<'_>, referer: Referer) -> Redirect { Redirect::to(redirect_url) } +#[derive(Default, Deserialize)] +struct User { + #[serde(deserialize_with = "deserialize_id")] + id: u64, + #[serde(rename = "username")] + name: String, + #[serde(deserialize_with = "deserialize_discriminator")] + discriminator: u16, + avatar: String, +} + +fn deserialize_id<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + let id_str: &str = serde::Deserialize::deserialize(deserializer)?; + id_str.parse().map_err(serde::de::Error::custom) +} + +fn deserialize_discriminator<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + let id_str: &str = serde::Deserialize::deserialize(deserializer)?; + id_str.parse().map_err(serde::de::Error::custom) +} + +impl Serialize for User { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut state = serializer.serialize_struct("User", 5)?; + state.serialize_field("id", &self.id)?; + state.serialize_field("name", &self.name)?; + state.serialize_field("discriminator", &self.discriminator)?; + state.serialize_field("avatar", &self.avatar)?; + state.serialize_field("username", &self.username())?; + state.end() + } +} + +#[derive(From, Debug)] +enum GetUserError { + ReqwestError(reqwest::Error), + DeserializeError(serde_json::Error), + #[allow(unused)] + DiscordError { status: StatusCode, message: Option }, +} + +impl User { + fn username(&self) -> String { + if self.discriminator == 0 { + return self.name.clone(); + } + format!("{}#{:0>4}", self.name, self.discriminator) + } + + async fn get(cookies: &CookieJar<'_>) -> Result, GetUserError> { + let token = match cookies.get_private(TOKEN_COOKIE) { + Some(cookie) => cookie.value().to_owned(), + None => return Ok(None), + }; + if cookies.get(TOKEN_EXPIRE_COOKIE) + .map(|expire| expire.value().parse::()) + .map(Result::ok) + .flatten() + .map_or(true, |timestamp| Utc::now().timestamp() >= timestamp) { + cookies.remove_private(Cookie::named(TOKEN_COOKIE)); + cookies.remove(Cookie::named(TOKEN_EXPIRE_COOKIE)); + return Ok(None); + } + let (status, text) = { + let response = reqwest::Client::new() + .get("https://discord.com/api/users/@me") + .header("Authorization", format!("Bearer {token}")) + .send() + .await?; + (response.status(), response.text().await) + }; + if !status.is_success() { + return Err(GetUserError::DiscordError { + status, + message: text.ok(), + }) + } + Ok(Some(serde_json::from_str(&text?)?)) + } +} + #[launch] fn rocket() -> _ { let config = rocket::Config::figment().merge(("port", 1313)); dotenv::dotenv().expect("Failed to load .env file"); rocket::custom(config) - .mount("/", routes![get_challenge, login, login_success, logout]) + .mount("/", routes![get_challenge, login, post_login, success, logout]) .mount("/css", FileServer::from(relative!("styles/css"))) .attach(Template::fairing()) .attach(SassFairing::default()) } + + +#[cfg(test)] +mod tests { + use crate::User; + + fn test_user(name: &str, discriminator: u16) -> User { + let mut user = User::default(); + user.name = name.to_owned(); + user.discriminator = discriminator; + user + } + + #[test] + fn test_legacy_username() { + let user = test_user("test", 123); + assert_eq!(user.username(), "test#0123"); + } + + #[test] + fn test_new_username() { + let user = test_user("test", 0); + assert_eq!(user.username(), "test"); + } +} \ No newline at end of file diff --git a/templates/index.html.tera b/templates/index.html.tera index 297d133..e11989c 100644 --- a/templates/index.html.tera +++ b/templates/index.html.tera @@ -18,9 +18,9 @@ {{ content.song.japanese }} {% endif %} - {% if logged_in %} + {% if user %} - @mochamoko + {{ user.username }} Log out