52 lines
1.4 KiB
Rust
52 lines
1.4 KiB
Rust
use axum::{
|
|
extract::{ConnectInfo, Request, State},
|
|
http::StatusCode,
|
|
middleware::Next,
|
|
response::Response,
|
|
};
|
|
use redis::AsyncCommands;
|
|
use std::net::SocketAddr;
|
|
|
|
use crate::state::AppState;
|
|
|
|
pub async fn ip_rate_limit_middleware(
|
|
State(state): State<AppState>,
|
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
|
req: Request,
|
|
next: Next,
|
|
) -> Result<Response, StatusCode> {
|
|
let ip = addr.ip().to_string();
|
|
let key = format!("ip_rate_limit:{}", ip);
|
|
let mut conn = state.redis.clone();
|
|
|
|
// Check if IP is blocked
|
|
let blocked: bool = conn.exists(format!("ip_blocked:{}", ip))
|
|
.await
|
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
|
|
|
if blocked {
|
|
return Err(StatusCode::FORBIDDEN);
|
|
}
|
|
|
|
// Aggressive rate limiting: 100 req/min per IP
|
|
let count: i64 = conn.incr(&key, 1)
|
|
.await
|
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
|
|
|
if count == 1 {
|
|
let _: () = conn.expire(&key, 60)
|
|
.await
|
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
|
}
|
|
|
|
if count > 100 {
|
|
// Block IP for 1 hour after exceeding limit
|
|
let _: () = conn.set_ex(format!("ip_blocked:{}", ip), "1", 3600)
|
|
.await
|
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
|
return Err(StatusCode::TOO_MANY_REQUESTS);
|
|
}
|
|
|
|
Ok(next.run(req).await)
|
|
}
|