|
12 | 12 | //! Then test: curl -H "Authorization: Bearer token123" http://127.0.0.1:8080/api/protected |
13 | 13 |
|
14 | 14 | use rustapi_rs::prelude::*; |
| 15 | +use std::future::Future; |
| 16 | +use std::pin::Pin; |
15 | 17 | use std::time::Instant; |
16 | | -use uuid::Uuid; |
| 18 | + |
| 19 | +// Import middleware traits from rustapi_core since they're not re-exported |
| 20 | +use rustapi_core::middleware::{BoxedNext, MiddlewareLayer}; |
17 | 21 |
|
18 | 22 | // ============================================ |
19 | 23 | // Custom Middleware |
20 | 24 | // ============================================ |
21 | 25 |
|
22 | 26 | /// Request ID Middleware - Adds unique ID to each request |
| 27 | +#[derive(Clone)] |
23 | 28 | struct RequestIdMiddleware; |
24 | 29 |
|
25 | 30 | impl RequestIdMiddleware { |
26 | 31 | fn new() -> Self { |
27 | 32 | Self |
28 | 33 | } |
| 34 | +} |
29 | 35 |
|
30 | | - async fn handle<B>(&self, req: Request<B>, next: Next<B>) -> Response { |
31 | | - let request_id = Uuid::new_v4().to_string(); |
32 | | - println!( |
33 | | - "📝 [{}] New request: {} {}", |
34 | | - request_id, |
35 | | - req.method(), |
36 | | - req.uri() |
37 | | - ); |
38 | | - |
39 | | - // Add request ID to headers |
40 | | - let mut response = next.run(req).await; |
41 | | - response |
42 | | - .headers_mut() |
43 | | - .insert("X-Request-ID", request_id.parse().unwrap()); |
44 | | - response |
| 36 | +impl MiddlewareLayer for RequestIdMiddleware { |
| 37 | + fn call( |
| 38 | + &self, |
| 39 | + req: Request, |
| 40 | + next: BoxedNext, |
| 41 | + ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> { |
| 42 | + Box::pin(async move { |
| 43 | + let request_id = generate_request_id(); |
| 44 | + println!( |
| 45 | + "📝 [{}] New request: {} {}", |
| 46 | + request_id, |
| 47 | + req.method(), |
| 48 | + req.uri() |
| 49 | + ); |
| 50 | + |
| 51 | + // Call next middleware/handler |
| 52 | + let mut response = next(req).await; |
| 53 | + |
| 54 | + // Add request ID to response headers |
| 55 | + if let Ok(header_value) = request_id.parse() { |
| 56 | + response.headers_mut().insert("X-Request-ID", header_value); |
| 57 | + } |
| 58 | + |
| 59 | + response |
| 60 | + }) |
| 61 | + } |
| 62 | + |
| 63 | + fn clone_box(&self) -> Box<dyn MiddlewareLayer> { |
| 64 | + Box::new(self.clone()) |
45 | 65 | } |
46 | 66 | } |
47 | 67 |
|
48 | 68 | /// Timing Middleware - Logs request duration |
| 69 | +#[derive(Clone)] |
49 | 70 | struct TimingMiddleware; |
50 | 71 |
|
51 | 72 | impl TimingMiddleware { |
52 | 73 | fn new() -> Self { |
53 | 74 | Self |
54 | 75 | } |
| 76 | +} |
55 | 77 |
|
56 | | - async fn handle<B>(&self, req: Request<B>, next: Next<B>) -> Response { |
57 | | - let start = Instant::now(); |
58 | | - let method = req.method().clone(); |
59 | | - let uri = req.uri().clone(); |
| 78 | +impl MiddlewareLayer for TimingMiddleware { |
| 79 | + fn call( |
| 80 | + &self, |
| 81 | + req: Request, |
| 82 | + next: BoxedNext, |
| 83 | + ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> { |
| 84 | + Box::pin(async move { |
| 85 | + let start = Instant::now(); |
| 86 | + let method = req.method().to_string(); |
| 87 | + let uri = req.uri().to_string(); |
60 | 88 |
|
61 | | - let response = next.run(req).await; |
| 89 | + let response = next(req).await; |
62 | 90 |
|
63 | | - let duration = start.elapsed(); |
64 | | - println!("⏱️ {} {} - {}ms", method, uri, duration.as_millis()); |
| 91 | + let duration = start.elapsed(); |
| 92 | + println!("⏱️ {} {} - {}ms", method, uri, duration.as_millis()); |
| 93 | + |
| 94 | + response |
| 95 | + }) |
| 96 | + } |
65 | 97 |
|
66 | | - response |
| 98 | + fn clone_box(&self) -> Box<dyn MiddlewareLayer> { |
| 99 | + Box::new(self.clone()) |
67 | 100 | } |
68 | 101 | } |
69 | 102 |
|
70 | 103 | /// Custom Auth Middleware - Simple token validation |
| 104 | +#[derive(Clone)] |
71 | 105 | struct CustomAuthMiddleware; |
72 | 106 |
|
73 | 107 | impl CustomAuthMiddleware { |
74 | 108 | fn new() -> Self { |
75 | 109 | Self |
76 | 110 | } |
| 111 | +} |
77 | 112 |
|
78 | | - async fn handle<B>(&self, req: Request<B>, next: Next<B>) -> Response { |
79 | | - // Check if route requires auth |
80 | | - let path = req.uri().path(); |
81 | | - if path.starts_with("/api/protected") { |
82 | | - // Validate auth header |
83 | | - if let Some(auth_header) = req.headers().get("Authorization") { |
84 | | - if let Ok(auth_str) = auth_header.to_str() { |
85 | | - if auth_str.starts_with("Bearer ") { |
86 | | - let token = &auth_str[7..]; |
87 | | - if token == "token123" { |
88 | | - println!("✅ Auth successful for {}", path); |
89 | | - return next.run(req).await; |
| 113 | +impl MiddlewareLayer for CustomAuthMiddleware { |
| 114 | + fn call( |
| 115 | + &self, |
| 116 | + req: Request, |
| 117 | + next: BoxedNext, |
| 118 | + ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> { |
| 119 | + Box::pin(async move { |
| 120 | + let path = req.uri().path(); |
| 121 | + |
| 122 | + // Check if route requires auth |
| 123 | + if path.starts_with("/api/protected") { |
| 124 | + // Validate auth header |
| 125 | + if let Some(auth_header) = req.headers().get("Authorization") { |
| 126 | + if let Ok(auth_str) = auth_header.to_str() { |
| 127 | + if auth_str.starts_with("Bearer ") { |
| 128 | + let token = &auth_str[7..]; |
| 129 | + if token == "token123" { |
| 130 | + println!("✅ Auth successful for {}", path); |
| 131 | + return next(req).await; |
| 132 | + } |
90 | 133 | } |
91 | 134 | } |
92 | 135 | } |
| 136 | + |
| 137 | + println!("❌ Auth failed for {}", path); |
| 138 | + // Return 401 Unauthorized |
| 139 | + use http::StatusCode; |
| 140 | + return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response(); |
93 | 141 | } |
94 | 142 |
|
95 | | - println!("❌ Auth failed for {}", path); |
96 | | - return Response::builder() |
97 | | - .status(401) |
98 | | - .body("Unauthorized".into()) |
99 | | - .unwrap(); |
100 | | - } |
| 143 | + next(req).await |
| 144 | + }) |
| 145 | + } |
101 | 146 |
|
102 | | - next.run(req).await |
| 147 | + fn clone_box(&self) -> Box<dyn MiddlewareLayer> { |
| 148 | + Box::new(self.clone()) |
103 | 149 | } |
104 | 150 | } |
105 | 151 |
|
| 152 | +// ============================================ |
| 153 | +// Helper Functions |
| 154 | +// ============================================ |
| 155 | + |
| 156 | +/// Generate a simple request ID |
| 157 | +fn generate_request_id() -> String { |
| 158 | + use std::sync::atomic::{AtomicU64, Ordering}; |
| 159 | + static COUNTER: AtomicU64 = AtomicU64::new(0); |
| 160 | + let count = COUNTER.fetch_add(1, Ordering::Relaxed); |
| 161 | + let timestamp = std::time::SystemTime::now() |
| 162 | + .duration_since(std::time::UNIX_EPOCH) |
| 163 | + .unwrap() |
| 164 | + .as_millis(); |
| 165 | + format!("{:x}-{:x}", timestamp, count) |
| 166 | +} |
| 167 | + |
106 | 168 | // ============================================ |
107 | 169 | // Response Models |
108 | 170 | // ============================================ |
@@ -177,9 +239,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> { |
177 | 239 |
|
178 | 240 | RustApi::auto() |
179 | 241 | // Middleware are executed in order |
180 | | - .middleware(RequestIdMiddleware::new()) |
181 | | - .middleware(TimingMiddleware::new()) |
182 | | - .middleware(CustomAuthMiddleware::new()) |
| 242 | + .layer(RequestIdMiddleware::new()) |
| 243 | + .layer(TimingMiddleware::new()) |
| 244 | + .layer(CustomAuthMiddleware::new()) |
183 | 245 | .run("127.0.0.1:8080") |
184 | 246 | .await |
185 | 247 | } |
0 commit comments