From 90d0c6e3b6044a88f4679ef6daef89aef71d299a Mon Sep 17 00:00:00 2001 From: TuDatTr Date: Fri, 2 Feb 2024 03:39:52 +0100 Subject: [PATCH] Reworked MessageBody Typesystem and added challenge 03 Signed-off-by: TuDatTr --- src/lib.rs | 8 ++- src/messages.rs | 88 +++++++++++++------------- src/routes.rs | 92 ++++++++++++++++----------- tests/broadcast.rs | 95 ++++++++++++++++++++++++++++ tests/echo.rs | 92 +++++++++++++-------------- tests/uid.rs | 150 +++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 396 insertions(+), 129 deletions(-) create mode 100644 tests/broadcast.rs create mode 100644 tests/uid.rs diff --git a/src/lib.rs b/src/lib.rs index 1475e3d..71e3c24 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,13 @@ use routes::challenge; pub mod messages; pub mod routes; -type AppState = Arc>; +type AppState = Arc>; + +#[derive(Default)] +pub struct State { + uid: IdCounter, + message: String, +} #[derive(Debug, Default)] pub struct IdCounter { diff --git a/src/messages.rs b/src/messages.rs index 04c709d..e9e17cd 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -1,11 +1,49 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize, Serialize)] +#[serde(tag = "type")] pub enum MessageBody { - EchoRequest(EchoRequest), - EchoResponse(EchoResponse), - GenerateRequest(GenerateRequest), - GenerateResponse(GenerateResponse), + #[serde(rename = "echo")] + Echo { + msg_id: u64, + echo: String, + }, + #[serde(rename = "echo_ok")] + EchoOk { + msg_id: u64, + in_reply_to: u64, + echo: String, + }, + #[serde(rename = "generate")] + Generate { + msg_id: u64, + }, + #[serde(rename = "generate_ok")] + GenerateOk { + msg_id: u64, + in_reply_to: u64, + id: u64, + }, + #[serde(rename = "broadcast")] + Broadcast { + msg_id: u64, + message: String, + }, + #[serde(rename = "broadcast_ok")] + BroadcastOk { + msg_id: u64, + in_reply_to: u64, + }, + #[serde(rename = "read")] + Read { + msg_id: u64, + }, + #[serde(rename = "read_ok")] + ReadOk { + msg_id: u64, + in_reply_to: u64, + message: String, + }, Default, } @@ -24,45 +62,3 @@ impl Default for Message { Message { src, dest, body } } } - -#[derive(Debug, Serialize, Deserialize)] -pub struct EchoRequest { - #[serde(rename = "type")] - pub response_type: String, - pub msg_id: u64, - pub echo: String, -} - -impl EchoRequest { - pub fn name(&self) -> String { - "EchoRequest".to_string() - } -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct EchoResponse { - #[serde(rename = "type")] - pub response_type: String, - pub msg_id: u64, - pub in_reply_to: u64, - pub echo: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct GenerateRequest { - #[serde(rename = "type")] - pub response_type: String, -} - -impl GenerateRequest { - pub fn name(&self) -> String { - "GenerateRequest".to_string() - } -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct GenerateResponse { - #[serde(rename = "type")] - pub response_type: String, - pub id: u64, -} diff --git a/src/routes.rs b/src/routes.rs index 8812381..76e304b 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,8 +1,8 @@ use axum::{extract::State, http::StatusCode, Json}; -use tracing::error; +use tracing::{info}; use crate::{ - messages::{EchoResponse, GenerateResponse, Message, MessageBody}, + messages::{Message, MessageBody}, AppState, }; @@ -17,49 +17,69 @@ pub async fn challenge( let response: Message = { let src = payload.dest; let dest = payload.src; + info!("{:?}", payload.body); let body: MessageBody = match payload.body { - MessageBody::EchoRequest(r) => { - if r.response_type != "echo" { - error!( - "Invalid response_type {} for {}", - &r.response_type, - r.name() - ); - return (StatusCode::BAD_REQUEST, Json(Message::default())); - } + MessageBody::Echo { msg_id, echo } => { + let in_reply_to = msg_id; - let response_type = "echo_ok".to_string(); - let msg_id = r.msg_id; - let in_reply_to = r.msg_id; - let echo = r.echo; - - MessageBody::EchoResponse(EchoResponse { - response_type, + MessageBody::EchoOk { msg_id, in_reply_to, echo, - }) - } - MessageBody::GenerateRequest(r) => { - if r.response_type != "generate" { - error!( - "Invalid response_type {} for {}", - &r.response_type, - r.name() - ); - return (StatusCode::BAD_REQUEST, Json(Message::default())); } - let response_type = "generate_ok".to_string(); - let id = { state.lock().unwrap().next() }; - MessageBody::GenerateResponse(GenerateResponse { response_type, id }) } - MessageBody::EchoResponse(_r) => { - return (StatusCode::BAD_REQUEST, Json(Message::default())) + MessageBody::EchoOk { + msg_id: _, + in_reply_to: _, + echo: _, + } => return (StatusCode::BAD_REQUEST, Json(Message::default())), + MessageBody::Generate { msg_id } => { + let in_reply_to = msg_id; + let id = { state.lock().unwrap().uid.next() }; + MessageBody::GenerateOk { + msg_id, + in_reply_to, + id, + } } - MessageBody::GenerateResponse(_r) => { - return (StatusCode::BAD_REQUEST, Json(Message::default())) + MessageBody::GenerateOk { + msg_id: _, + in_reply_to: _, + id: _, + } => return (StatusCode::BAD_REQUEST, Json(Message::default())), + MessageBody::Broadcast { msg_id, message } => { + let in_reply_to = msg_id; + { + let mut local_state = state.lock().unwrap(); + local_state.message = message; + } + MessageBody::BroadcastOk { + msg_id, + in_reply_to, + } } - MessageBody::Default => return (StatusCode::BAD_REQUEST, Json(Message::default())), + MessageBody::BroadcastOk { + msg_id: _, + in_reply_to: _, + } => return (StatusCode::BAD_REQUEST, Json(Message::default())), + MessageBody::Read { msg_id } => { + let in_reply_to = msg_id; + let message = { + let local_state = state.lock().unwrap(); + local_state.message.clone() + }; + MessageBody::ReadOk { + msg_id, + in_reply_to, + message, + } + } + MessageBody::ReadOk { + msg_id: _, + in_reply_to: _, + message: _, + } => return (StatusCode::BAD_REQUEST, Json(Message::default())), + MessageBody::Default => todo!(), }; Message { src, dest, body } diff --git a/tests/broadcast.rs b/tests/broadcast.rs new file mode 100644 index 0000000..43bd5e5 --- /dev/null +++ b/tests/broadcast.rs @@ -0,0 +1,95 @@ +use axum::http::{self, Request, StatusCode}; +use echo::{ + app, + messages::{Message, MessageBody}, +}; +use http_body_util::BodyExt; +use serde_json::{json, Value}; +use tower::{Service, ServiceExt}; + +#[tokio::test] +async fn test_rw() { + let mut app = app().into_service(); + + let message = "test".to_string(); + + let response = { + let request: Message = { + let src = "c1".to_string(); + let dest = "n1".to_string(); + let body = { + let msg_id = 1; + let message = message.clone(); + + MessageBody::Broadcast { msg_id, message } + }; + Message { src, dest, body } + }; + + let request = Request::builder() + .method(http::Method::POST) + .uri("/") + .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(serde_json::to_string(&request).unwrap()) + .unwrap(); + + let response = ServiceExt::>::ready(&mut app) + .await + .unwrap() + .call(request) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let request: Message = { + let src = "c1".to_string(); + let dest = "n1".to_string(); + let body = { + let msg_id = 2; + + MessageBody::Read { msg_id } + }; + Message { src, dest, body } + }; + + let request = Request::builder() + .method(http::Method::POST) + .uri("/") + .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(serde_json::to_string(&request).unwrap()) + .unwrap(); + + let response = ServiceExt::>::ready(&mut app) + .await + .unwrap() + .call(request) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + response + }; + + { + let body: Value = { + let body = response.into_body().collect().await.unwrap().to_bytes(); + serde_json::from_slice(&body).unwrap() + }; + + let expected = json!( + { + "src": "n1", + "dest": "c1", + "body": { + "type": "read_ok", + "msg_id": 2, + "in_reply_to": 2, + "message": message + } + } + ); + + assert_eq!(body, expected); + } +} diff --git a/tests/echo.rs b/tests/echo.rs index 5787303..901bfa2 100644 --- a/tests/echo.rs +++ b/tests/echo.rs @@ -1,61 +1,61 @@ use axum::http::{self, Request, StatusCode}; use echo::{ app, - messages::{EchoRequest, Message, MessageBody}, + messages::{Message, MessageBody}, }; use http_body_util::BodyExt; use serde_json::{json, Value}; use tower::ServiceExt; -use tracing::info; #[tokio::test] -async fn specification() { - tracing_subscriber::fmt::init(); - let request: Message = { - let src = "c1".to_string(); - let dest = "n1".to_string(); - let response_type = "echo".to_string(); - let msg_id = 1; - let echo = "Please echo 35".to_string(); - let body = MessageBody::EchoRequest(EchoRequest { - response_type, - msg_id, - echo, - }); - Message { src, dest, body } - }; - let body = serde_json::to_string_pretty(&request).unwrap(); - info!("Request: {body}"); - +async fn test_echo() { let app = app().into_service(); - let response = app - .oneshot( - Request::builder() - .method(http::Method::POST) - .uri("/echo") - .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) - .body(serde_json::to_string(&request).unwrap()) - .unwrap(), - ) - .await - .unwrap(); - assert_eq!(response.status(), StatusCode::OK); + let response = { + let request: Message = { + let src = "c1".to_string(); + let dest = "n1".to_string(); + let body = { + let msg_id = 1; + let echo = "Please echo 35".to_string(); - let body = response.into_body().collect().await.unwrap().to_bytes(); - let body: Value = serde_json::from_slice(&body).unwrap(); - let expected = json!( - { - "src": "n1", - "dest": "c1", - "body": { - "type": "echo_ok", - "msg_id": 1, - "in_reply_to": 1, - "echo": "Please echo 35" + MessageBody::Echo { msg_id, echo } + }; + Message { src, dest, body } + }; + + let response = app + .oneshot( + Request::builder() + .method(http::Method::POST) + .uri("/") + .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(serde_json::to_string(&request).unwrap()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + response + }; + + { + let body = response.into_body().collect().await.unwrap().to_bytes(); + let body: Value = serde_json::from_slice(&body).unwrap(); + let expected = json!( + { + "src": "n1", + "dest": "c1", + "body": { + "type": "echo_ok", + "msg_id": 1, + "in_reply_to": 1, + "echo": "Please echo 35" + } } - } - ); + ); - assert_eq!(body, expected); + assert_eq!(body, expected); + } } diff --git a/tests/uid.rs b/tests/uid.rs new file mode 100644 index 0000000..27e2c45 --- /dev/null +++ b/tests/uid.rs @@ -0,0 +1,150 @@ +use axum::http::{self, Request, StatusCode}; +use echo::{ + app, + messages::{Message, MessageBody}, +}; +use http_body_util::BodyExt; +use serde_json::{json, Value}; +use tower::{Service, ServiceExt}; + +#[tokio::test] +async fn test_single() { + let app = app().into_service(); + + let response = { + let request: Message = { + let src = "c1".to_string(); + let dest = "n1".to_string(); + let body = { + let msg_id = 1; + + MessageBody::Generate { msg_id } + }; + Message { src, dest, body } + }; + + let response = app + .oneshot( + Request::builder() + .method(http::Method::POST) + .uri("/") + .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(serde_json::to_string(&request).unwrap()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + response + }; + + { + let body: Value = { + let body = response.into_body().collect().await.unwrap().to_bytes(); + serde_json::from_slice(&body).unwrap() + }; + + let expected = json!( + { + "src": "n1", + "dest": "c1", + "body": { + "type": "generate_ok", + "msg_id": 1, + "in_reply_to": 1, + "id": 0 + } + } + ); + + assert_eq!(body, expected); + } +} + +#[tokio::test] +async fn test_multiple() { + let mut app = app().into_service(); + let request_count = 3; + + let response = { + for i in 0..request_count { + let request: Message = { + let src = "c1".to_string(); + let dest = "n1".to_string(); + let body = { + let msg_id = i; + + MessageBody::Generate { msg_id } + }; + Message { src, dest, body } + }; + + let request = Request::builder() + .method(http::Method::POST) + .uri("/") + .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(serde_json::to_string(&request).unwrap()) + .unwrap(); + + let response = ServiceExt::>::ready(&mut app) + .await + .unwrap() + .call(request) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + } + + let request: Message = { + let src = "c1".to_string(); + let dest = "n1".to_string(); + let body = { + let msg_id = request_count; + + MessageBody::Generate { msg_id } + }; + Message { src, dest, body } + }; + + let request = Request::builder() + .method(http::Method::POST) + .uri("/") + .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(serde_json::to_string(&request).unwrap()) + .unwrap(); + + let response = ServiceExt::>::ready(&mut app) + .await + .unwrap() + .call(request) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + response + }; + + { + let body: Value = { + let body = response.into_body().collect().await.unwrap().to_bytes(); + serde_json::from_slice(&body).unwrap() + }; + + let expected = json!( + { + "src": "n1", + "dest": "c1", + "body": { + "type": "generate_ok", + "msg_id": request_count, + "in_reply_to": request_count, + "id": request_count + } + } + ); + + assert_eq!(body, expected); + } +}