diff --git a/src/database.rs b/src/database.rs index dcbdbd0..0d8141e 100644 --- a/src/database.rs +++ b/src/database.rs @@ -7,9 +7,7 @@ pub struct Database { impl Database { pub fn new(testing: bool, name: &str) -> Result { - let name = name - .replace("http://", "") - .replace("https://", ""); + let name = name.replace("http://", "").replace("https://", ""); let conn = if testing { Connection::open_in_memory() } else { diff --git a/src/main.rs b/src/main.rs index 8ddc3a9..9e17c94 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,13 +6,13 @@ mod database; pub use database::Database; use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; +use clap::Parser; +use sanitize_html::{errors::SanitizeError, rules::predefined::DEFAULT, sanitize_str}; use scraper::{Html, Selector}; use serde::Deserialize; -use std::{collections::HashMap, sync::MutexGuard}; use std::sync::Mutex; +use std::{collections::HashMap, sync::MutexGuard}; use validator::Validate; -use sanitize_html::{sanitize_str, rules::predefined::DEFAULT, errors::SanitizeError}; -use clap::Parser; struct AppState { databases: HashMap>, @@ -22,15 +22,19 @@ struct AppState { enum DatabaseAccessError { BadOrigin, AccessError, - DatabaseError + DatabaseError, } impl DatabaseAccessError { fn to_http_response(&self) -> HttpResponse { match self { Self::BadOrigin => HttpResponse::BadRequest().reason("bad origin").finish(), - Self::AccessError => HttpResponse::InternalServerError().reason("database access error").finish(), // e.g. PoisonError - Self::DatabaseError => HttpResponse::InternalServerError().reason("database error").finish(), + Self::AccessError => HttpResponse::InternalServerError() + .reason("database access error") + .finish(), // e.g. PoisonError + Self::DatabaseError => HttpResponse::InternalServerError() + .reason("database error") + .finish(), } } } @@ -74,11 +78,24 @@ fn get_request_origin(request: &HttpRequest) -> Option { #[derive(Default, Parser)] #[clap(author, version, about)] struct Arguments { - #[clap(short, long, default_value = "8080", help = "Set port where HTTP requests will be received")] + #[clap( + short, + long, + default_value = "8080", + help = "Set port where HTTP requests will be received" + )] port: u16, - #[clap(required = true, min_values = 1, help = "Set sites where comments will be posted")] + #[clap( + required = true, + min_values = 1, + help = "Set sites where comments will be posted" + )] sites: Vec, - #[clap(short, long, help = "Run in testing mode, with in-memory database(s) and permissive CORS policy")] + #[clap( + short, + long, + help = "Run in testing mode, with in-memory database(s) and permissive CORS policy" + )] testing: bool, #[clap(short, long, help = "Require name for comment submissions")] name_required: bool, @@ -94,14 +111,20 @@ async fn get_comments( ) -> impl Responder { let origin = get_request_origin(&request); let comments = match web::block(move || { - Ok(match match data.get_db_with_origin(origin) { - Ok(database) => database, - Err(err) => return Err(err), - }.get_comments(&content_id) { - Ok(comments) => comments, - Err(_) => return Err(DatabaseAccessError::DatabaseError), - }) - }).await { + Ok( + match match data.get_db_with_origin(origin) { + Ok(database) => database, + Err(err) => return Err(err), + } + .get_comments(&content_id) + { + Ok(comments) => comments, + Err(_) => return Err(DatabaseAccessError::DatabaseError), + }, + ) + }) + .await + { Ok(comments) => match comments { Ok(comments) => comments, Err(err) => return err.to_http_response(), @@ -125,25 +148,34 @@ async fn post_comment( ) -> impl Responder { match String::from_utf8(bytes.to_vec()) { Ok(text) => { - let PostCommentsRequest { url, comment } = match serde_json::from_str::(&text) { - Ok(mut req) => { - let mut sanitize_req = || -> Result<(), SanitizeError> { - req.comment.text = sanitize_str(&DEFAULT, &req.comment.text)? - .replace(">", ">"); // required for markdown quotes - if let Some(ref mut author) = req.comment.author { - *author = sanitize_str(&DEFAULT, &author)?; + let PostCommentsRequest { url, comment } = + match serde_json::from_str::(&text) { + Ok(mut req) => { + let mut sanitize_req = || -> Result<(), SanitizeError> { + req.comment.text = + sanitize_str(&DEFAULT, &req.comment.text)?.replace(">", ">"); // required for markdown quotes + if let Some(ref mut author) = req.comment.author { + *author = sanitize_str(&DEFAULT, &author)?; + } + Ok(()) + }; + if let Err(_) = sanitize_req() { + return HttpResponse::InternalServerError() + .reason("failed to sanitize request") + .finish(); } - Ok(()) - }; - if let Err(_) = sanitize_req() { - return HttpResponse::InternalServerError().reason("failed to sanitize request").finish(); + req } - req - } - Err(_) => return HttpResponse::BadRequest().reason("invalid request body").finish(), - }; + Err(_) => { + return HttpResponse::BadRequest() + .reason("invalid request body") + .finish() + } + }; if comment.validate().is_err() { - return HttpResponse::BadRequest().reason("invalid comment field(s)").finish(); + return HttpResponse::BadRequest() + .reason("invalid comment field(s)") + .finish(); } if comment.author.is_none() && data.arguments.name_required { return HttpResponse::BadRequest().reason("name required").finish(); @@ -170,18 +202,26 @@ async fn post_comment( break 'outer; } } - return HttpResponse::BadRequest().reason("url out of scope").finish(); + return HttpResponse::BadRequest() + .reason("url out of scope") + .finish(); } match get_page_data(&url).await { Ok(page_data_option) => match page_data_option { Some(page_data) => { if page_data.content_id != comment.content_id { - return HttpResponse::BadRequest().reason("content ids don't match").finish(); + return HttpResponse::BadRequest() + .reason("content ids don't match") + .finish(); } } None => return HttpResponse::BadRequest().reason("url invalid").finish(), // e.g. 404 }, - Err(_) => return HttpResponse::InternalServerError().reason("failed to get page data").finish(), + Err(_) => { + return HttpResponse::InternalServerError() + .reason("failed to get page data") + .finish() + } }; let database = match data.get_db(&request) { Ok(database) => database, @@ -190,25 +230,37 @@ async fn post_comment( if let Some(parent) = comment.parent { 'outer2: loop { match database.get_comments(&comment.content_id) { - Ok(comments) => for other_comment in comments.iter() { - if other_comment.id.unwrap() == parent { - if other_comment.parent.is_none() { - break 'outer2; + Ok(comments) => { + for other_comment in comments.iter() { + if other_comment.id.unwrap() == parent { + if other_comment.parent.is_none() { + break 'outer2; + } + break; } - break; } - }, - Err(_) => return HttpResponse::InternalServerError().reason("failed to get comments").finish(), + } + Err(_) => { + return HttpResponse::InternalServerError() + .reason("failed to get comments") + .finish() + } } - return HttpResponse::BadRequest().reason("invalid comment parent").finish(); + return HttpResponse::BadRequest() + .reason("invalid comment parent") + .finish(); } } if let Err(_) = database.create_comment(&comment) { - return HttpResponse::InternalServerError().reason("failed to create comment").finish(); + return HttpResponse::InternalServerError() + .reason("failed to create comment") + .finish(); } HttpResponse::Ok().into() } - Err(_) => HttpResponse::BadRequest().reason("failed to parse request body").finish(), + Err(_) => HttpResponse::BadRequest() + .reason("failed to parse request body") + .finish(), } } @@ -255,7 +307,10 @@ async fn main() -> Result<(), std::io::Error> { ); } let port = arguments.port; - let state = web::Data::new(AppState { databases, arguments }); + let state = web::Data::new(AppState { + databases, + arguments, + }); HttpServer::new(move || { App::new() .service(get_comments) @@ -263,16 +318,18 @@ async fn main() -> Result<(), std::io::Error> { .app_data(state.clone()) // Issue with CORS on POST requests, // keeping permissive for now - .wrap(Cors::permissive() /* if arguments.testing { - Cors::permissive() - } else { - let mut cors = Cors::default() - .allowed_methods(vec!["GET", "POST"]); - for domain in arguments.sites.iter() { - cors = cors.allowed_origin(domain); - } - cors - } */) + .wrap( + Cors::permissive(), /* if arguments.testing { + Cors::permissive() + } else { + let mut cors = Cors::default() + .allowed_methods(vec!["GET", "POST"]); + for domain in arguments.sites.iter() { + cors = cors.allowed_origin(domain); + } + cors + } */ + ) }) .bind(("127.0.0.1", port))? .run()