This commit is contained in:
2025-02-18 16:26:30 +08:00
committed by Chris Tam
commit 4a76b8589a
7 changed files with 2434 additions and 0 deletions

228
src/main.rs Normal file
View File

@@ -0,0 +1,228 @@
use actix_cors::Cors;
use actix_web::{ http::header, web, App, HttpServer, Responder, HttpResponse };
use serde::{ Deserialize, Serialize };
use reqwest::Client as HttpClient;
use std::sync::Mutex;
use std::collections::HashMap;
use std::fs;
use std::env;
use std::io::Write;
use dotenv::dotenv;
use rand::Rng;
#[derive(Serialize, Deserialize, Debug, Clone)]
struct User {
id: u64,
username: String,
url: String,
random_bits: Vec<bool>
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct RegisterUser {
id: u64,
username: String,
url: String
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct GenerateKeyInstruction {
sender_id: u64,
receiver_id: u64,
hash: u64,
total: u64,
keysize: u64
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct Database {
users: HashMap<u64, User>
}
impl Database {
fn new() -> Self {
Self {
users: HashMap::new()
}
}
fn get_user(&self, id: &u64) -> Option<&User> {
self.users.get(id)
}
fn update_user(&mut self, user: User) {
self.users.insert(user.id, user);
}
fn insert_user(&mut self, user: User) {
self.users.insert(user.id, user);
}
fn save_to_file(&self) -> std::io::Result<()> {
let data: String = serde_json::to_string(&self)?;
let args: Vec<String> = env::args().collect();
let file_name = format!("database-{}.json", args.get(2).unwrap_or(&env::var("SECURITYHUB_ID").unwrap()));
let mut file: fs::File = fs::File::create(file_name)?;
file.write_all(data.as_bytes())?;
Ok(())
}
fn load_from_file() -> std::io::Result<Self> {
let args: Vec<String> = env::args().collect();
let file_name = format!("database-{}.json", args.get(2).unwrap_or(&env::var("SECURITYHUB_ID").unwrap()));
let file_content: String = fs::read_to_string(file_name)?;
let db: Database = serde_json::from_str(&file_content)?;
Ok(db)
}
}
struct AppState {
db: Mutex<Database>
}
async fn user_list(app_state: web::Data<AppState>) -> impl Responder {
let db: std::sync::MutexGuard<Database> = app_state.db.lock().unwrap();
let user_ids: Vec<u64> = db.users.keys().cloned().collect();
HttpResponse::Ok().json(user_ids)
}
async fn register_client(app_state: web::Data<AppState>, user: web::Json<RegisterUser>) -> impl Responder {
let mut db: std::sync::MutexGuard<Database> = app_state.db.lock().unwrap();
let mut random_bits: Vec<bool> = Vec::new();
let mut rng = rand::thread_rng();
// number of random bits to generate
let num_bits = env::var("GENERATE_RANDOM_BITS_NUM").unwrap().parse::<u64>().unwrap();
for _ in 0..num_bits {
let random_bit: u8 = rng.gen_range(0..=1);
random_bits.push(random_bit == 1);
}
let create_user = User {
id: user.id,
username: user.username.clone(),
url: user.url.clone(),
random_bits
};
db.insert_user(create_user.clone());
let _ = db.save_to_file();
HttpResponse::Ok().json(create_user)
}
async fn generate_key_instruction(app_state: web::Data<AppState>, generate_key_instruction: web::Json<GenerateKeyInstruction>) -> impl Responder {
let mut db: std::sync::MutexGuard<Database> = app_state.db.lock().unwrap();
let sender = db.get_user(&generate_key_instruction.sender_id);
match sender {
Some(_) => (),
None => return HttpResponse::NotFound().finish()
}
let mut sender = sender.unwrap().clone();
let receiver = db.get_user(&generate_key_instruction.receiver_id);
match receiver {
Some(_) => (),
None => return HttpResponse::NotFound().finish()
}
let mut receiver = receiver.unwrap().clone();
// check remaining random bits
if sender.random_bits.len() < generate_key_instruction.keysize as usize || receiver.random_bits.len() < generate_key_instruction.keysize as usize{
return HttpResponse::BadRequest().finish()
}
// calculate key instruction A ^ B
let key_instruction: Vec<bool> = (0..generate_key_instruction.keysize)
.map(|i| sender.random_bits[i as usize] ^ receiver.random_bits[i as usize])
.collect();
// reduce the sender random bits by keysize
sender.random_bits = sender.random_bits.iter().skip(generate_key_instruction.keysize as usize).cloned().collect();
db.update_user(sender.clone());
let client = HttpClient::new();
let args: Vec<String> = env::args().collect();
let securityhub_id_env = env::var("SECURITYHUB_ID").unwrap();
let securityhub_id: u64 = args.get(2).unwrap_or(&securityhub_id_env).parse().unwrap();
let res = client.post(&format!("{}/receive_key_instruction", receiver.url))
.json(&serde_json::json!({
"sender_id": sender.id,
"securityhub_id": securityhub_id,
"hash": generate_key_instruction.hash,
"total": generate_key_instruction.total,
"key_instruction": key_instruction.clone()
}))
.send()
.await;
let mut flag = false;
match res {
Ok(_) => {
println!("Key instruction sent to client {}", receiver.id);
receiver.random_bits = receiver.random_bits.iter().skip(generate_key_instruction.keysize as usize).cloned().collect();
db.update_user(receiver.clone());
flag = true;
},
Err(err) => {
println!("Error connecting to client {}: {:?}", receiver.id, err);
}
}
let _ = db.save_to_file();
if flag {
HttpResponse::Ok().json(key_instruction)
} else {
HttpResponse::NotFound().finish()
}
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
dotenv().ok();
let args: Vec<String> = env::args().collect();
let db: Database = match Database::load_from_file() {
Ok(db) => db,
Err(_) => Database::new()
};
let data: web::Data<AppState> = web::Data::new(AppState {
db: Mutex::new(db)
});
// if args[1] is set, use the port from the args
// else use the port from the .env
let port_env = env::var("PORT").unwrap();
let port = args.get(1).unwrap_or(&port_env);
println!("Starting server on port {}", port);
// if args[2] is set, use the securityhub id from the args
// else use the securityhub id from the .env
let securityhub_id_env = env::var("SECURITYHUB_ID").unwrap();
let securityhub_id = args.get(2).unwrap_or(&securityhub_id_env);
println!("Security Hub ID: {}", securityhub_id);
HttpServer::new(move || {
App::new()
.wrap(
Cors::permissive()
.allowed_origin_fn(|origin, _req_head| {
origin.as_bytes().starts_with(b"http://localhost") || origin == "null"
})
.allowed_methods(vec!["GET", "POST", "PUT", "DELETE"])
.allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
.allowed_header(header::CONTENT_TYPE)
.supports_credentials()
.max_age(3600)
)
.app_data(data.clone())
.route("/register_client", web::post().to(register_client))
.route("/generate_key_instruction", web::post().to(generate_key_instruction))
.route("/user_list", web::post().to(user_list))
})
.bind(format!("127.0.0.1:{}", port))?
.run()
.await
}