implement simple locking to be used for locking simultaneous updates of files

This commit is contained in:
Nico Fricke 2025-02-01 22:53:24 +01:00
parent c53bdb0ad3
commit 163d2a61ca
2 changed files with 126 additions and 2 deletions

View File

@ -0,0 +1,115 @@
use crate::locking::LockingResult::{LOCKED, SUCCESS};
use axum::http::StatusCode;
use problem_details::ProblemDetails;
use std::collections::HashSet;
use std::future::Future;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Clone)]
pub struct SimpleLock {
locks: Arc<Mutex<HashSet<String>>>,
}
impl SimpleLock {
pub fn new() -> Self {
Self {
locks: Arc::new(Mutex::new(HashSet::new())),
}
}
}
impl Lock for SimpleLock {
async fn lock(&self, id: &str) -> bool {
let mut locks = self.locks.lock().await;
if locks.contains(id) {
false
} else {
locks.insert(id.into());
true
}
}
async fn unlock(&self, id: &str) {
let mut locks = self.locks.lock().await;
locks.remove(id);
}
}
pub trait Lock {
async fn lock(&self, id: &str) -> bool;
async fn unlock(&self, id: &str);
async fn run_if_not_locked<T, RESULT>(
&self,
lock_id: &str,
func: fn() -> RESULT,
) -> LockingResult<T>
where
RESULT: Future<Output = T>,
{
if self.lock(lock_id).await {
let result = func().await;
self.unlock(lock_id).await;
SUCCESS(result)
} else {
LOCKED
}
}
async fn run_if_locked<T, RESULT>(
&self,
lock_id: &str,
func: fn() -> RESULT,
) -> Result<T, ProblemDetails>
where
RESULT: Future<Output = T>,
{
match self.run_if_not_locked(lock_id, func).await {
LOCKED => Err(ProblemDetails::from_status_code(StatusCode::CONFLICT)
.with_detail(format!("Locked {:?}", lock_id))),
SUCCESS(result) => Ok(result),
}
}
}
#[derive(PartialEq, Debug)]
enum LockingResult<T> {
LOCKED,
SUCCESS(T),
}
#[cfg(test)]
mod tests {
use crate::locking::{Lock, LockingResult, SimpleLock};
const LOCK_ID: &str = "1";
const RESULT: &str = "result";
#[tokio::test]
async fn test_locking() {
let lock = SimpleLock::new();
assert!(lock.lock(LOCK_ID).await);
assert!(!lock.lock(LOCK_ID).await);
lock.unlock(LOCK_ID).await;
assert!(lock.lock(LOCK_ID).await);
}
#[tokio::test]
async fn test_run_if_not_locked() {
let lock = SimpleLock::new();
assert_eq!(
LockingResult::SUCCESS(RESULT),
lock.run_if_not_locked(LOCK_ID, test_fn).await
);
assert!(lock.lock(LOCK_ID).await);
assert_eq!(
LockingResult::LOCKED,
lock.run_if_not_locked(LOCK_ID, test_fn).await
);
}
async fn test_fn() -> &'static str {
RESULT
}
}

View File

@ -6,11 +6,13 @@ mod db;
mod errors;
mod extractor_helper;
mod files;
mod locking;
mod routes;
use crate::config::Config;
use crate::db::repository::Repository;
use crate::db::sqlite::Sqlite;
use crate::locking::{Lock, SimpleLock};
use axum::{middleware, Router};
use axum_jwks::Jwks;
use std::env;
@ -20,9 +22,10 @@ use tracing::{debug, error, info};
#[derive(Clone)]
pub struct AppState {
config: config::Config,
config: Config,
pub jwks: Jwks,
pub sqlite: Sqlite,
sqlite: Sqlite,
pub locks: SimpleLock,
}
impl AppState {
@ -30,6 +33,11 @@ impl AppState {
pub fn get_repository(&self) -> &impl Repository {
&self.sqlite
}
#[must_use]
pub fn get_lock(&self) -> &impl Lock {
&self.locks
}
}
#[tokio::main]
@ -48,6 +56,7 @@ async fn main() {
config,
jwks,
sqlite: db,
locks: SimpleLock::new(),
};
let app = Router::new()
.merge(routes::routes())