Add configuration file, make HTTP/HTTPS agnostic
This commit is contained in:
parent
98dc007fa4
commit
4bfcf6631e
6 changed files with 110 additions and 52 deletions
|
@ -1,17 +1,41 @@
|
|||
use crate::Comment;
|
||||
use rusqlite::{params, Connection, Result};
|
||||
use serde::Deserialize;
|
||||
use std::fs;
|
||||
use derive_more::From;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub struct Database {
|
||||
conn: Connection,
|
||||
pub settings: DatabaseSettings,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[serde(default)]
|
||||
pub struct DatabaseSettings {
|
||||
pub name_required: bool,
|
||||
pub email_required: bool,
|
||||
pub file: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(From, Debug)]
|
||||
pub enum DatabaseCreationError {
|
||||
RusqliteError(rusqlite::Error),
|
||||
IoError(std::io::Error),
|
||||
}
|
||||
|
||||
impl Database {
|
||||
pub fn new(testing: bool, name: &str) -> Result<Self> {
|
||||
let name = name.replace("http://", "").replace("https://", "");
|
||||
pub fn new(testing: bool, name: &str, settings: DatabaseSettings) -> Result<Self, DatabaseCreationError> {
|
||||
let conn = if testing {
|
||||
Connection::open_in_memory()
|
||||
} else {
|
||||
Connection::open(format!("{name}.db"))
|
||||
let path = PathBuf::from(match &settings.file {
|
||||
Some(path) => path.clone(),
|
||||
None => format!("{name}.db"),
|
||||
});
|
||||
fs::create_dir_all(path.parent().unwrap())?;
|
||||
Connection::open(path)
|
||||
}?;
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS comment (
|
||||
|
@ -25,7 +49,7 @@ impl Database {
|
|||
)",
|
||||
params![],
|
||||
)?;
|
||||
Ok(Self { conn })
|
||||
Ok(Self { conn, settings })
|
||||
}
|
||||
|
||||
pub fn get_comments(&self, content_id: &str) -> Result<Vec<Comment>> {
|
||||
|
|
57
src/main.rs
57
src/main.rs
|
@ -2,10 +2,10 @@ mod comment;
|
|||
pub use comment::*;
|
||||
|
||||
mod database;
|
||||
pub use database::Database;
|
||||
pub use database::*;
|
||||
|
||||
mod error;
|
||||
pub use error::Error;
|
||||
pub use error::*;
|
||||
|
||||
use actix_cors::Cors;
|
||||
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer};
|
||||
|
@ -13,13 +13,13 @@ use clap::Parser;
|
|||
use sanitize_html::{errors::SanitizeError, rules::predefined::DEFAULT, sanitize_str};
|
||||
use scraper::{Html, Selector};
|
||||
use serde::Deserialize;
|
||||
use std::fs::File;
|
||||
use std::sync::Mutex;
|
||||
use std::{collections::HashMap, sync::MutexGuard};
|
||||
use validator::Validate;
|
||||
|
||||
struct AppState {
|
||||
databases: HashMap<String, Mutex<Database>>,
|
||||
arguments: Arguments,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
|
@ -38,10 +38,14 @@ impl AppState {
|
|||
}
|
||||
}
|
||||
|
||||
fn trim_protocol(url: &str) -> String {
|
||||
url.replace("http://", "").replace("https://", "")
|
||||
}
|
||||
|
||||
fn get_request_origin(request: &HttpRequest) -> Option<String> {
|
||||
match request.head().headers().get("Origin") {
|
||||
Some(origin) => match origin.to_str() {
|
||||
Ok(origin) => Some(origin.to_owned()),
|
||||
Ok(origin) => Some(trim_protocol(origin)),
|
||||
Err(_) => None,
|
||||
},
|
||||
None => None,
|
||||
|
@ -51,6 +55,8 @@ fn get_request_origin(request: &HttpRequest) -> Option<String> {
|
|||
#[derive(Default, Parser)]
|
||||
#[clap(author, version, about)]
|
||||
struct Arguments {
|
||||
#[clap(default_value = "soudan.yaml", help = "Set configuration file")]
|
||||
config: String,
|
||||
#[clap(
|
||||
short,
|
||||
long,
|
||||
|
@ -58,22 +64,12 @@ struct Arguments {
|
|||
help = "Set port where HTTP requests will be received"
|
||||
)]
|
||||
port: u16,
|
||||
#[clap(
|
||||
required = true,
|
||||
min_values = 1,
|
||||
help = "Set sites where comments will be posted"
|
||||
)]
|
||||
sites: Vec<String>,
|
||||
#[clap(
|
||||
short,
|
||||
long,
|
||||
help = "Run in testing mode, with in-memory database(s) and permissive CORS policy"
|
||||
)]
|
||||
testing: bool,
|
||||
#[clap(short, long, help = "Require name for comment submissions")]
|
||||
name_required: bool,
|
||||
#[clap(short, long, help = "Require email for comment submissions")]
|
||||
email_required: bool,
|
||||
}
|
||||
|
||||
async fn _get_comments(
|
||||
|
@ -149,13 +145,7 @@ async fn _post_comment(
|
|||
};
|
||||
if comment.validate().is_err() {
|
||||
return Err(Error::InvalidFields);
|
||||
}
|
||||
if comment.author.is_none() && data.arguments.name_required {
|
||||
return Err(Error::NameRequired);
|
||||
}
|
||||
if comment.email.is_none() && data.arguments.email_required {
|
||||
return Err(Error::EmailRequired);
|
||||
}
|
||||
}
|
||||
let origin = match get_request_origin(&request) {
|
||||
Some(origin) => origin,
|
||||
None => return Err(Error::InvalidOrigin),
|
||||
|
@ -166,7 +156,7 @@ async fn _post_comment(
|
|||
// https://github.com/rust-lang/rust/issues/48594
|
||||
'outer: loop {
|
||||
for site_root in data.databases.keys() {
|
||||
if site_root.starts_with(&origin) && url.starts_with(site_root) {
|
||||
if site_root.eq(&origin) && trim_protocol(&url).starts_with(site_root) {
|
||||
break 'outer;
|
||||
}
|
||||
}
|
||||
|
@ -190,6 +180,12 @@ async fn _post_comment(
|
|||
Ok(database) => database,
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
if comment.author.is_none() && database.settings.name_required {
|
||||
return Err(Error::NameRequired);
|
||||
}
|
||||
if comment.email.is_none() && database.settings.email_required {
|
||||
return Err(Error::EmailRequired);
|
||||
}
|
||||
if let Some(parent) = comment.parent {
|
||||
'outer2: loop {
|
||||
match database.get_comments(&comment.content_id) {
|
||||
|
@ -270,20 +266,21 @@ async fn get_page_data(url: &str) -> Result<Option<PageData>, reqwest::Error> {
|
|||
}
|
||||
|
||||
#[actix_web::main]
|
||||
async fn main() -> Result<(), std::io::Error> {
|
||||
async fn main() -> std::io::Result<()> {
|
||||
let arguments = Arguments::parse();
|
||||
let database_settings: HashMap<String, DatabaseSettings> = match serde_yaml::from_reader(File::open(arguments.config)?) {
|
||||
Ok(settings) => settings,
|
||||
Err(_) => return Err(std::io::Error::new(std::io::ErrorKind::Other, "invalid config file")),
|
||||
};
|
||||
let mut databases = HashMap::new();
|
||||
for domain in arguments.sites.iter() {
|
||||
for (site, settings) in database_settings.iter() {
|
||||
databases.insert(
|
||||
domain.to_owned(),
|
||||
Mutex::new(Database::new(arguments.testing, domain).unwrap()),
|
||||
site.to_owned(),
|
||||
Mutex::new(Database::new(arguments.testing, site, settings.clone()).unwrap()),
|
||||
);
|
||||
}
|
||||
let port = arguments.port;
|
||||
let state = web::Data::new(AppState {
|
||||
databases,
|
||||
arguments,
|
||||
});
|
||||
let state = web::Data::new(AppState { databases });
|
||||
HttpServer::new(move || {
|
||||
App::new()
|
||||
.service(get_comments)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue