Improve error handling with From derive

main
Elnu 2 years ago
parent 3f2f72a278
commit ea3e9bcd5a

@ -1,50 +1,72 @@
use actix_web::HttpResponse; use crate::Database;
use actix_web::{HttpResponse, error::BlockingError};
use std::{sync::{PoisonError, MutexGuard}, string::FromUtf8Error};
use sanitize_html::errors::SanitizeError;
use validator::ValidationErrors;
use derive_more::From;
#[derive(From, Debug)]
pub enum Error { pub enum Error {
InvalidOrigin, InvalidOrigin,
InvalidBody, InvalidBodyEncoding(FromUtf8Error),
InvalidBodyJson(serde_json::Error),
InvalidUrl, InvalidUrl,
InvalidFields, InvalidFields(ValidationErrors),
InvalidContentId, InvalidContentId,
InvalidParent, InvalidParent,
EmailRequired, EmailRequired,
NameRequired, NameRequired,
DatabaseAccessError, DatabaseAccessBlockingError(BlockingError),
DatabaseInternalError, DatabaseAccessPoisonError,
SanitizationError, DatabaseInternalError(rusqlite::Error),
PageFetchError, SanitizationError(SanitizeError),
PageFetchError(reqwest::Error),
} }
impl Error { impl Error {
pub fn to_http_response(&self) -> HttpResponse { pub fn to_http_response(&self) -> HttpResponse {
match self { match self {
Self::InvalidOrigin Self::InvalidOrigin
| Self::InvalidBody | Self::InvalidBodyEncoding(_)
| Self::InvalidBodyJson(_)
| Self::InvalidUrl | Self::InvalidUrl
| Self::InvalidFields | Self::InvalidFields(_)
| Self::InvalidContentId | Self::InvalidContentId
| Self::InvalidParent | Self::InvalidParent
| Self::EmailRequired | Self::EmailRequired
| Self::NameRequired => HttpResponse::BadRequest(), | Self::NameRequired => HttpResponse::BadRequest(),
Self::DatabaseAccessError Self::DatabaseAccessBlockingError(_)
| Self::DatabaseInternalError | Self::DatabaseAccessPoisonError
| Self::SanitizationError | Self::DatabaseInternalError(_)
| Self::PageFetchError => HttpResponse::InternalServerError(), | Self::SanitizationError(_)
| Self::PageFetchError(_) => {
eprintln!("{:?}", self);
HttpResponse::InternalServerError()
},
} }
.reason(match self { .reason(match self {
Self::InvalidOrigin => "invalid request origin", Self::InvalidOrigin => "invalid request origin",
Self::InvalidBody => "invalid request body", Self::InvalidBodyEncoding(_)
| Self::InvalidBodyJson(_) => "invalid request body",
Self::InvalidUrl => "invalid request url", Self::InvalidUrl => "invalid request url",
Self::InvalidFields => "invalid request field", Self::InvalidFields(_) => "invalid request field",
Self::InvalidContentId => "invalid request content id", Self::InvalidContentId => "invalid request content id",
Self::InvalidParent => "invalid comment parent", Self::InvalidParent => "invalid comment parent",
Self::EmailRequired => "comment email required", Self::EmailRequired => "comment email required",
Self::NameRequired => "comment name required", Self::NameRequired => "comment name required",
Self::DatabaseAccessError => "database access error", Self::DatabaseAccessBlockingError(_)
Self::DatabaseInternalError => "database internal error", | Self::DatabaseAccessPoisonError => "database access error",
Self::SanitizationError => "comment sanitization error", Self::DatabaseInternalError(_) => "database internal error",
Self::PageFetchError => "page fetch error", Self::SanitizationError(_) => "comment sanitization error",
Self::PageFetchError(_) => "page fetch error",
}) })
.finish() .finish()
} }
} }
impl<'a> From<PoisonError<MutexGuard<'a, Database>>> for Error {
fn from(_: PoisonError<MutexGuard<'a, Database>>) -> Self {
Self::DatabaseAccessPoisonError
}
}

