Improve error handling with From derive

main
Elnu 2 years ago
parent 3f2f72a278
commit ea3e9bcd5a

@ -1,50 +1,72 @@
use actix_web::HttpResponse; use crate::Database;
use actix_web::{HttpResponse, error::BlockingError};
use std::{sync::{PoisonError, MutexGuard}, string::FromUtf8Error};
use sanitize_html::errors::SanitizeError;
use validator::ValidationErrors;
use derive_more::From;
#[derive(From, Debug)]
pub enum Error { pub enum Error {
InvalidOrigin, InvalidOrigin,
InvalidBody, InvalidBodyEncoding(FromUtf8Error),
InvalidBodyJson(serde_json::Error),
InvalidUrl, InvalidUrl,
InvalidFields, InvalidFields(ValidationErrors),
InvalidContentId, InvalidContentId,
InvalidParent, InvalidParent,
EmailRequired, EmailRequired,
NameRequired, NameRequired,
DatabaseAccessError, DatabaseAccessBlockingError(BlockingError),
DatabaseInternalError, DatabaseAccessPoisonError,
SanitizationError, DatabaseInternalError(rusqlite::Error),
PageFetchError, SanitizationError(SanitizeError),
PageFetchError(reqwest::Error),
} }
impl Error { impl Error {
pub fn to_http_response(&self) -> HttpResponse { pub fn to_http_response(&self) -> HttpResponse {
match self { match self {
Self::InvalidOrigin Self::InvalidOrigin
| Self::InvalidBody | Self::InvalidBodyEncoding(_)
| Self::InvalidBodyJson(_)
| Self::InvalidUrl | Self::InvalidUrl
| Self::InvalidFields | Self::InvalidFields(_)
| Self::InvalidContentId | Self::InvalidContentId
| Self::InvalidParent | Self::InvalidParent
| Self::EmailRequired | Self::EmailRequired
| Self::NameRequired => HttpResponse::BadRequest(), | Self::NameRequired => HttpResponse::BadRequest(),
Self::DatabaseAccessError Self::DatabaseAccessBlockingError(_)
| Self::DatabaseInternalError | Self::DatabaseAccessPoisonError
| Self::SanitizationError | Self::DatabaseInternalError(_)
| Self::PageFetchError => HttpResponse::InternalServerError(), | Self::SanitizationError(_)
| Self::PageFetchError(_) => {
eprintln!("{:?}", self);
HttpResponse::InternalServerError()
},
} }
.reason(match self { .reason(match self {
Self::InvalidOrigin => "invalid request origin", Self::InvalidOrigin => "invalid request origin",
Self::InvalidBody => "invalid request body", Self::InvalidBodyEncoding(_)
| Self::InvalidBodyJson(_) => "invalid request body",
Self::InvalidUrl => "invalid request url", Self::InvalidUrl => "invalid request url",
Self::InvalidFields => "invalid request field", Self::InvalidFields(_) => "invalid request field",
Self::InvalidContentId => "invalid request content id", Self::InvalidContentId => "invalid request content id",
Self::InvalidParent => "invalid comment parent", Self::InvalidParent => "invalid comment parent",
Self::EmailRequired => "comment email required", Self::EmailRequired => "comment email required",
Self::NameRequired => "comment name required", Self::NameRequired => "comment name required",
Self::DatabaseAccessError => "database access error", Self::DatabaseAccessBlockingError(_)
Self::DatabaseInternalError => "database internal error", | Self::DatabaseAccessPoisonError => "database access error",
Self::SanitizationError => "comment sanitization error", Self::DatabaseInternalError(_) => "database internal error",
Self::PageFetchError => "page fetch error", Self::SanitizationError(_) => "comment sanitization error",
Self::PageFetchError(_) => "page fetch error",
}) })
.finish() .finish()
} }
} }
impl<'a> From<PoisonError<MutexGuard<'a, Database>>> for Error {
fn from(_: PoisonError<MutexGuard<'a, Database>>) -> Self {
Self::DatabaseAccessPoisonError
}
}

