Improve error handling with From derive

main
Elnu 2 years ago
parent 3f2f72a278
commit ea3e9bcd5a

@ -1,50 +1,72 @@
use actix_web::HttpResponse;
use crate::Database;
use actix_web::{HttpResponse, error::BlockingError};
use std::{sync::{PoisonError, MutexGuard}, string::FromUtf8Error};
use sanitize_html::errors::SanitizeError;
use validator::ValidationErrors;
use derive_more::From;
#[derive(From, Debug)]
pub enum Error {
InvalidOrigin,
InvalidBody,
InvalidBodyEncoding(FromUtf8Error),
InvalidBodyJson(serde_json::Error),
InvalidUrl,
InvalidFields,
InvalidFields(ValidationErrors),
InvalidContentId,
InvalidParent,
EmailRequired,
NameRequired,
DatabaseAccessError,
DatabaseInternalError,
SanitizationError,
PageFetchError,
DatabaseAccessBlockingError(BlockingError),
DatabaseAccessPoisonError,
DatabaseInternalError(rusqlite::Error),
SanitizationError(SanitizeError),
PageFetchError(reqwest::Error),
}
impl Error {
pub fn to_http_response(&self) -> HttpResponse {
match self {
Self::InvalidOrigin
| Self::InvalidBody
| Self::InvalidBodyEncoding(_)
| Self::InvalidBodyJson(_)
| Self::InvalidUrl
| Self::InvalidFields
| Self::InvalidFields(_)
| Self::InvalidContentId
| Self::InvalidParent
| Self::EmailRequired
| Self::NameRequired => HttpResponse::BadRequest(),
Self::DatabaseAccessError
| Self::DatabaseInternalError
| Self::SanitizationError
| Self::PageFetchError => HttpResponse::InternalServerError(),
Self::DatabaseAccessBlockingError(_)
| Self::DatabaseAccessPoisonError
| Self::DatabaseInternalError(_)
| Self::SanitizationError(_)
| Self::PageFetchError(_) => {
eprintln!("{:?}", self);
HttpResponse::InternalServerError()
},
}
.reason(match self {
Self::InvalidOrigin => "invalid request origin",
Self::InvalidBody => "invalid request body",
Self::InvalidBodyEncoding(_)
| Self::InvalidBodyJson(_) => "invalid request body",
Self::InvalidUrl => "invalid request url",
Self::InvalidFields => "invalid request field",
Self::InvalidFields(_) => "invalid request field",
Self::InvalidContentId => "invalid request content id",
Self::InvalidParent => "invalid comment parent",
Self::EmailRequired => "comment email required",
Self::NameRequired => "comment name required",
Self::DatabaseAccessError => "database access error",
Self::DatabaseInternalError => "database internal error",
Self::SanitizationError => "comment sanitization error",
Self::PageFetchError => "page fetch error",
Self::DatabaseAccessBlockingError(_)
| Self::DatabaseAccessPoisonError => "database access error",
Self::DatabaseInternalError(_) => "database internal error",
Self::SanitizationError(_) => "comment sanitization error",
Self::PageFetchError(_) => "page fetch error",
})
.finish()
}
}
impl<'a> From<PoisonError<MutexGuard<'a, Database>>> for Error {
fn from(_: PoisonError<MutexGuard<'a, Database>>) -> Self {
Self::DatabaseAccessPoisonError
}
}

