diff --git a/src/main.rs b/src/main.rs index 9e17c94..78fcef7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,12 +19,22 @@ struct AppState { arguments: Arguments, } +// TODO: Make error handling system not bad. +// Currently, it's a horrible mix of custom error types a direct HttpResponses, +// due to not being able to pass HttpResponse out of web::block... +// Well, it works, at least. + enum DatabaseAccessError { BadOrigin, AccessError, DatabaseError, } +enum CommentCreationError { + DatabaseAccessError(DatabaseAccessError), + BadParent, +} + impl DatabaseAccessError { fn to_http_response(&self) -> HttpResponse { match self { @@ -39,6 +49,15 @@ impl DatabaseAccessError { } } +impl CommentCreationError { + fn to_http_response(&self) -> HttpResponse { + match self { + Self::DatabaseAccessError(error) => error.to_http_response(), + Self::BadParent => HttpResponse::BadRequest().reason("invalid comment parent").finish(), + } + } +} + impl AppState { fn get_db<'a>( &'a self, @@ -183,13 +202,8 @@ async fn post_comment( if comment.email.is_none() && data.arguments.email_required { return HttpResponse::BadRequest().reason("email required").finish(); } - let origin = match request.head().headers().get("Origin") { - Some(origin) => match origin.to_str() { - Ok(origin) => origin, - // If the Origin is not valid ASCII, it is a bad request not sent from a browser - Err(_) => return HttpResponse::BadRequest().reason("bad origin").finish(), - }, - // If there is no Origin header, it is a bad request not sent from a browser + let origin = match get_request_origin(&request) { + Some(origin) => origin, None => return HttpResponse::BadRequest().reason("bad origin").finish(), }; // Check to see if provided URL is in scope. @@ -198,7 +212,7 @@ async fn post_comment( // https://github.com/rust-lang/rust/issues/48594 'outer: loop { for site_root in data.databases.keys() { - if site_root.starts_with(origin) && url.starts_with(site_root) { + if site_root.starts_with(&origin) && url.starts_with(site_root) { break 'outer; } } @@ -223,40 +237,40 @@ async fn post_comment( .finish() } }; - let database = match data.get_db(&request) { - Ok(database) => database, - Err(err) => return err.to_http_response(), - }; - if let Some(parent) = comment.parent { - 'outer2: loop { - match database.get_comments(&comment.content_id) { - Ok(comments) => { - for other_comment in comments.iter() { - if other_comment.id.unwrap() == parent { - if other_comment.parent.is_none() { - break 'outer2; + match web::block(move || { + let database = match data.get_db_with_origin(Some(origin)) { + Ok(database) => database, + Err(err) => return Err(CommentCreationError::DatabaseAccessError(err)), + }; + if let Some(parent) = comment.parent { + 'outer2: loop { + match database.get_comments(&comment.content_id) { + Ok(comments) => { + for other_comment in comments.iter() { + if other_comment.id.unwrap() == parent { + if other_comment.parent.is_none() { + break 'outer2; + } + break; } - break; } } - } - Err(_) => { - return HttpResponse::InternalServerError() - .reason("failed to get comments") - .finish() - } + Err(_) => return Err(CommentCreationError::DatabaseAccessError(DatabaseAccessError::DatabaseError)), + }; + return Err(CommentCreationError::BadParent); } - return HttpResponse::BadRequest() - .reason("invalid comment parent") - .finish(); } + if let Err(_) = database.create_comment(&comment) { + return Err(CommentCreationError::DatabaseAccessError(DatabaseAccessError::DatabaseError)); + } + Ok(()) + }).await { + Ok(result) => match result { + Ok(_) => HttpResponse::Ok().into(), + Err(error) => error.to_http_response(), + }, + Err(_) => DatabaseAccessError::AccessError.to_http_response(), } - if let Err(_) = database.create_comment(&comment) { - return HttpResponse::InternalServerError() - .reason("failed to create comment") - .finish(); - } - HttpResponse::Ok().into() } Err(_) => HttpResponse::BadRequest() .reason("failed to parse request body")