diff --git a/src/main.rs b/src/main.rs index 9e16000..8ddc3a9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,8 +8,8 @@ pub use database::Database; use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; use scraper::{Html, Selector}; use serde::Deserialize; -use std::collections::HashMap; -use std::sync::{Mutex, MutexGuard}; +use std::{collections::HashMap, sync::MutexGuard}; +use std::sync::Mutex; use validator::Validate; use sanitize_html::{sanitize_str, rules::predefined::DEFAULT, errors::SanitizeError}; use clap::Parser; @@ -19,6 +19,58 @@ struct AppState { arguments: Arguments, } +enum DatabaseAccessError { + BadOrigin, + AccessError, + DatabaseError +} + +impl DatabaseAccessError { + fn to_http_response(&self) -> HttpResponse { + match self { + Self::BadOrigin => HttpResponse::BadRequest().reason("bad origin").finish(), + Self::AccessError => HttpResponse::InternalServerError().reason("database access error").finish(), // e.g. PoisonError + Self::DatabaseError => HttpResponse::InternalServerError().reason("database error").finish(), + } + } +} + +impl AppState { + fn get_db<'a>( + &'a self, + request: &HttpRequest, + ) -> Result, DatabaseAccessError> { + self.get_db_with_origin(get_request_origin(request)) + } + + fn get_db_with_origin<'a>( + &'a self, + origin: Option, + ) -> Result, DatabaseAccessError> { + let origin = match origin { + Some(origin) => origin, + None => return Err(DatabaseAccessError::BadOrigin), + }; + match self.databases.get(&origin) { + Some(database) => Ok(match database.lock() { + Ok(database) => database, + Err(_) => return Err(DatabaseAccessError::AccessError), + }), + None => return Err(DatabaseAccessError::AccessError), + } + } +} + +fn get_request_origin(request: &HttpRequest) -> Option { + match request.head().headers().get("Origin") { + Some(origin) => match origin.to_str() { + Ok(origin) => Some(origin.to_owned()), + Err(_) => None, + }, + None => None, + } +} + #[derive(Default, Parser)] #[clap(author, version, about)] struct Arguments { @@ -34,37 +86,29 @@ struct Arguments { email_required: bool, } -fn get_db<'a>( - data: &'a web::Data, - request: &HttpRequest, -) -> Result, HttpResponse> { - let origin = match request.head().headers().get("Origin") { - Some(origin) => match origin.to_str() { - Ok(origin) => origin, - Err(_) => return Err(HttpResponse::BadRequest().reason("bad origin").finish()), - }, - None => return Err(HttpResponse::BadRequest().reason("bad origin").finish()), - }; - match data.databases.get(origin) { - Some(database) => Ok(match database.lock() { - Ok(database) => database, - Err(_) => return Err(HttpResponse::InternalServerError().reason("database error").finish()), - }), - None => return Err(HttpResponse::BadRequest().reason("bad origin").finish()), - } -} - #[get("/{content_id}")] async fn get_comments( data: web::Data, request: HttpRequest, content_id: web::Path, ) -> impl Responder { - let database = match get_db(&data, &request) { - Ok(database) => database, - Err(response) => return response, + let origin = get_request_origin(&request); + let comments = match web::block(move || { + Ok(match match data.get_db_with_origin(origin) { + Ok(database) => database, + Err(err) => return Err(err), + }.get_comments(&content_id) { + Ok(comments) => comments, + Err(_) => return Err(DatabaseAccessError::DatabaseError), + }) + }).await { + Ok(comments) => match comments { + Ok(comments) => comments, + Err(err) => return err.to_http_response(), + }, + Err(_) => return DatabaseAccessError::AccessError.to_http_response(), }; - HttpResponse::Ok().json(database.get_comments(&content_id).unwrap()) + HttpResponse::Ok().json(comments) } #[derive(Deserialize)] @@ -139,9 +183,9 @@ async fn post_comment( }, Err(_) => return HttpResponse::InternalServerError().reason("failed to get page data").finish(), }; - let database = match get_db(&data, &request) { + let database = match data.get_db(&request) { Ok(database) => database, - Err(response) => return response, + Err(err) => return err.to_http_response(), }; if let Some(parent) = comment.parent { 'outer2: loop {