From d41911f506abf3c6338c511147adacc64583eeb1 Mon Sep 17 00:00:00 2001 From: Arnaud Bailly Date: Tue, 15 Oct 2024 08:46:50 +0200 Subject: Keep tester thread handle and abort when unregistering --- rust/src/web.rs | 115 +++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 89 insertions(+), 26 deletions(-) diff --git a/rust/src/web.rs b/rust/src/web.rs index 8ee3075..3f8f056 100644 --- a/rust/src/web.rs +++ b/rust/src/web.rs @@ -1,15 +1,15 @@ use actix_web::{get, middleware::Logger, post, web, App, HttpResponse, HttpServer, Responder}; use chrono::{DateTime, Utc}; use clap::Parser; -use futures::lock::Mutex; use handlebars::{DirectorySourceOptions, Handlebars}; use log::info; use proptest::test_runner::{Config, RngAlgorithm, TestRng, TestRunner}; use rand::Rng; use serde::{Deserialize, Serialize}; +use std::sync::Mutex; use std::time::Duration; use std::{collections::HashMap, sync::Arc}; -use tokio::task; +use tokio::task::{self, JoinHandle}; use uuid::Uuid; use lambda::lambda::{eval_all, eval_whnf, generate_expr, generate_exprs, gensym, Environment}; @@ -53,6 +53,7 @@ struct Leaderboard { trait AppState: Send + Sync { fn register(&mut self, registration: &Registration) -> RegistrationResult; + fn unregister(&mut self, url: &String); } #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] @@ -86,7 +87,7 @@ enum TestResult { } impl Client { - fn new(url: String, name: String) -> Self { + fn new(url: String, name: String, delay: Duration) -> Self { let id = Uuid::new_v4(); let runner = TestRunner::new_with_rng( Config::default(), @@ -99,7 +100,7 @@ impl Client { grade: 1, runner, results: Vec::new(), - delay: Duration::from_secs(10), + delay, } } @@ -182,15 +183,26 @@ impl Client { #[derive(Debug)] struct State { - clients: HashMap>>, + base_duration: Duration, + clients: HashMap>, JoinHandle<()>)>, } impl State { fn new() -> Self { + State::with_duration(Duration::from_secs(10)) + } + + fn with_duration(base_duration: Duration) -> Self { Self { + base_duration, clients: HashMap::new(), } } + + fn client_events(&self, url: &String) -> usize { + let client = self.clients.get(url).unwrap().0.lock().unwrap(); + client.results.len() + } } impl AppState for State { @@ -200,14 +212,21 @@ impl AppState for State { url: registration.url.clone(), } } else { - let client = Client::new(registration.url.clone(), registration.name.clone()); + let client = Client::new( + registration.url.clone(), + registration.name.clone(), + self.base_duration, + ); let id = client.id.to_string(); let client_ref = Arc::new(Mutex::new(client)); - let client_s = client_ref.clone(); - self.clients.insert(registration.url.clone(), client_ref); // let it run in the background // FIXME: should find a way to handle graceful termination - task::spawn(async move { send_tests(client_s).await }); + let client_handle = task::spawn(send_tests(client_ref.clone())); + + self.clients.insert( + registration.url.clone(), + (client_ref.clone(), client_handle), + ); RegistrationResult::RegistrationSuccess { id, @@ -215,6 +234,11 @@ impl AppState for State { } } } + + fn unregister(&mut self, url: &String) { + let (_, handle) = self.clients.get(url).unwrap(); + handle.abort() + } } #[post("/register")] @@ -222,7 +246,7 @@ async fn register( app_state: web::Data>>, registration: web::Json, ) -> impl Responder { - let result = app_state.lock().await.register(®istration); + let result = app_state.lock().unwrap().register(®istration); match result { RegistrationResult::RegistrationSuccess { .. } => HttpResponse::Ok().json(result), RegistrationResult::UrlAlreadyRegistered { .. } => HttpResponse::BadRequest().json(result), @@ -254,10 +278,10 @@ async fn leaderboard( app_state: web::Data>>, hb: web::Data>, ) -> impl Responder { - let clients = &app_state.lock().await.clients; + let clients = &app_state.lock().unwrap().clients; let mut client_data = vec![]; for client in clients.values() { - let client = client.lock().await; + let client = client.0.lock().unwrap(); client_data.push(ClientData::from(&client)); } client_data.sort_by(|a, b| b.grade.cmp(&a.grade)); @@ -321,31 +345,42 @@ async fn main() -> std::io::Result<()> { .await } +fn get_test(client_m: &Mutex) -> (String, String, String) { + let mut client = client_m.lock().unwrap(); + let (input, expected) = client.generate_expr(); + (input, client.url.clone(), expected) +} + async fn send_tests(client_m: Arc>) { loop { - let sleep = sleep_time(&client_m).await; + let sleep = sleep_time(&client_m); tokio::time::sleep(sleep).await; - { - let mut client = client_m.lock().await; - let (input, expected) = client.generate_expr(); - let response = send_test(&input, &client.url).await; + let (input, url, expected) = get_test(&client_m); - let test = client.check_result(&expected, &response); - client.apply(&test); + let response = send_test(&input, &url, sleep).await; + + apply_result(&client_m, expected, response); } } } -async fn sleep_time(client_m: &Arc>) -> Duration { - client_m.lock().await.time_to_next_test() +fn apply_result(client_m: &Mutex, expected: String, response: Result) { + let mut client = client_m.lock().unwrap(); + let test = client.check_result(&expected, &response); + client.apply(&test); +} + +fn sleep_time(client_m: &Arc>) -> Duration { + client_m.lock().unwrap().time_to_next_test() } -async fn send_test(input: &String, url: &String) -> Result { +async fn send_test(input: &String, url: &String, timeout: Duration) -> Result { info!("Sending {} to {}", input, url); let body = input.clone(); let response = reqwest::Client::new() .post(url) + .timeout(timeout) .header("content-type", "text/plain") .body(body) .send() @@ -426,7 +461,7 @@ mod app_tests { name: "foo".to_string(), }; - state.lock().await.register(®istration); + state.lock().unwrap().register(®istration); let req = test::TestRequest::post() .uri("/register") @@ -484,7 +519,7 @@ mod app_tests { #[actix_web::test] async fn get_leaderboard_returns_html_page_listing_clients_state() { let app_state = Arc::new(Mutex::new(State::new())); - app_state.lock().await.register(&Registration { + app_state.lock().unwrap().register(&Registration { url: "http://1.2.3.4".to_string(), name: "client1".to_string(), }); @@ -539,8 +574,37 @@ mod app_tests { ); } + #[test] + async fn unregistering_registered_client_stops_tester_thread_from_sending_tests() { + let mut app_state = State::with_duration(Duration::from_millis(100)); + let registration = Registration { + name: "foo".to_string(), + url: "http://1.2.3.4".to_string(), + }; + + let reg = app_state.register(®istration); + assert!(matches!( + reg, + RegistrationResult::RegistrationSuccess { .. } + )); + + tokio::time::sleep(Duration::from_millis(500)).await; + + app_state.unregister(®istration.url); + + let grade_before = app_state.client_events(®istration.url); + tokio::time::sleep(Duration::from_millis(500)).await; + let grade_after = app_state.client_events(®istration.url); + + assert_eq!(grade_before, grade_after); + } + fn client() -> Client { - Client::new("http://1.2.3.4".to_string(), "foo".to_string()) + Client::new( + "http://1.2.3.4".to_string(), + "foo".to_string(), + Duration::from_secs(10), + ) } #[test] @@ -575,7 +639,6 @@ mod app_tests { let parsed = parse(&input); match &parsed[..] { [Value::Sym(name)] => { - println!("{}", name); assert!(name.chars().all(|c| c.is_ascii_alphanumeric())); } _ => panic!("Expected symbol, got {:?}", parsed), -- cgit v1.2.3