@ -10,7 +10,7 @@ pub use error::*;
use actix_cors::Cors;
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer};
use clap::Parser;
use sanitize_html::{errors::SanitizeError, rules::predefined::DEFAULT, sanitize_str};
use sanitize_html::{rules::predefined::DEFAULT, sanitize_str};
use scraper::{Html, Selector};
use serde::Deserialize;
use std::fs::File;
@ -23,16 +23,9 @@ struct AppState {
}
impl AppState {
fn get_db<'a>(&'a self, origin: Option<String>) -> Result<MutexGuard<'a, Database>, Error> {
let origin = match origin {
Some(origin) => origin,
None => return Err(Error::InvalidOrigin),
};
match self.databases.get(&origin) {
Some(database) => Ok(match database.lock() {
Ok(database) => database,
Err(_) => return Err(Error::DatabaseAccessError),
}),
fn get_db<'a>(&'a self, origin: &str) -> Result<MutexGuard<'a, Database>, Error> {
match self.databases.get(origin) {
Some(database) => Ok(database.lock()?),
None => return Err(Error::InvalidOrigin),
}
}
@ -42,13 +35,13 @@ fn trim_protocol(url: &str) -> String {
url.replace("http://", "").replace("https://", "")
}
fn get_request_origin(request: &HttpRequest) -> Option<String> {
fn get_request_origin(request: &HttpRequest) -> Result<String, Error> {
match request.head().headers().get("Origin") {
Some(origin) => match origin.to_str() {
Ok(origin) => Some(trim_protocol(origin)),
Err(_) => None,
Ok(origin) => Ok(trim_protocol(origin)),
Err(_) => Err(Error::InvalidOrigin),
},
None => None,
None => Err(Error::InvalidOrigin),
}
}
@ -77,25 +70,8 @@ async fn _get_comments(
request: HttpRequest,
content_id: web::Path<String>,
) -> Result<Vec<Comment>, Error> {
let origin = get_request_origin(&request);
match web::block(move || {
Ok(
match match data.get_db(origin) {
Ok(database) => database,
Err(err) => return Err(err),
}
.get_comments(&content_id)
{
Ok(comments) => comments,
Err(_) => return Err(Error::DatabaseInternalError),
},
)
})
.await
{
Ok(result) => result,
Err(_) => Err(Error::DatabaseAccessError),
}
let origin = get_request_origin(&request)?;
web::block(move || Ok(data.get_db(&origin)?.get_comments(&content_id)?)).await?
}
#[get("/{content_id}")]
@ -121,104 +97,61 @@ async fn _post_comment(
request: HttpRequest,
bytes: web::Bytes,
) -> Result<(), Error> {
match String::from_utf8(bytes.to_vec()) {
Ok(text) => {
let PostCommentsRequest { url, comment } =
match serde_json::from_str::<PostCommentsRequest>(&text) {
Ok(mut req) => {
let mut sanitize_req = || -> Result<(), SanitizeError> {
req.comment.text =
sanitize_str(&DEFAULT, &req.comment.text)?.replace("&gt;", ">"); // required for markdown quotes
if let Some(ref mut author) = req.comment.author {
*author = sanitize_str(&DEFAULT, &author)?;
}
Ok(())
};
if let Err(_) = sanitize_req() {
return Err(Error::SanitizationError);
}
req
}
Err(_) => {
return Err(Error::InvalidBody);
}
};
if comment.validate().is_err() {
return Err(Error::InvalidFields);
let PostCommentsRequest { url, comment } = {
let mut req = serde_json::from_str::<PostCommentsRequest>(&String::from_utf8(bytes.to_vec())?)?;
req.comment.text = sanitize_str(&DEFAULT, &req.comment.text)?.replace("&gt;", ">"); // required for markdown quotes
if let Some(ref mut author) = req.comment.author {
*author = sanitize_str(&DEFAULT, &author)?;
}
req
};
comment.validate()?;
let origin = get_request_origin(&request)?;
// Check to see if provided URL is in scope.
// This is to prevent malicious requests that try to get server to fetch external websites.
// (requires loop because "labels on blocks are unstable")
// https://github.com/rust-lang/rust/issues/48594
'outer: loop {
for site_root in data.databases.keys() {
if site_root.eq(&origin) && trim_protocol(&url).starts_with(site_root) {
break 'outer;
}
let origin = match get_request_origin(&request) {
Some(origin) => origin,
None => return Err(Error::InvalidOrigin),
};
// Check to see if provided URL is in scope.
// This is to prevent malicious requests that try to get server to fetch external websites.
// (requires loop because "labels on blocks are unstable")
// https://github.com/rust-lang/rust/issues/48594
'outer: loop {
for site_root in data.databases.keys() {
if site_root.eq(&origin) && trim_protocol(&url).starts_with(site_root) {
break 'outer;
}
}
return Err(Error::InvalidUrl);
}
return Err(Error::InvalidUrl);
}
match get_page_data(&url).await? {
Some(page_data) => {
if page_data.content_id != comment.content_id {
return Err(Error::InvalidContentId);
}
match get_page_data(&url).await {
Ok(page_data_option) => match page_data_option {
Some(page_data) => {
if page_data.content_id != comment.content_id {
return Err(Error::InvalidContentId);
}
None => return Err(Error::InvalidUrl), // e.g. 404
};
web::block(move || {
let database = data.get_db(&origin)?;
if comment.author.is_none() && database.settings.name_required {
return Err(Error::NameRequired);
}
if comment.email.is_none() && database.settings.email_required {
return Err(Error::EmailRequired);
}
if let Some(parent) = comment.parent {
'outer2: loop {
let comments = database.get_comments(&comment.content_id)?;
for other_comment in comments.iter() {
if other_comment.id.unwrap() == parent {
if other_comment.parent.is_none() {
break 'outer2;
}
break;
}
None => return Err(Error::InvalidUrl), // e.g. 404
},
Err(_) => {
return Err(Error::PageFetchError);
}
};
match web::block(move || {
let database = match data.get_db(Some(origin)) {
Ok(database) => database,
Err(err) => return Err(err),
};
if comment.author.is_none() && database.settings.name_required {
return Err(Error::NameRequired);
}
if comment.email.is_none() && database.settings.email_required {
return Err(Error::EmailRequired);
}
if let Some(parent) = comment.parent {
'outer2: loop {
match database.get_comments(&comment.content_id) {
Ok(comments) => {
for other_comment in comments.iter() {
if other_comment.id.unwrap() == parent {
if other_comment.parent.is_none() {
break 'outer2;
}
break;
}
}
}
Err(_) => {
return Err(Error::DatabaseInternalError);
}
};
return Err(Error::InvalidParent);
}
}
if let Err(_) = database.create_comment(&comment) {
return Err(Error::DatabaseInternalError);
}
Ok(())
})
.await
{
Ok(result) => result,
Err(_) => Err(Error::DatabaseAccessError),
return Err(Error::InvalidParent);
}
}
Err(_) => Err(Error::InvalidBody),
}
database.create_comment(&comment)?;
Ok(())
}).await?
}
#[post("/")]

Loading…
Cancel
Save