mod comment; use actix_cors::Cors; pub use comment::*; mod database; pub use database::Database; use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; use clap::Parser; use sanitize_html::{errors::SanitizeError, rules::predefined::DEFAULT, sanitize_str}; use scraper::{Html, Selector}; use serde::Deserialize; use std::sync::Mutex; use std::{collections::HashMap, sync::MutexGuard}; use validator::Validate; struct AppState { databases: HashMap>, arguments: Arguments, } enum DatabaseAccessError { BadOrigin, AccessError, DatabaseError, } impl DatabaseAccessError { fn to_http_response(&self) -> HttpResponse { match self { Self::BadOrigin => HttpResponse::BadRequest().reason("bad origin").finish(), Self::AccessError => HttpResponse::InternalServerError() .reason("database access error") .finish(), // e.g. PoisonError Self::DatabaseError => HttpResponse::InternalServerError() .reason("database error") .finish(), } } } impl AppState { fn get_db<'a>( &'a self, request: &HttpRequest, ) -> Result, DatabaseAccessError> { self.get_db_with_origin(get_request_origin(request)) } fn get_db_with_origin<'a>( &'a self, origin: Option, ) -> Result, DatabaseAccessError> { let origin = match origin { Some(origin) => origin, None => return Err(DatabaseAccessError::BadOrigin), }; match self.databases.get(&origin) { Some(database) => Ok(match database.lock() { Ok(database) => database, Err(_) => return Err(DatabaseAccessError::AccessError), }), None => return Err(DatabaseAccessError::AccessError), } } } fn get_request_origin(request: &HttpRequest) -> Option { match request.head().headers().get("Origin") { Some(origin) => match origin.to_str() { Ok(origin) => Some(origin.to_owned()), Err(_) => None, }, None => None, } } #[derive(Default, Parser)] #[clap(author, version, about)] struct Arguments { #[clap( short, long, default_value = "8080", help = "Set port where HTTP requests will be received" )] port: u16, #[clap( required = true, min_values = 1, help = "Set sites where comments will be posted" )] sites: Vec, #[clap( short, long, help = "Run in testing mode, with in-memory database(s) and permissive CORS policy" )] testing: bool, #[clap(short, long, help = "Require name for comment submissions")] name_required: bool, #[clap(short, long, help = "Require email for comment submissions")] email_required: bool, } #[get("/{content_id}")] async fn get_comments( data: web::Data, request: HttpRequest, content_id: web::Path, ) -> impl Responder { let origin = get_request_origin(&request); let comments = match web::block(move || { Ok( match match data.get_db_with_origin(origin) { Ok(database) => database, Err(err) => return Err(err), } .get_comments(&content_id) { Ok(comments) => comments, Err(_) => return Err(DatabaseAccessError::DatabaseError), }, ) }) .await { Ok(comments) => match comments { Ok(comments) => comments, Err(err) => return err.to_http_response(), }, Err(_) => return DatabaseAccessError::AccessError.to_http_response(), }; HttpResponse::Ok().json(comments) } #[derive(Deserialize)] struct PostCommentsRequest { url: String, comment: Comment, } #[post("/")] async fn post_comment( data: web::Data, request: HttpRequest, bytes: web::Bytes, ) -> impl Responder { match String::from_utf8(bytes.to_vec()) { Ok(text) => { let PostCommentsRequest { url, comment } = match serde_json::from_str::(&text) { Ok(mut req) => { let mut sanitize_req = || -> Result<(), SanitizeError> { req.comment.text = sanitize_str(&DEFAULT, &req.comment.text)?.replace(">", ">"); // required for markdown quotes if let Some(ref mut author) = req.comment.author { *author = sanitize_str(&DEFAULT, &author)?; } Ok(()) }; if let Err(_) = sanitize_req() { return HttpResponse::InternalServerError() .reason("failed to sanitize request") .finish(); } req } Err(_) => { return HttpResponse::BadRequest() .reason("invalid request body") .finish() } }; if comment.validate().is_err() { return HttpResponse::BadRequest() .reason("invalid comment field(s)") .finish(); } if comment.author.is_none() && data.arguments.name_required { return HttpResponse::BadRequest().reason("name required").finish(); } if comment.email.is_none() && data.arguments.email_required { return HttpResponse::BadRequest().reason("email required").finish(); } let origin = match request.head().headers().get("Origin") { Some(origin) => match origin.to_str() { Ok(origin) => origin, // If the Origin is not valid ASCII, it is a bad request not sent from a browser Err(_) => return HttpResponse::BadRequest().reason("bad origin").finish(), }, // If there is no Origin header, it is a bad request not sent from a browser None => return HttpResponse::BadRequest().reason("bad origin").finish(), }; // Check to see if provided URL is in scope. // This is to prevent malicious requests that try to get server to fetch external websites. // (requires loop because "labels on blocks are unstable") // https://github.com/rust-lang/rust/issues/48594 'outer: loop { for site_root in data.databases.keys() { if site_root.starts_with(origin) && url.starts_with(site_root) { break 'outer; } } return HttpResponse::BadRequest() .reason("url out of scope") .finish(); } match get_page_data(&url).await { Ok(page_data_option) => match page_data_option { Some(page_data) => { if page_data.content_id != comment.content_id { return HttpResponse::BadRequest() .reason("content ids don't match") .finish(); } } None => return HttpResponse::BadRequest().reason("url invalid").finish(), // e.g. 404 }, Err(_) => { return HttpResponse::InternalServerError() .reason("failed to get page data") .finish() } }; let database = match data.get_db(&request) { Ok(database) => database, Err(err) => return err.to_http_response(), }; if let Some(parent) = comment.parent { 'outer2: loop { match database.get_comments(&comment.content_id) { Ok(comments) => { for other_comment in comments.iter() { if other_comment.id.unwrap() == parent { if other_comment.parent.is_none() { break 'outer2; } break; } } } Err(_) => { return HttpResponse::InternalServerError() .reason("failed to get comments") .finish() } } return HttpResponse::BadRequest() .reason("invalid comment parent") .finish(); } } if let Err(_) = database.create_comment(&comment) { return HttpResponse::InternalServerError() .reason("failed to create comment") .finish(); } HttpResponse::Ok().into() } Err(_) => HttpResponse::BadRequest() .reason("failed to parse request body") .finish(), } } // Contains all page details stored in meta tags. // Currently, only content_id, but this is wrapped in this struct // to make adding other meta tags, such as locked comments, in the future struct PageData { content_id: String, } async fn get_page_data(url: &str) -> Result, reqwest::Error> { let response = reqwest::get(url).await?; if !response.status().is_success() { return Ok(None); } let content = response.text_with_charset("utf-8").await?; let document = Html::parse_document(&content); let get_meta = |name: &str| -> Option { let selector = Selector::parse(&format!("meta[name=\"{}\"]", name)).unwrap(); match document.select(&selector).next() { Some(element) => match element.value().attr("content") { Some(value) => Some(value.to_owned()), None => return None, }, None => return None, } }; return Ok(Some(PageData { content_id: match get_meta("soudan-content-id") { Some(id) => id, None => return Ok(None), }, })); } #[actix_web::main] async fn main() -> Result<(), std::io::Error> { let arguments = Arguments::parse(); let mut databases = HashMap::new(); for domain in arguments.sites.iter() { databases.insert( domain.to_owned(), Mutex::new(Database::new(arguments.testing, domain).unwrap()), ); } let port = arguments.port; let state = web::Data::new(AppState { databases, arguments, }); HttpServer::new(move || { App::new() .service(get_comments) .service(post_comment) .app_data(state.clone()) // Issue with CORS on POST requests, // keeping permissive for now .wrap( Cors::permissive(), /* if arguments.testing { Cors::permissive() } else { let mut cors = Cors::default() .allowed_methods(vec!["GET", "POST"]); for domain in arguments.sites.iter() { cors = cors.allowed_origin(domain); } cors } */ ) }) .bind(("127.0.0.1", port))? .run() .await }