@ -10,7 +10,7 @@ pub use error::*;
use actix_cors::Cors; use actix_cors::Cors;
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer}; use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer};
use clap::Parser; use clap::Parser;
use sanitize_html::{errors::SanitizeError, rules::predefined::DEFAULT, sanitize_str}; use sanitize_html::{rules::predefined::DEFAULT, sanitize_str};
use scraper::{Html, Selector}; use scraper::{Html, Selector};
use serde::Deserialize; use serde::Deserialize;
use std::fs::File; use std::fs::File;
@ -23,16 +23,9 @@ struct AppState {
} }
impl AppState { impl AppState {
fn get_db<'a>(&'a self, origin: Option<String>) -> Result<MutexGuard<'a, Database>, Error> { fn get_db<'a>(&'a self, origin: &str) -> Result<MutexGuard<'a, Database>, Error> {
let origin = match origin { match self.databases.get(origin) {
Some(origin) => origin, Some(database) => Ok(database.lock()?),
None => return Err(Error::InvalidOrigin),
};
match self.databases.get(&origin) {
Some(database) => Ok(match database.lock() {
Ok(database) => database,
Err(_) => return Err(Error::DatabaseAccessError),
}),
None => return Err(Error::InvalidOrigin), None => return Err(Error::InvalidOrigin),
} }
} }
@ -42,13 +35,13 @@ fn trim_protocol(url: &str) -> String {
url.replace("http://", "").replace("https://", "") url.replace("http://", "").replace("https://", "")
} }
fn get_request_origin(request: &HttpRequest) -> Option<String> { fn get_request_origin(request: &HttpRequest) -> Result<String, Error> {
match request.head().headers().get("Origin") { match request.head().headers().get("Origin") {
Some(origin) => match origin.to_str() { Some(origin) => match origin.to_str() {
Ok(origin) => Some(trim_protocol(origin)), Ok(origin) => Ok(trim_protocol(origin)),
Err(_) => None, Err(_) => Err(Error::InvalidOrigin),
}, },
None => None, None => Err(Error::InvalidOrigin),
} }
} }
@ -77,25 +70,8 @@ async fn _get_comments(
request: HttpRequest, request: HttpRequest,
content_id: web::Path<String>, content_id: web::Path<String>,
) -> Result<Vec<Comment>, Error> { ) -> Result<Vec<Comment>, Error> {
let origin = get_request_origin(&request); let origin = get_request_origin(&request)?;
match web::block(move || { web::block(move || Ok(data.get_db(&origin)?.get_comments(&content_id)?)).await?
Ok(
match match data.get_db(origin) {
Ok(database) => database,
Err(err) => return Err(err),
}
.get_comments(&content_id)
{
Ok(comments) => comments,
Err(_) => return Err(Error::DatabaseInternalError),
},
)
})
.await
{
Ok(result) => result,
Err(_) => Err(Error::DatabaseAccessError),
}
} }
#[get("/{content_id}")] #[get("/{content_id}")]
@ -121,35 +97,16 @@ async fn _post_comment(
request: HttpRequest, request: HttpRequest,
bytes: web::Bytes, bytes: web::Bytes,
) -> Result<(), Error> { ) -> Result<(), Error> {
match String::from_utf8(bytes.to_vec()) { let PostCommentsRequest { url, comment } = {
Ok(text) => { let mut req = serde_json::from_str::<PostCommentsRequest>(&String::from_utf8(bytes.to_vec())?)?;
let PostCommentsRequest { url, comment } = req.comment.text = sanitize_str(&DEFAULT, &req.comment.text)?.replace("&gt;", ">"); // required for markdown quotes
match serde_json::from_str::<PostCommentsRequest>(&text) {
Ok(mut req) => {
let mut sanitize_req = || -> Result<(), SanitizeError> {
req.comment.text =
sanitize_str(&DEFAULT, &req.comment.text)?.replace("&gt;", ">"); // required for markdown quotes
if let Some(ref mut author) = req.comment.author { if let Some(ref mut author) = req.comment.author {
*author = sanitize_str(&DEFAULT, &author)?; *author = sanitize_str(&DEFAULT, &author)?;
} }
Ok(())
};
if let Err(_) = sanitize_req() {
return Err(Error::SanitizationError);
}
req req
}
Err(_) => {
return Err(Error::InvalidBody);
}
};
if comment.validate().is_err() {
return Err(Error::InvalidFields);
}
let origin = match get_request_origin(&request) {
Some(origin) => origin,
None => return Err(Error::InvalidOrigin),
}; };
comment.validate()?;
let origin = get_request_origin(&request)?;
// Check to see if provided URL is in scope. // Check to see if provided URL is in scope.
// This is to prevent malicious requests that try to get server to fetch external websites. // This is to prevent malicious requests that try to get server to fetch external websites.
// (requires loop because "labels on blocks are unstable") // (requires loop because "labels on blocks are unstable")
@ -162,24 +119,16 @@ async fn _post_comment(
} }
return Err(Error::InvalidUrl); return Err(Error::InvalidUrl);
} }
match get_page_data(&url).await { match get_page_data(&url).await? {
Ok(page_data_option) => match page_data_option {
Some(page_data) => { Some(page_data) => {
if page_data.content_id != comment.content_id { if page_data.content_id != comment.content_id {
return Err(Error::InvalidContentId); return Err(Error::InvalidContentId);
} }
} }
None => return Err(Error::InvalidUrl), // e.g. 404 None => return Err(Error::InvalidUrl), // e.g. 404
},
Err(_) => {
return Err(Error::PageFetchError);
}
};
match web::block(move || {
let database = match data.get_db(Some(origin)) {
Ok(database) => database,
Err(err) => return Err(err),
}; };
web::block(move || {
let database = data.get_db(&origin)?;
if comment.author.is_none() && database.settings.name_required { if comment.author.is_none() && database.settings.name_required {
return Err(Error::NameRequired); return Err(Error::NameRequired);
} }
@ -188,8 +137,7 @@ async fn _post_comment(
} }
if let Some(parent) = comment.parent { if let Some(parent) = comment.parent {
'outer2: loop { 'outer2: loop {
match database.get_comments(&comment.content_id) { let comments = database.get_comments(&comment.content_id)?;
Ok(comments) => {
for other_comment in comments.iter() { for other_comment in comments.iter() {
if other_comment.id.unwrap() == parent { if other_comment.id.unwrap() == parent {
if other_comment.parent.is_none() { if other_comment.parent.is_none() {
@ -198,27 +146,12 @@ async fn _post_comment(
break; break;
} }
} }
}
Err(_) => {
return Err(Error::DatabaseInternalError);
}
};
return Err(Error::InvalidParent); return Err(Error::InvalidParent);
} }
} }
if let Err(_) = database.create_comment(&comment) { database.create_comment(&comment)?;
return Err(Error::DatabaseInternalError);
}
Ok(()) Ok(())
}) }).await?
.await
{
Ok(result) => result,
Err(_) => Err(Error::DatabaseAccessError),
}
}
Err(_) => Err(Error::InvalidBody),
}
} }
#[post("/")] #[post("/")]

Loading…
Cancel
Save