@ -10,7 +10,7 @@ pub use error::*;
use actix_cors::Cors; use actix_cors::Cors;
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer}; use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer};
use clap::Parser; use clap::Parser;
use sanitize_html::{errors::SanitizeError, rules::predefined::DEFAULT, sanitize_str}; use sanitize_html::{rules::predefined::DEFAULT, sanitize_str};
use scraper::{Html, Selector}; use scraper::{Html, Selector};
use serde::Deserialize; use serde::Deserialize;
use std::fs::File; use std::fs::File;
@ -23,16 +23,9 @@ struct AppState {
} }
impl AppState { impl AppState {
fn get_db<'a>(&'a self, origin: Option<String>) -> Result<MutexGuard<'a, Database>, Error> { fn get_db<'a>(&'a self, origin: &str) -> Result<MutexGuard<'a, Database>, Error> {
let origin = match origin { match self.databases.get(origin) {
Some(origin) => origin, Some(database) => Ok(database.lock()?),
None => return Err(Error::InvalidOrigin),
};
match self.databases.get(&origin) {
Some(database) => Ok(match database.lock() {
Ok(database) => database,
Err(_) => return Err(Error::DatabaseAccessError),
}),
None => return Err(Error::InvalidOrigin), None => return Err(Error::InvalidOrigin),
} }
} }
@ -42,13 +35,13 @@ fn trim_protocol(url: &str) -> String {
url.replace("http://", "").replace("https://", "") url.replace("http://", "").replace("https://", "")
} }
fn get_request_origin(request: &HttpRequest) -> Option<String> { fn get_request_origin(request: &HttpRequest) -> Result<String, Error> {
match request.head().headers().get("Origin") { match request.head().headers().get("Origin") {
Some(origin) => match origin.to_str() { Some(origin) => match origin.to_str() {
Ok(origin) => Some(trim_protocol(origin)), Ok(origin) => Ok(trim_protocol(origin)),
Err(_) => None, Err(_) => Err(Error::InvalidOrigin),
}, },
None => None, None => Err(Error::InvalidOrigin),
} }
} }
@ -77,25 +70,8 @@ async fn _get_comments(
request: HttpRequest, request: HttpRequest,
content_id: web::Path<String>, content_id: web::Path<String>,
) -> Result<Vec<Comment>, Error> { ) -> Result<Vec<Comment>, Error> {
let origin = get_request_origin(&request); let origin = get_request_origin(&request)?;
match web::block(move || { web::block(move || Ok(data.get_db(&origin)?.get_comments(&content_id)?)).await?
Ok(
match match data.get_db(origin) {
Ok(database) => database,
Err(err) => return Err(err),
}
.get_comments(&content_id)
{
Ok(comments) => comments,
Err(_) => return Err(Error::DatabaseInternalError),
},
)
})
.await
{
Ok(result) => result,
Err(_) => Err(Error::DatabaseAccessError),
}
} }
#[get("/{content_id}")] #[get("/{content_id}")]
@ -121,104 +97,61 @@ async fn _post_comment(
request: HttpRequest, request: HttpRequest,
bytes: web::Bytes, bytes: web::Bytes,
) -> Result<(), Error> { ) -> Result<(), Error> {
match String::from_utf8(bytes.to_vec()) { let PostCommentsRequest { url, comment } = {
Ok(text) => { let mut req = serde_json::from_str::<PostCommentsRequest>(&String::from_utf8(bytes.to_vec())?)?;
let PostCommentsRequest { url, comment } = req.comment.text = sanitize_str(&DEFAULT, &req.comment.text)?.replace("&gt;", ">"); // required for markdown quotes
match serde_json::from_str::<PostCommentsRequest>(&text) { if let Some(ref mut author) = req.comment.author {
Ok(mut req) => { *author = sanitize_str(&DEFAULT, &author)?;
let mut sanitize_req = || -> Result<(), SanitizeError> { }
req.comment.text = req
sanitize_str(&DEFAULT, &req.comment.text)?.replace("&gt;", ">"); // required for markdown quotes };
if let Some(ref mut author) = req.comment.author { comment.validate()?;
*author = sanitize_str(&DEFAULT, &author)?; let origin = get_request_origin(&request)?;
} // Check to see if provided URL is in scope.
Ok(()) // This is to prevent malicious requests that try to get server to fetch external websites.
}; // (requires loop because "labels on blocks are unstable")
if let Err(_) = sanitize_req() { // https://github.com/rust-lang/rust/issues/48594
return Err(Error::SanitizationError); 'outer: loop {
} for site_root in data.databases.keys() {
req if site_root.eq(&origin) && trim_protocol(&url).starts_with(site_root) {
} break 'outer;
Err(_) => {
return Err(Error::InvalidBody);
}
};
if comment.validate().is_err() {
return Err(Error::InvalidFields);
}
let origin = match get_request_origin(&request) {
Some(origin) => origin,
None => return Err(Error::InvalidOrigin),
};
// Check to see if provided URL is in scope.
// This is to prevent malicious requests that try to get server to fetch external websites.
// (requires loop because "labels on blocks are unstable")
// https://github.com/rust-lang/rust/issues/48594
'outer: loop {
for site_root in data.databases.keys() {
if site_root.eq(&origin) && trim_protocol(&url).starts_with(site_root) {
break 'outer;
}
}
return Err(Error::InvalidUrl);
} }
match get_page_data(&url).await { }
Ok(page_data_option) => match page_data_option { return Err(Error::InvalidUrl);
Some(page_data) => { }
if page_data.content_id != comment.content_id { match get_page_data(&url).await? {
return Err(Error::InvalidContentId); Some(page_data) => {
if page_data.content_id != comment.content_id {
return Err(Error::InvalidContentId);
}
}
None => return Err(Error::InvalidUrl), // e.g. 404
};
web::block(move || {
let database = data.get_db(&origin)?;
if comment.author.is_none() && database.settings.name_required {
return Err(Error::NameRequired);
}
if comment.email.is_none() && database.settings.email_required {
return Err(Error::EmailRequired);
}
if let Some(parent) = comment.parent {
'outer2: loop {
let comments = database.get_comments(&comment.content_id)?;
for other_comment in comments.iter() {
if other_comment.id.unwrap() == parent {
if other_comment.parent.is_none() {
break 'outer2;
} }
break;
} }
None => return Err(Error::InvalidUrl), // e.g. 404
},
Err(_) => {
return Err(Error::PageFetchError);
}
};
match web::block(move || {
let database = match data.get_db(Some(origin)) {
Ok(database) => database,
Err(err) => return Err(err),
};
if comment.author.is_none() && database.settings.name_required {
return Err(Error::NameRequired);
} }
if comment.email.is_none() && database.settings.email_required { return Err(Error::InvalidParent);
return Err(Error::EmailRequired);
}
if let Some(parent) = comment.parent {
'outer2: loop {
match database.get_comments(&comment.content_id) {
Ok(comments) => {
for other_comment in comments.iter() {
if other_comment.id.unwrap() == parent {
if other_comment.parent.is_none() {
break 'outer2;
}
break;
}
}
}
Err(_) => {
return Err(Error::DatabaseInternalError);
}
};
return Err(Error::InvalidParent);
}
}
if let Err(_) = database.create_comment(&comment) {
return Err(Error::DatabaseInternalError);
}
Ok(())
})
.await
{
Ok(result) => result,
Err(_) => Err(Error::DatabaseAccessError),
} }
} }
Err(_) => Err(Error::InvalidBody), database.create_comment(&comment)?;
} Ok(())
}).await?
} }
#[post("/")] #[post("/")]

Loading…
Cancel
Save