From 2b4be3dab857e69cf715112d38ca30e8b5f4914b Mon Sep 17 00:00:00 2001 From: TuDatTr Date: Fri, 2 Feb 2024 01:30:40 +0100 Subject: [PATCH] Updated typing system Signed-off-by: TuDatTr --- src/lib.rs | 30 ++++++--- src/main.rs | 4 ++ src/messages.rs | 74 ++++++++++++++++------ src/routes.rs | 98 ++++++++++++++++++++++------- tests/{specification.rs => echo.rs} | 14 ++--- 5 files changed, 164 insertions(+), 56 deletions(-) rename tests/{specification.rs => echo.rs} (84%) diff --git a/src/lib.rs b/src/lib.rs index e0f425c..1475e3d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,26 @@ -use axum::{ - routing::{get, post}, - Router, -}; -use routes::{echo, root}; +use std::sync::{Arc, Mutex}; + +use axum::{routing::post, Router}; +use routes::challenge; pub mod messages; pub mod routes; -pub fn app() -> Router { - Router::new() - .route("/", get(root)) - .route("/echo", post(echo)) +type AppState = Arc>; + +#[derive(Debug, Default)] +pub struct IdCounter { + value: u64, +} + +impl IdCounter { + fn next(&mut self) -> u64 { + let current = self.value; + self.value += 1; + current + } +} +pub fn app() -> Router { + let state = AppState::default(); + Router::new().route("/", post(challenge)).with_state(state) } diff --git a/src/main.rs b/src/main.rs index e1c2b55..614e140 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,14 @@ + + use anyhow::Error; use echo::app; + use tracing::info; #[tokio::main] async fn main() -> Result<(), Error> { tracing_subscriber::fmt::init(); + let binding = "0.0.0.0:3000"; info!("Creating Axum at: {binding}"); let listener = tokio::net::TcpListener::bind(&binding).await?; diff --git a/src/messages.rs b/src/messages.rs index af05b4b..04c709d 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -1,30 +1,68 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize, Serialize)] -pub struct EchoRequest { - pub src: String, - pub dest: String, - pub body: BodyRequest, +pub enum MessageBody { + EchoRequest(EchoRequest), + EchoResponse(EchoResponse), + GenerateRequest(GenerateRequest), + GenerateResponse(GenerateResponse), + Default, } #[derive(Debug, Deserialize, Serialize)] -pub struct BodyRequest { - pub r#type: String, - pub msg_id: u32, - pub echo: String, -} - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct EchoResponse { +pub struct Message { pub src: String, pub dest: String, - pub body: BodyResponse, + pub body: MessageBody, } -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct BodyResponse { - pub r#type: String, - pub msg_id: u32, - pub in_reply_to: u32, +impl Default for Message { + fn default() -> Self { + let src = "".to_string(); + let dest = "".to_string(); + let body = MessageBody::Default; + Message { src, dest, body } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct EchoRequest { + #[serde(rename = "type")] + pub response_type: String, + pub msg_id: u64, pub echo: String, } + +impl EchoRequest { + pub fn name(&self) -> String { + "EchoRequest".to_string() + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct EchoResponse { + #[serde(rename = "type")] + pub response_type: String, + pub msg_id: u64, + pub in_reply_to: u64, + pub echo: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct GenerateRequest { + #[serde(rename = "type")] + pub response_type: String, +} + +impl GenerateRequest { + pub fn name(&self) -> String { + "GenerateRequest".to_string() + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct GenerateResponse { + #[serde(rename = "type")] + pub response_type: String, + pub id: u64, +} diff --git a/src/routes.rs b/src/routes.rs index 27681c9..8812381 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,35 +1,89 @@ -use axum::{http::StatusCode, Json}; -use tracing::{debug, error}; +use axum::{extract::State, http::StatusCode, Json}; +use tracing::error; -use crate::messages::{BodyResponse, EchoRequest, EchoResponse}; +use crate::{ + messages::{EchoResponse, GenerateResponse, Message, MessageBody}, + AppState, +}; pub async fn root() -> &'static str { "Hello, World!" } -pub async fn echo(Json(payload): Json) -> (StatusCode, Json) { - if payload.body.r#type != "echo" { - error!("Error."); - return (StatusCode::BAD_REQUEST, Json(EchoResponse::default())); - } - let response: EchoResponse = { +pub async fn challenge( + State(state): State, + Json(payload): Json, +) -> (StatusCode, Json) { + let response: Message = { let src = payload.dest; let dest = payload.src; - let body = { - let r#type = "echo_ok".to_string(); - let msg_id = payload.body.msg_id; - let in_reply_to = payload.body.msg_id; - let echo = payload.body.echo; - BodyResponse { - r#type, - msg_id, - in_reply_to, - echo, + let body: MessageBody = match payload.body { + MessageBody::EchoRequest(r) => { + if r.response_type != "echo" { + error!( + "Invalid response_type {} for {}", + &r.response_type, + r.name() + ); + return (StatusCode::BAD_REQUEST, Json(Message::default())); + } + + let response_type = "echo_ok".to_string(); + let msg_id = r.msg_id; + let in_reply_to = r.msg_id; + let echo = r.echo; + + MessageBody::EchoResponse(EchoResponse { + response_type, + msg_id, + in_reply_to, + echo, + }) } + MessageBody::GenerateRequest(r) => { + if r.response_type != "generate" { + error!( + "Invalid response_type {} for {}", + &r.response_type, + r.name() + ); + return (StatusCode::BAD_REQUEST, Json(Message::default())); + } + let response_type = "generate_ok".to_string(); + let id = { state.lock().unwrap().next() }; + MessageBody::GenerateResponse(GenerateResponse { response_type, id }) + } + MessageBody::EchoResponse(_r) => { + return (StatusCode::BAD_REQUEST, Json(Message::default())) + } + MessageBody::GenerateResponse(_r) => { + return (StatusCode::BAD_REQUEST, Json(Message::default())) + } + MessageBody::Default => return (StatusCode::BAD_REQUEST, Json(Message::default())), }; - EchoResponse { src, dest, body } + + Message { src, dest, body } }; - let json = serde_json::to_string_pretty(&response).unwrap(); - debug!("Response: {json}"); (StatusCode::OK, Json(response)) + + // let response: MessageResponse = { + // let src = payload.dest; + // let dest = payload.src; + // let body = { + // let r#type = "echo_ok".to_string(); + // let msg_id = payload.body.msg_id; + // let in_reply_to = payload.body.msg_id; + // let echo = payload.body.echo; + // BodyResponse { + // r#type, + // msg_id, + // in_reply_to, + // echo, + // } + // }; + // MessageResponse { src, dest, body } + // }; + // let json = serde_json::to_string_pretty(&response).unwrap(); + // debug!("Response: {json}"); + // (StatusCode::OK, Json(response)) } diff --git a/tests/specification.rs b/tests/echo.rs similarity index 84% rename from tests/specification.rs rename to tests/echo.rs index 6b5d130..5787303 100644 --- a/tests/specification.rs +++ b/tests/echo.rs @@ -1,7 +1,7 @@ use axum::http::{self, Request, StatusCode}; use echo::{ app, - messages::{BodyRequest, EchoRequest}, + messages::{EchoRequest, Message, MessageBody}, }; use http_body_util::BodyExt; use serde_json::{json, Value}; @@ -11,18 +11,18 @@ use tracing::info; #[tokio::test] async fn specification() { tracing_subscriber::fmt::init(); - let request = { + let request: Message = { let src = "c1".to_string(); let dest = "n1".to_string(); - let r#type = "echo".to_string(); + let response_type = "echo".to_string(); let msg_id = 1; let echo = "Please echo 35".to_string(); - let body = BodyRequest { - r#type, + let body = MessageBody::EchoRequest(EchoRequest { + response_type, msg_id, echo, - }; - EchoRequest { src, dest, body } + }); + Message { src, dest, body } }; let body = serde_json::to_string_pretty(&request).unwrap(); info!("Request: {body}");