diff --git a/src/error.rs b/src/error.rs index 415f600..e38eff5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,50 +1,72 @@ -use actix_web::HttpResponse; +use crate::Database; +use actix_web::{HttpResponse, error::BlockingError}; +use std::{sync::{PoisonError, MutexGuard}, string::FromUtf8Error}; +use sanitize_html::errors::SanitizeError; +use validator::ValidationErrors; +use derive_more::From; + +#[derive(From, Debug)] pub enum Error { InvalidOrigin, - InvalidBody, + InvalidBodyEncoding(FromUtf8Error), + InvalidBodyJson(serde_json::Error), InvalidUrl, - InvalidFields, + InvalidFields(ValidationErrors), InvalidContentId, InvalidParent, EmailRequired, NameRequired, - DatabaseAccessError, - DatabaseInternalError, - SanitizationError, - PageFetchError, + DatabaseAccessBlockingError(BlockingError), + DatabaseAccessPoisonError, + DatabaseInternalError(rusqlite::Error), + SanitizationError(SanitizeError), + PageFetchError(reqwest::Error), } impl Error { pub fn to_http_response(&self) -> HttpResponse { match self { Self::InvalidOrigin - | Self::InvalidBody + | Self::InvalidBodyEncoding(_) + | Self::InvalidBodyJson(_) | Self::InvalidUrl - | Self::InvalidFields + | Self::InvalidFields(_) | Self::InvalidContentId | Self::InvalidParent | Self::EmailRequired | Self::NameRequired => HttpResponse::BadRequest(), - Self::DatabaseAccessError - | Self::DatabaseInternalError - | Self::SanitizationError - | Self::PageFetchError => HttpResponse::InternalServerError(), + Self::DatabaseAccessBlockingError(_) + | Self::DatabaseAccessPoisonError + | Self::DatabaseInternalError(_) + | Self::SanitizationError(_) + | Self::PageFetchError(_) => { + eprintln!("{:?}", self); + HttpResponse::InternalServerError() + }, } .reason(match self { Self::InvalidOrigin => "invalid request origin", - Self::InvalidBody => "invalid request body", + Self::InvalidBodyEncoding(_) + | Self::InvalidBodyJson(_) => "invalid request body", Self::InvalidUrl => "invalid request url", - Self::InvalidFields => "invalid request field", + Self::InvalidFields(_) => "invalid request field", Self::InvalidContentId => "invalid request content id", Self::InvalidParent => "invalid comment parent", Self::EmailRequired => "comment email required", Self::NameRequired => "comment name required", - Self::DatabaseAccessError => "database access error", - Self::DatabaseInternalError => "database internal error", - Self::SanitizationError => "comment sanitization error", - Self::PageFetchError => "page fetch error", + Self::DatabaseAccessBlockingError(_) + | Self::DatabaseAccessPoisonError => "database access error", + Self::DatabaseInternalError(_) => "database internal error", + Self::SanitizationError(_) => "comment sanitization error", + Self::PageFetchError(_) => "page fetch error", }) .finish() } } + +impl<'a> From>> for Error { + fn from(_: PoisonError>) -> Self { + Self::DatabaseAccessPoisonError + } +} diff --git a/src/main.rs b/src/main.rs index 95f8303..73c1941 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,7 +10,7 @@ pub use error::*; use actix_cors::Cors; use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer}; use clap::Parser; -use sanitize_html::{errors::SanitizeError, rules::predefined::DEFAULT, sanitize_str}; +use sanitize_html::{rules::predefined::DEFAULT, sanitize_str}; use scraper::{Html, Selector}; use serde::Deserialize; use std::fs::File; @@ -23,16 +23,9 @@ struct AppState { } impl AppState { - fn get_db<'a>(&'a self, origin: Option) -> Result, Error> { - let origin = match origin { - Some(origin) => origin, - None => return Err(Error::InvalidOrigin), - }; - match self.databases.get(&origin) { - Some(database) => Ok(match database.lock() { - Ok(database) => database, - Err(_) => return Err(Error::DatabaseAccessError), - }), + fn get_db<'a>(&'a self, origin: &str) -> Result, Error> { + match self.databases.get(origin) { + Some(database) => Ok(database.lock()?), None => return Err(Error::InvalidOrigin), } } @@ -42,13 +35,13 @@ fn trim_protocol(url: &str) -> String { url.replace("http://", "").replace("https://", "") } -fn get_request_origin(request: &HttpRequest) -> Option { +fn get_request_origin(request: &HttpRequest) -> Result { match request.head().headers().get("Origin") { Some(origin) => match origin.to_str() { - Ok(origin) => Some(trim_protocol(origin)), - Err(_) => None, + Ok(origin) => Ok(trim_protocol(origin)), + Err(_) => Err(Error::InvalidOrigin), }, - None => None, + None => Err(Error::InvalidOrigin), } } @@ -77,25 +70,8 @@ async fn _get_comments( request: HttpRequest, content_id: web::Path, ) -> Result, Error> { - let origin = get_request_origin(&request); - match web::block(move || { - Ok( - match match data.get_db(origin) { - Ok(database) => database, - Err(err) => return Err(err), - } - .get_comments(&content_id) - { - Ok(comments) => comments, - Err(_) => return Err(Error::DatabaseInternalError), - }, - ) - }) - .await - { - Ok(result) => result, - Err(_) => Err(Error::DatabaseAccessError), - } + let origin = get_request_origin(&request)?; + web::block(move || Ok(data.get_db(&origin)?.get_comments(&content_id)?)).await? } #[get("/{content_id}")] @@ -121,104 +97,61 @@ async fn _post_comment( request: HttpRequest, bytes: web::Bytes, ) -> Result<(), Error> { - match String::from_utf8(bytes.to_vec()) { - Ok(text) => { - let PostCommentsRequest { url, comment } = - match serde_json::from_str::(&text) { - Ok(mut req) => { - let mut sanitize_req = || -> Result<(), SanitizeError> { - req.comment.text = - sanitize_str(&DEFAULT, &req.comment.text)?.replace(">", ">"); // required for markdown quotes - if let Some(ref mut author) = req.comment.author { - *author = sanitize_str(&DEFAULT, &author)?; - } - Ok(()) - }; - if let Err(_) = sanitize_req() { - return Err(Error::SanitizationError); - } - req - } - Err(_) => { - return Err(Error::InvalidBody); - } - }; - if comment.validate().is_err() { - return Err(Error::InvalidFields); - } - let origin = match get_request_origin(&request) { - Some(origin) => origin, - None => return Err(Error::InvalidOrigin), - }; - // Check to see if provided URL is in scope. - // This is to prevent malicious requests that try to get server to fetch external websites. - // (requires loop because "labels on blocks are unstable") - // https://github.com/rust-lang/rust/issues/48594 - 'outer: loop { - for site_root in data.databases.keys() { - if site_root.eq(&origin) && trim_protocol(&url).starts_with(site_root) { - break 'outer; - } - } - return Err(Error::InvalidUrl); + let PostCommentsRequest { url, comment } = { + let mut req = serde_json::from_str::(&String::from_utf8(bytes.to_vec())?)?; + req.comment.text = sanitize_str(&DEFAULT, &req.comment.text)?.replace(">", ">"); // required for markdown quotes + if let Some(ref mut author) = req.comment.author { + *author = sanitize_str(&DEFAULT, &author)?; + } + req + }; + comment.validate()?; + let origin = get_request_origin(&request)?; + // Check to see if provided URL is in scope. + // This is to prevent malicious requests that try to get server to fetch external websites. + // (requires loop because "labels on blocks are unstable") + // https://github.com/rust-lang/rust/issues/48594 + 'outer: loop { + for site_root in data.databases.keys() { + if site_root.eq(&origin) && trim_protocol(&url).starts_with(site_root) { + break 'outer; } - match get_page_data(&url).await { - Ok(page_data_option) => match page_data_option { - Some(page_data) => { - if page_data.content_id != comment.content_id { - return Err(Error::InvalidContentId); + } + return Err(Error::InvalidUrl); + } + match get_page_data(&url).await? { + Some(page_data) => { + if page_data.content_id != comment.content_id { + return Err(Error::InvalidContentId); + } + } + None => return Err(Error::InvalidUrl), // e.g. 404 + }; + web::block(move || { + let database = data.get_db(&origin)?; + if comment.author.is_none() && database.settings.name_required { + return Err(Error::NameRequired); + } + if comment.email.is_none() && database.settings.email_required { + return Err(Error::EmailRequired); + } + if let Some(parent) = comment.parent { + 'outer2: loop { + let comments = database.get_comments(&comment.content_id)?; + for other_comment in comments.iter() { + if other_comment.id.unwrap() == parent { + if other_comment.parent.is_none() { + break 'outer2; } + break; } - None => return Err(Error::InvalidUrl), // e.g. 404 - }, - Err(_) => { - return Err(Error::PageFetchError); - } - }; - match web::block(move || { - let database = match data.get_db(Some(origin)) { - Ok(database) => database, - Err(err) => return Err(err), - }; - if comment.author.is_none() && database.settings.name_required { - return Err(Error::NameRequired); } - if comment.email.is_none() && database.settings.email_required { - return Err(Error::EmailRequired); - } - if let Some(parent) = comment.parent { - 'outer2: loop { - match database.get_comments(&comment.content_id) { - Ok(comments) => { - for other_comment in comments.iter() { - if other_comment.id.unwrap() == parent { - if other_comment.parent.is_none() { - break 'outer2; - } - break; - } - } - } - Err(_) => { - return Err(Error::DatabaseInternalError); - } - }; - return Err(Error::InvalidParent); - } - } - if let Err(_) = database.create_comment(&comment) { - return Err(Error::DatabaseInternalError); - } - Ok(()) - }) - .await - { - Ok(result) => result, - Err(_) => Err(Error::DatabaseAccessError), + return Err(Error::InvalidParent); } } - Err(_) => Err(Error::InvalidBody), - } + database.create_comment(&comment)?; + Ok(()) + }).await? } #[post("/")]