Initial commit: Full Crawl API implementation
Some checks failed
CI / Test (push) Has been cancelled
Deploy / Deploy to Staging (push) Has been cancelled
CI / Build & Push (push) Has been cancelled
Deploy / Deploy to Production (push) Has been cancelled

This commit is contained in:
2026-04-29 07:03:48 +00:00
commit 62994d4f3d
92 changed files with 6176 additions and 0 deletions

View File

@@ -0,0 +1,51 @@
use db::connection::create_pool;
use db::repos::{api_keys, users};
use shared::config::AppConfig;
use std::sync::Arc;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use uuid::Uuid;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.init();
let config = Arc::new(AppConfig::from_env()?);
let db = create_pool(&config.database_url).await?;
// Create test user
let email = "demo@crawlapi.dev";
let password = "demo123456";
let password_hash = bcrypt::hash(password, bcrypt::DEFAULT_COST)?;
let user = match users::find_by_email(&db, email).await? {
Some(u) => {
tracing::info!("User already exists: {}", u.id);
u
}
None => {
let u = users::create(&db, email, Some(&password_hash), None).await?;
tracing::info!("Created user: {} with 30 free credits", u.id);
u
}
};
// Create API key
let api_key = format!("crawlapi_demo_{}", Uuid::new_v4().to_string().replace('-', ""));
let key_hash = format!("{:x}", md5::compute(&api_key));
let key = api_keys::create(&db, user.id, &key_hash, "Demo Key").await?;
tracing::info!("Created API key: {} (id: {})", api_key, key.id);
println!("\n========================================");
println!("SEED DATA CREATED");
println!("========================================");
println!("Email: {}", email);
println!("Password: {}", password);
println!("API Key: {}", api_key);
println!("Credits: {}", user.credits);
println!("========================================\n");
Ok(())
}

8
crates/api/src/lib.rs Normal file
View File

@@ -0,0 +1,8 @@
pub mod middleware;
pub mod metrics;
pub mod queue;
pub mod routes;
pub mod secrets;
pub mod state;
pub mod storage;
pub mod validation;

79
crates/api/src/main.rs Normal file
View File

@@ -0,0 +1,79 @@
use std::sync::Arc;
use tower_http::trace::TraceLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
use api::{metrics, routes, state::AppState, storage};
use db::connection::create_pool;
use shared::config::AppConfig;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let sentry_dsn = std::env::var("SENTRY_DSN").ok();
let _guard = sentry_dsn.map(|dsn| {
sentry::init((dsn, sentry::ClientOptions {
release: sentry::release_name!(),
..Default::default()
}))
});
// Structured JSON logging with correlation IDs
let json_logging = std::env::var("JSON_LOGGING").unwrap_or_else(|_| "false".to_string()) == "true";
if json_logging {
tracing_subscriber::registry()
.with(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "api=debug,tower_http=debug".into()),
)
.with(
tracing_subscriber::fmt::layer()
.json()
.with_current_span(true)
.with_span_list(true)
.with_target(true),
)
.init();
} else {
tracing_subscriber::registry()
.with(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "api=debug,tower_http=debug".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
}
metrics::register_metrics();
let config = Arc::new(AppConfig::from_env()?);
let db = create_pool(&config.database_url).await?;
sqlx::migrate!("../db/migrations").run(&db).await?;
let redis = redis::Client::open(config.redis_url.clone())?;
let redis_conn = redis.get_multiplexed_tokio_connection().await?;
let s3_config = aws_config::from_env()
.endpoint_url(&config.s3_endpoint)
.load()
.await;
let s3 = aws_sdk_s3::Client::new(&s3_config);
storage::ensure_bucket_exists(&s3, &config.s3_bucket).await?;
let state = AppState {
config,
db,
redis: redis_conn,
s3,
};
let app = routes::create_router(state)
.layer(TraceLayer::new_for_http());
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
tracing::info!("API server listening on {}", listener.local_addr()?);
axum::serve(listener, app).await?;
Ok(())
}

31
crates/api/src/metrics.rs Normal file
View File

@@ -0,0 +1,31 @@
use lazy_static::lazy_static;
use prometheus::{CounterVec, HistogramVec, Registry};
use std::time::Instant;
lazy_static! {
pub static ref REGISTRY: Registry = Registry::new();
pub static ref REQUEST_COUNTER: CounterVec = CounterVec::new(
prometheus::Opts::new("api_requests_total", "Total API requests"),
&["endpoint", "status"]
)
.unwrap();
pub static ref REQUEST_DURATION: HistogramVec = HistogramVec::new(
prometheus::HistogramOpts::new("api_request_duration_seconds", "Request duration in seconds"),
&["endpoint"]
)
.unwrap();
}
pub fn register_metrics() {
REGISTRY.register(Box::new(REQUEST_COUNTER.clone())).unwrap();
REGISTRY.register(Box::new(REQUEST_DURATION.clone())).unwrap();
}
pub fn record_request(endpoint: &str, status: &str) {
REQUEST_COUNTER.with_label_values(&[endpoint, status]).inc();
}
pub fn record_duration(endpoint: &str, start: Instant) {
let duration = start.elapsed().as_secs_f64();
REQUEST_DURATION.with_label_values(&[endpoint]).observe(duration);
}

View File

@@ -0,0 +1,52 @@
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use db::repos::api_keys;
use shared::models::User;
use crate::state::AppState;
#[derive(Clone)]
pub struct ApiKeyAuth {
pub user: User,
pub api_key_id: uuid::Uuid,
}
pub async fn api_key_middleware(
State(state): State<AppState>,
mut req: Request,
next: Next,
) -> Result<Response, StatusCode> {
let api_key = req
.headers()
.get("x-api-key")
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
let key_hash = format!("{:x}", md5::compute(api_key));
let api_key_record = api_keys::find_by_key_hash(&state.db, &key_hash)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::UNAUTHORIZED)?;
let user = db::repos::users::find_by_id(&state.db, api_key_record.user_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::UNAUTHORIZED)?;
api_keys::update_last_used(&state.db, api_key_record.id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let auth = ApiKeyAuth {
user,
api_key_id: api_key_record.id,
};
req.extensions_mut().insert(auth);
Ok(next.run(req).await)
}

View File

@@ -0,0 +1,42 @@
use axum::{
extract::{Request, State},
http::{header::HeaderValue, StatusCode},
middleware::Next,
response::Response,
};
use tracing::{info_span, Instrument};
use uuid::Uuid;
use crate::state::AppState;
pub async fn correlation_id_middleware(
State(_state): State<AppState>,
mut req: Request,
next: Next,
) -> Result<Response, StatusCode> {
let correlation_id = req
.headers()
.get("x-correlation-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| Uuid::new_v4().to_string());
req.headers_mut().insert(
"x-correlation-id",
HeaderValue::from_str(&correlation_id).unwrap(),
);
let method = req.method().to_string();
let uri = req.uri().to_string();
let span = info_span!(
"http_request",
correlation_id = %correlation_id,
method = %method,
uri = %uri,
);
let response = next.run(req).instrument(span).await;
Ok(response)
}

View File

@@ -0,0 +1,49 @@
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use jsonwebtoken::{decode, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::state::AppState;
#[derive(Debug, Serialize, Deserialize)]
pub struct JwtClaims {
pub sub: String,
pub exp: usize,
}
#[derive(Clone)]
pub struct JwtAuth {
pub user_id: Uuid,
}
pub async fn jwt_middleware(
State(state): State<AppState>,
mut req: Request,
next: Next,
) -> Result<Response, StatusCode> {
let token = req
.headers()
.get("x-auth-token")
.or_else(|| req.headers().get("authorization"))
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer ").or(Some(v)))
.ok_or(StatusCode::UNAUTHORIZED)?;
let validation = Validation::default();
let token_data = decode::<JwtClaims>(
token,
&DecodingKey::from_secret(state.config.jwt_secret.as_bytes()),
&validation,
)
.map_err(|_| StatusCode::UNAUTHORIZED)?;
let user_id = Uuid::parse_str(&token_data.claims.sub).map_err(|_| StatusCode::UNAUTHORIZED)?;
req.extensions_mut().insert(JwtAuth { user_id });
Ok(next.run(req).await)
}

View File

@@ -0,0 +1,5 @@
pub mod auth;
pub mod correlation;
pub mod jwt;
pub mod rate_limit;
pub mod waf;

View File

@@ -0,0 +1,36 @@
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use redis::AsyncCommands;
use crate::state::AppState;
pub async fn rate_limit_middleware(
State(state): State<AppState>,
req: Request,
next: Next,
) -> Result<Response, StatusCode> {
let api_key = req
.headers()
.get("x-api-key")
.and_then(|v| v.to_str().ok())
.unwrap_or("anonymous");
let key = format!("rate_limit:{}", api_key);
let mut conn = state.redis.clone();
let count: i64 = conn.incr(&key, 1).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if count == 1 {
let _: () = conn.expire(&key, 60).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
}
if count > 60 {
return Err(StatusCode::TOO_MANY_REQUESTS);
}
Ok(next.run(req).await)
}

View File

@@ -0,0 +1,51 @@
use axum::{
extract::{ConnectInfo, Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use redis::AsyncCommands;
use std::net::SocketAddr;
use crate::state::AppState;
pub async fn ip_rate_limit_middleware(
State(state): State<AppState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
req: Request,
next: Next,
) -> Result<Response, StatusCode> {
let ip = addr.ip().to_string();
let key = format!("ip_rate_limit:{}", ip);
let mut conn = state.redis.clone();
// Check if IP is blocked
let blocked: bool = conn.exists(format!("ip_blocked:{}", ip))
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if blocked {
return Err(StatusCode::FORBIDDEN);
}
// Aggressive rate limiting: 100 req/min per IP
let count: i64 = conn.incr(&key, 1)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if count == 1 {
let _: () = conn.expire(&key, 60)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
}
if count > 100 {
// Block IP for 1 hour after exceeding limit
let _: () = conn.set_ex(format!("ip_blocked:{}", ip), "1", 3600)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
return Err(StatusCode::TOO_MANY_REQUESTS);
}
Ok(next.run(req).await)
}

94
crates/api/src/queue.rs Normal file
View File

@@ -0,0 +1,94 @@
use redis::AsyncCommands;
use shared::{
models::CrawlOptions,
queue::{Job, JobResult, QUEUE_NAME, RESULT_PREFIX},
};
use std::time::Duration;
use tokio::time::sleep;
use uuid::Uuid;
use crate::state::AppState;
pub async fn enqueue_job(
state: &AppState,
user_id: Uuid,
api_key_id: Uuid,
endpoint: &str,
url: &str,
options: &CrawlOptions,
webhook_url: Option<String>,
) -> Result<Uuid, redis::RedisError> {
let job = Job {
id: Uuid::new_v4(),
user_id,
api_key_id,
endpoint: endpoint.to_string(),
url: url.to_string(),
options: options.clone(),
webhook_url,
};
let job_json = serde_json::to_string(&job).unwrap();
let mut conn = state.redis.clone();
conn.rpush::<_, _, ()>(QUEUE_NAME, job_json).await?;
Ok(job.id)
}
pub async fn wait_for_result(
state: &AppState,
job_id: Uuid,
timeout_secs: u64,
) -> Result<Option<JobResult>, redis::RedisError> {
let result_key = format!("{}{}", RESULT_PREFIX, job_id);
let mut conn = state.redis.clone();
let start = std::time::Instant::now();
while start.elapsed().as_secs() < timeout_secs {
let result_json: Option<String> = conn.get(&result_key).await?;
if let Some(json) = result_json {
let result: JobResult = serde_json::from_str(&json).unwrap_or_else(|_| JobResult {
id: job_id,
success: false,
data: None,
error: Some("Failed to deserialize result".to_string()),
duration_ms: 0,
});
return Ok(Some(result));
}
sleep(Duration::from_millis(200)).await;
}
Ok(None)
}
pub async fn get_cache_key(url: &str, endpoint: &str, options: &CrawlOptions) -> String {
let opts_json = serde_json::to_string(options).unwrap_or_default();
let hash = format!("{:x}", md5::compute(format!("{}:{}:{}", url, endpoint, opts_json)));
format!("crawlapi:cache:{}", hash)
}
pub async fn get_cached_result(
state: &AppState,
cache_key: &str,
) -> Result<Option<JobResult>, redis::RedisError> {
let mut conn = state.redis.clone();
let result_json: Option<String> = conn.get::<_, Option<String>>(cache_key).await?;
if let Some(json) = result_json {
let result: JobResult = serde_json::from_str(&json).unwrap();
return Ok(Some(result));
}
Ok(None)
}
pub async fn set_cached_result(
state: &AppState,
cache_key: &str,
result: &JobResult,
ttl_secs: u64,
) -> Result<(), redis::RedisError> {
let mut conn = state.redis.clone();
let json = serde_json::to_string(result).unwrap();
conn.set_ex::<_, _, ()>(cache_key, json, ttl_secs).await?;
Ok(())
}

130
crates/api/src/routes/ai.rs Normal file
View File

@@ -0,0 +1,130 @@
use axum::{
extract::{Json, State},
http::StatusCode,
Extension,
};
use serde::{Deserialize, Serialize};
use shared::models::CrawlRequest;
use crate::{middleware::auth::ApiKeyAuth, queue, state::AppState};
#[derive(Debug, Deserialize)]
pub struct AiExtractRequest {
pub url: String,
pub schema: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct AiExtractResponse {
pub success: bool,
pub data: Option<serde_json::Value>,
pub error: Option<String>,
}
pub async fn extract(
State(state): State<AppState>,
Extension(auth): Extension<ApiKeyAuth>,
Json(body): Json<AiExtractRequest>,
) -> Result<Json<AiExtractResponse>, StatusCode> {
let openai_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
if openai_key.is_empty() {
return Ok(Json(AiExtractResponse {
success: false,
data: None,
error: Some("OpenAI not configured".to_string()),
}));
}
// Crawl the page via queue
let crawl_req = CrawlRequest {
url: body.url.clone(),
options: Default::default(),
};
let job_id = queue::enqueue_job(
&state,
auth.user.id,
auth.api_key_id,
"crawl",
&body.url,
&crawl_req.options,
None,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let job_result = queue::wait_for_result(&state, job_id, 60)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
match job_result {
Some(result) if result.success => {
let data = result.data.unwrap_or_default();
let html = data.get("html").and_then(|v| v.as_str()).unwrap_or("");
let title = data.get("title").and_then(|v| v.as_str()).unwrap_or("");
// Call OpenAI
let client = reqwest::Client::new();
let system_prompt = format!(
"You are a web scraping assistant. Extract structured data from the following HTML page titled '{}'. \
Return ONLY a JSON object matching the requested schema. Do not include any explanation.",
title
);
let user_prompt = if let Some(p) = body.prompt {
p
} else {
format!("Extract data from this HTML according to schema: {}\n\nHTML:\n{}",
body.schema.to_string(),
&html[..html.len().min(8000)])
};
let res = client
.post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {}", openai_key))
.json(&serde_json::json!({
"model": "gpt-4o-mini",
"messages": [
{ "role": "system", "content": system_prompt },
{ "role": "user", "content": user_prompt }
],
"temperature": 0.1,
"response_format": { "type": "json_object" }
}))
.send()
.await;
match res {
Ok(response) => {
let ai_data: serde_json::Value = response.json().await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if let Some(content) = ai_data["choices"][0]["message"]["content"].as_str() {
let parsed: serde_json::Value = serde_json::from_str(content).unwrap_or_else(|_| serde_json::json!({"raw": content}));
Ok(Json(AiExtractResponse {
success: true,
data: Some(parsed),
error: None,
}))
} else {
Ok(Json(AiExtractResponse {
success: false,
data: None,
error: Some("Invalid OpenAI response".to_string()),
}))
}
}
Err(e) => Ok(Json(AiExtractResponse {
success: false,
data: None,
error: Some(format!("OpenAI error: {}", e)),
})),
}
}
_ => Ok(Json(AiExtractResponse {
success: false,
data: None,
error: Some("Failed to crawl page".to_string()),
})),
}
}

View File

@@ -0,0 +1,142 @@
use axum::{
extract::{Json, Path, State},
http::StatusCode,
Extension,
};
use db::repos::{api_keys, users};
use serde::{Deserialize, Serialize};
use shared::models::{ApiKey, User};
use uuid::Uuid;
use crate::{middleware::jwt::JwtAuth, state::AppState};
#[derive(Debug, Deserialize)]
pub struct RegisterRequest {
pub email: String,
pub password: String,
}
#[derive(Debug, Serialize)]
pub struct AuthResponse {
pub user: User,
pub token: String,
}
pub async fn register(
State(state): State<AppState>,
Json(body): Json<RegisterRequest>,
) -> Result<Json<AuthResponse>, StatusCode> {
if users::find_by_email(&state.db, &body.email).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?.is_some() {
return Err(StatusCode::CONFLICT);
}
let password_hash = bcrypt::hash(&body.password, bcrypt::DEFAULT_COST).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let user = users::create(&state.db, &body.email, Some(&password_hash), None)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let token = create_jwt(&user.id.to_string(), &state.config.jwt_secret)?;
Ok(Json(AuthResponse { user, token }))
}
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
pub email: String,
pub password: String,
}
pub async fn login(
State(state): State<AppState>,
Json(body): Json<LoginRequest>,
) -> Result<Json<AuthResponse>, StatusCode> {
let user = users::find_by_email(&state.db, &body.email)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::UNAUTHORIZED)?;
let password_hash = user.password_hash.as_ref().ok_or(StatusCode::UNAUTHORIZED)?;
if !bcrypt::verify(&body.password, password_hash).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? {
return Err(StatusCode::UNAUTHORIZED);
}
let token = create_jwt(&user.id.to_string(), &state.config.jwt_secret)?;
Ok(Json(AuthResponse { user, token }))
}
pub fn create_jwt(user_id: &str, secret: &str) -> Result<String, StatusCode> {
use jsonwebtoken::{encode, EncodingKey, Header};
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
sub: String,
exp: usize,
}
let claims = Claims {
sub: user_id.to_string(),
exp: (chrono::Utc::now() + chrono::Duration::days(30)).timestamp() as usize,
};
encode(&Header::default(), &claims, &EncodingKey::from_secret(secret.as_bytes()))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
#[derive(Debug, Deserialize)]
pub struct CreateApiKeyRequest {
pub name: String,
}
#[derive(Debug, Serialize)]
pub struct ApiKeyResponse {
pub id: Uuid,
pub key: String,
pub name: String,
}
pub async fn create_api_key(
State(state): State<AppState>,
Extension(auth): Extension<JwtAuth>,
Json(body): Json<CreateApiKeyRequest>,
) -> Result<Json<ApiKeyResponse>, StatusCode> {
let api_key = format!("crawlapi_{}", Uuid::new_v4().to_string().replace('-', ""));
let key_hash = format!("{:x}", md5::compute(&api_key));
let key = api_keys::create(&state.db, auth.user_id, &key_hash, &body.name)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(ApiKeyResponse {
id: key.id,
key: api_key,
name: key.name,
}))
}
pub async fn list_api_keys(
State(state): State<AppState>,
Extension(auth): Extension<JwtAuth>,
) -> Result<Json<Vec<ApiKey>>, StatusCode> {
let keys = api_keys::list_by_user(&state.db, auth.user_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(keys))
}
pub async fn delete_api_key(
State(state): State<AppState>,
Extension(auth): Extension<JwtAuth>,
Path(id): Path<Uuid>,
) -> Result<StatusCode, StatusCode> {
let deleted = api_keys::delete_by_id(&state.db, id, auth.user_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if deleted {
Ok(StatusCode::NO_CONTENT)
} else {
Err(StatusCode::NOT_FOUND)
}
}

View File

@@ -0,0 +1,296 @@
use axum::{
extract::{Json, State},
http::StatusCode,
Extension,
};
use db::repos::{usage_logs, users};
use serde_json::json;
use shared::{
error::AppError,
models::{CrawlRequest, CrawlResponse},
};
use std::time::Instant;
use tokio::fs;
use crate::{middleware::auth::ApiKeyAuth, queue, state::AppState, storage, validation};
async fn upload_files_if_needed(
state: &AppState,
endpoint: &str,
result: &mut serde_json::Value,
) -> Result<(), AppError> {
// Handle file_path
let file_path_opt = result.get("file_path").and_then(|v| v.as_str()).map(String::from);
if let Some(file_path) = file_path_opt {
let file_data = fs::read(&file_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to read file: {}", e)))?;
let ext = if endpoint == "pdf" { "pdf" } else { "png" };
let content_type = if endpoint == "pdf" { "application/pdf" } else { "image/png" };
let key = storage::generate_file_key(endpoint, ext);
storage::upload_file(
&state.s3,
&state.config.s3_bucket,
&key,
content_type,
file_data,
)
.await?;
let public_url = storage::get_public_url(
endpoint,
&state.config.s3_endpoint,
&state.config.s3_bucket,
&key,
);
if let Some(obj) = result.as_object_mut() {
obj.remove("file_path");
obj.insert("url".to_string(), json!(public_url));
}
let _ = fs::remove_file(file_path).await;
}
Ok(())
}
async fn handle_endpoint(
state: State<AppState>,
Extension(auth): Extension<ApiKeyAuth>,
Json(body): Json<CrawlRequest>,
endpoint: &'static str,
) -> Result<Json<CrawlResponse>, StatusCode> {
let start = Instant::now();
// Validate URL
if let Err(e) = validation::validate_url(&body.url) {
let _ = usage_logs::create(
&state.db,
auth.user.id,
auth.api_key_id,
endpoint,
&body.url,
"error",
0,
start.elapsed().as_millis() as i64,
)
.await;
return Ok(Json(CrawlResponse {
success: false,
data: None,
calls_remaining: Some(auth.user.credits),
error: Some(e.to_string()),
}));
}
// Validate webhook URL if provided
if let Some(ref webhook_url) = body.options.webhook_url {
if let Err(e) = validation::validate_webhook_url(webhook_url) {
return Ok(Json(CrawlResponse {
success: false,
data: None,
calls_remaining: Some(auth.user.credits),
error: Some(e.to_string()),
}));
}
}
// Check credits
if auth.user.credits <= 0 {
return Ok(Json(CrawlResponse {
success: false,
data: None,
calls_remaining: Some(0),
error: Some("Insufficient credits".to_string()),
}));
}
// Deduct credits
let has_credits = users::deduct_credits(&state.db, auth.user.id, 1)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if !has_credits {
return Ok(Json(CrawlResponse {
success: false,
data: None,
calls_remaining: Some(0),
error: Some("Insufficient credits".to_string()),
}));
}
// Try cache first
let cache_key = queue::get_cache_key(&body.url, endpoint, &body.options).await;
let cached = queue::get_cached_result(&state, &cache_key)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if let Some(cached_result) = cached {
let _ = usage_logs::create(
&state.db,
auth.user.id,
auth.api_key_id,
endpoint,
&body.url,
"success",
1,
start.elapsed().as_millis() as i64,
)
.await;
return Ok(Json(CrawlResponse {
success: cached_result.success,
data: cached_result.data,
calls_remaining: Some(auth.user.credits - 1),
error: cached_result.error,
}));
}
// Enqueue job and wait for result
let job_id = queue::enqueue_job(
&state,
auth.user.id,
auth.api_key_id,
endpoint,
&body.url,
&body.options,
body.options.webhook_url.clone(),
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let job_result = queue::wait_for_result(&state, job_id, 60)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let duration = start.elapsed().as_millis() as i64;
let remaining = auth.user.credits - 1;
match job_result {
Some(mut result) => {
// Upload files if needed
if let Some(ref mut data) = result.data {
let _ = upload_files_if_needed(&state, endpoint, data).await;
}
// Cache successful results for 5 minutes
if result.success {
let _ = queue::set_cached_result(&state, &cache_key, &result, 300).await;
}
let status = if result.success { "success" } else { "error" };
let _ = usage_logs::create(
&state.db,
auth.user.id,
auth.api_key_id,
endpoint,
&body.url,
status,
1,
duration,
)
.await;
Ok(Json(CrawlResponse {
success: result.success,
data: result.data,
calls_remaining: Some(remaining),
error: result.error,
}))
}
None => {
let _ = usage_logs::create(
&state.db,
auth.user.id,
auth.api_key_id,
endpoint,
&body.url,
"timeout",
1,
duration,
)
.await;
Ok(Json(CrawlResponse {
success: false,
data: None,
calls_remaining: Some(remaining),
error: Some("Job timed out".to_string()),
}))
}
}
}
pub async fn handle_crawl(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "crawl").await
}
pub async fn handle_content(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "content").await
}
pub async fn handle_screenshot(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "screenshot").await
}
pub async fn handle_pdf(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "pdf").await
}
pub async fn handle_markdown(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "markdown").await
}
pub async fn handle_snapshot(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "snapshot").await
}
pub async fn handle_scrape(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "scrape").await
}
pub async fn handle_json(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "json").await
}
pub async fn handle_links(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "links").await
}

View File

@@ -0,0 +1,72 @@
pub mod auth;
pub mod crawl;
pub mod oauth;
pub mod stripe;
pub mod ai;
pub mod teams;
pub mod ws;
use axum::{
middleware,
routing::{get, post},
Router,
};
use tower_http::cors::CorsLayer;
use crate::{
middleware::{auth::api_key_middleware, correlation::correlation_id_middleware, jwt::jwt_middleware, rate_limit::rate_limit_middleware, waf::ip_rate_limit_middleware},
state::AppState,
};
pub fn create_router(state: AppState) -> Router {
let api_routes = Router::new()
.route("/crawl", post(crawl::handle_crawl))
.route("/content", post(crawl::handle_content))
.route("/screenshot", post(crawl::handle_screenshot))
.route("/pdf", post(crawl::handle_pdf))
.route("/markdown", post(crawl::handle_markdown))
.route("/snapshot", post(crawl::handle_snapshot))
.route("/scrape", post(crawl::handle_scrape))
.route("/json", post(crawl::handle_json))
.route("/links", post(crawl::handle_links))
.route("/extract", post(ai::extract))
.route_layer(middleware::from_fn_with_state(state.clone(), api_key_middleware))
.route_layer(middleware::from_fn_with_state(state.clone(), rate_limit_middleware));
let auth_routes = Router::new()
.route("/auth/register", post(auth::register))
.route("/auth/login", post(auth::login))
.route("/auth/google", get(oauth::google_auth_url))
.route("/auth/google/callback", get(oauth::google_callback));
let protected_routes = Router::new()
.route("/auth/api-keys", post(auth::create_api_key))
.route("/auth/api-keys", get(auth::list_api_keys))
.route("/auth/api-keys/{id}", axum::routing::delete(auth::delete_api_key))
.route("/stripe/checkout", post(stripe::create_checkout))
.route("/teams", post(teams::create))
.route("/teams/{slug}", get(teams::get))
.route("/teams/{slug}/members", post(teams::add_member))
.route_layer(middleware::from_fn_with_state(state.clone(), jwt_middleware));
let stripe_webhook = Router::new()
.route("/stripe/webhook", post(stripe::webhook));
Router::new()
.nest("/api", api_routes)
.nest("/api", auth_routes)
.nest("/api", protected_routes)
.nest("/api", stripe_webhook)
.route("/metrics", get(|| async {
use prometheus::Encoder;
let encoder = prometheus::TextEncoder::new();
let mut buffer = vec![];
encoder.encode(&crate::metrics::REGISTRY.gather(), &mut buffer).unwrap();
String::from_utf8(buffer).unwrap()
}))
.route("/ws/logs", get(ws::live_logs))
.layer(middleware::from_fn_with_state(state.clone(), ip_rate_limit_middleware))
.layer(middleware::from_fn_with_state(state.clone(), correlation_id_middleware))
.layer(CorsLayer::permissive())
.with_state(state)
}

View File

@@ -0,0 +1,137 @@
use axum::{
extract::{Query, State},
http::StatusCode,
Json,
};
use db::repos::{oauth, users};
use serde::{Deserialize, Serialize};
use crate::state::AppState;
#[derive(Debug, Deserialize)]
pub struct GoogleCallback {
pub code: String,
}
#[derive(Debug, Serialize)]
pub struct GoogleAuthUrl {
pub url: String,
}
#[derive(Debug, Deserialize)]
struct GoogleTokenResponse {
access_token: String,
id_token: Option<String>,
}
#[derive(Debug, Deserialize)]
struct GoogleUserInfo {
sub: String,
email: String,
name: Option<String>,
picture: Option<String>,
}
pub async fn google_auth_url(State(_state): State<AppState>) -> Result<Json<GoogleAuthUrl>, StatusCode> {
let client_id = std::env::var("GOOGLE_CLIENT_ID").unwrap_or_default();
if client_id.is_empty() {
return Err(StatusCode::NOT_IMPLEMENTED);
}
let redirect_uri = std::env::var("GOOGLE_REDIRECT_URI")
.unwrap_or_else(|_| "http://localhost:3000/api/auth/google/callback".to_string());
let url = format!(
"https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=email%20profile&access_type=offline&prompt=consent",
client_id, redirect_uri
);
Ok(Json(GoogleAuthUrl { url }))
}
pub async fn google_callback(
State(state): State<AppState>,
Query(params): Query<GoogleCallback>,
) -> Result<Json<super::auth::AuthResponse>, StatusCode> {
let client_id = std::env::var("GOOGLE_CLIENT_ID").unwrap_or_default();
let client_secret = std::env::var("GOOGLE_CLIENT_SECRET").unwrap_or_default();
let redirect_uri = std::env::var("GOOGLE_REDIRECT_URI")
.unwrap_or_else(|_| "http://localhost:3000/api/auth/google/callback".to_string());
if client_id.is_empty() || client_secret.is_empty() {
// MVP fallback: create mock user
let email = format!("google_user_{}@example.com", &params.code[..8.min(params.code.len())]);
let user = match users::find_by_email(&state.db, &email).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? {
Some(u) => u,
None => {
let u = users::create(&state.db, &email, None, Some(&params.code)).await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let _ = oauth::create(&state.db, u.id, "google", &params.code).await;
u
}
};
let token = super::auth::create_jwt(&user.id.to_string(), &state.config.jwt_secret)?;
return Ok(Json(super::auth::AuthResponse { user, token }));
}
// Exchange code for token
let client = reqwest::Client::new();
let token_res = client
.post("https://oauth2.googleapis.com/token")
.form(&[
("code", params.code.as_str()),
("client_id", client_id.as_str()),
("client_secret", client_secret.as_str()),
("redirect_uri", redirect_uri.as_str()),
("grant_type", "authorization_code"),
])
.send()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.json::<GoogleTokenResponse>()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// Get user info
let user_info = client
.get("https://openidconnect.googleapis.com/v1/userinfo")
.header("Authorization", format!("Bearer {}", token_res.access_token))
.send()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.json::<GoogleUserInfo>()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// Find or create user
let user = match oauth::find_by_provider(&state.db, "google", &user_info.sub).await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
{
Some(oauth_account) => {
users::find_by_id(&state.db, oauth_account.user_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?
}
None => {
let user = match users::find_by_email(&state.db, &user_info.email).await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
{
Some(u) => u,
None => {
users::create(&state.db, &user_info.email, None, Some(&user_info.sub))
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
}
};
oauth::create(&state.db, user.id, "google", &user_info.sub)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
user
}
};
let token = super::auth::create_jwt(&user.id.to_string(), &state.config.jwt_secret)?;
Ok(Json(super::auth::AuthResponse { user, token }))
}

View File

@@ -0,0 +1,146 @@
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
Json,
};
use db::repos::{subscriptions, users};
use serde::{Deserialize, Serialize};
use crate::{middleware::jwt::JwtAuth, state::AppState};
use axum::Extension;
#[derive(Debug, Deserialize)]
pub struct CreateCheckoutRequest {
pub price_id: String,
}
#[derive(Debug, Serialize)]
pub struct CheckoutResponse {
pub checkout_url: String,
}
pub async fn create_checkout(
State(state): State<AppState>,
Extension(auth): Extension<JwtAuth>,
Json(body): Json<CreateCheckoutRequest>,
) -> Result<Json<CheckoutResponse>, StatusCode> {
let stripe_secret = std::env::var("STRIPE_SECRET_KEY").map_err(|_| StatusCode::NOT_IMPLEMENTED)?;
let user = users::find_by_id(&state.db, auth.user_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
// Create Stripe customer via HTTP API directly (simpler than SDK for MVP)
let client = reqwest::Client::new();
let customer_res = client
.post("https://api.stripe.com/v1/customers")
.basic_auth(&stripe_secret, Some(""))
.form(&[("email", &user.email)])
.send()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let customer_data: serde_json::Value = customer_res.json().await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let customer_id = customer_data["id"].as_str().unwrap_or("");
let success_url = std::env::var("STRIPE_SUCCESS_URL")
.unwrap_or_else(|_| "http://localhost:3000/dashboard?success=true".to_string());
let cancel_url = std::env::var("STRIPE_CANCEL_URL")
.unwrap_or_else(|_| "http://localhost:3000/dashboard?canceled=true".to_string());
let session_res = client
.post("https://api.stripe.com/v1/checkout/sessions")
.basic_auth(&stripe_secret, Some(""))
.form(&[
("customer", customer_id),
("success_url", &success_url),
("cancel_url", &cancel_url),
("mode", "subscription"),
("line_items[0][price]", &body.price_id),
("line_items[0][quantity]", "1"),
("metadata[user_id]", &auth.user_id.to_string()),
])
.send()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let session_data: serde_json::Value = session_res.json().await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let url = session_data["url"].as_str().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(CheckoutResponse { checkout_url: url.to_string() }))
}
#[derive(Debug, Deserialize)]
pub struct StripeWebhook {
#[serde(rename = "type")]
pub event_type: String,
pub data: StripeEventData,
}
#[derive(Debug, Deserialize)]
pub struct StripeEventData {
pub object: serde_json::Value,
}
pub async fn webhook(
State(state): State<AppState>,
headers: HeaderMap,
body: String,
) -> Result<StatusCode, StatusCode> {
let stripe_secret = std::env::var("STRIPE_WEBHOOK_SECRET").unwrap_or_default();
// Verify webhook signature if configured
if !stripe_secret.is_empty() {
let sig = headers
.get("stripe-signature")
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::BAD_REQUEST)?;
// In production, verify signature using Stripe library
// For MVP, we log and process
tracing::info!("Webhook signature: {}", sig);
}
let event: serde_json::Value = serde_json::from_str(&body).map_err(|_| StatusCode::BAD_REQUEST)?;
let event_type = event["type"].as_str().unwrap_or("");
match event_type {
"checkout.session.completed" => {
if let Some(metadata) = event["data"]["object"]["metadata"].as_object() {
if let Some(user_id_str) = metadata.get("user_id").and_then(|v| v.as_str()) {
if let Ok(user_id) = uuid::Uuid::parse_str(user_id_str) {
let customer_id = event["data"]["object"]["customer"].as_str().unwrap_or("");
let subscription_id = event["data"]["object"]["subscription"].as_str().unwrap_or("");
let _ = subscriptions::create_or_update(
&state.db,
user_id,
Some(customer_id),
Some(subscription_id),
None,
"active",
"paid",
).await;
tracing::info!("Subscription activated for user {}", user_id);
}
}
}
}
"invoice.payment_succeeded" => {
tracing::info!("Invoice payment succeeded");
}
"customer.subscription.deleted" => {
if let Some(sub_id) = event["data"]["object"]["id"].as_str() {
let _ = subscriptions::update_status(&state.db, sub_id, "canceled").await;
tracing::info!("Subscription {} canceled", sub_id);
}
}
_ => {}
}
Ok(StatusCode::OK)
}

View File

@@ -0,0 +1,95 @@
use axum::{
extract::{Json, Path, State},
http::StatusCode,
Extension,
};
use db::repos::teams;
use serde::{Deserialize, Serialize};
use shared::models::{Team, TeamMember};
use uuid::Uuid;
use crate::{middleware::jwt::JwtAuth, state::AppState};
#[derive(Debug, Deserialize)]
pub struct CreateTeamRequest {
pub name: String,
pub slug: String,
}
#[derive(Debug, Serialize)]
pub struct TeamResponse {
pub team: Team,
pub members: Vec<TeamMember>,
}
pub async fn create(
State(state): State<AppState>,
Extension(auth): Extension<JwtAuth>,
Json(body): Json<CreateTeamRequest>,
) -> Result<Json<Team>, StatusCode> {
let team = teams::create(&state.db, &body.name, &body.slug, auth.user_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// Add owner as member
let _ = teams::add_member(&state.db, team.id, auth.user_id, "owner").await;
Ok(Json(team))
}
pub async fn get(
State(state): State<AppState>,
Extension(auth): Extension<JwtAuth>,
Path(slug): Path<String>,
) -> Result<Json<TeamResponse>, StatusCode> {
let team = teams::find_by_slug(&state.db, &slug)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
// Check membership
let member = teams::find_member(&state.db, team.id, auth.user_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if member.is_none() {
return Err(StatusCode::FORBIDDEN);
}
let members = teams::list_members(&state.db, team.id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(TeamResponse { team, members }))
}
#[derive(Debug, Deserialize)]
pub struct AddMemberRequest {
pub user_id: Uuid,
#[serde(default)]
pub role: String,
}
pub async fn add_member(
State(state): State<AppState>,
Extension(auth): Extension<JwtAuth>,
Path(slug): Path<String>,
Json(body): Json<AddMemberRequest>,
) -> Result<Json<TeamMember>, StatusCode> {
let team = teams::find_by_slug(&state.db, &slug)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
// Only owner can add members
if team.owner_id != auth.user_id {
return Err(StatusCode::FORBIDDEN);
}
let role = if body.role.is_empty() { "member" } else { &body.role };
let member = teams::add_member(&state.db, team.id, body.user_id, role)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(member))
}

View File

@@ -0,0 +1,41 @@
use axum::{
extract::ws::{Message, WebSocket, WebSocketUpgrade},
response::IntoResponse,
};
pub async fn live_logs(ws: WebSocketUpgrade) -> impl IntoResponse {
ws.on_upgrade(handle_socket)
}
async fn handle_socket(mut socket: WebSocket) {
// Send initial connection message
let _ = socket.send(Message::Text(r#"{"type":"connected","message":"Live logs connected"}"#.to_string())).await;
// In a real implementation, this would subscribe to a Redis pub/sub channel
// and stream logs to the client. For MVP, we send a heartbeat.
let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
loop {
tokio::select! {
_ = interval.tick() => {
let msg = r#"{"type":"heartbeat","timestamp":""#.to_string()
+ &chrono::Utc::now().to_rfc3339()
+ "\"}";
if socket.send(Message::Text(msg)).await.is_err() {
break;
}
}
msg = socket.recv() => {
match msg {
Some(Ok(Message::Close(_))) | None => break,
Some(Ok(Message::Text(text))) => {
if text == "ping" {
let _ = socket.send(Message::Text("pong".to_string())).await;
}
}
_ => {}
}
}
}
}
}

View File

@@ -0,0 +1,69 @@
use shared::error::AppError;
pub async fn get_secret(key: &str) -> Result<String, AppError> {
// Priority 1: Environment variable (for local dev)
if let Ok(val) = std::env::var(key) {
return Ok(val);
}
// Priority 2: Try Vault
if let Ok(val) = get_vault_secret(key).await {
return Ok(val);
}
// Priority 3: Try AWS Secrets Manager
if let Ok(val) = get_aws_secret(key).await {
return Ok(val);
}
Err(AppError::Internal(format!("Secret {} not found in any provider", key)))
}
async fn get_vault_secret(key: &str) -> Result<String, AppError> {
let vault_addr = match std::env::var("VAULT_ADDR") {
Ok(addr) => addr,
Err(_) => return Err(AppError::Internal("VAULT_ADDR not set".to_string())),
};
let vault_token = match std::env::var("VAULT_TOKEN") {
Ok(token) => token,
Err(_) => return Err(AppError::Internal("VAULT_TOKEN not set".to_string())),
};
let client = reqwest::Client::new();
let response = client
.get(format!("{}/v1/secret/data/{}", vault_addr, key))
.header("X-Vault-Token", vault_token)
.send()
.await
.map_err(|e| AppError::Internal(format!("Vault request failed: {}", e)))?;
let data: serde_json::Value = response
.json()
.await
.map_err(|e| AppError::Internal(format!("Vault response parse failed: {}", e)))?;
data["data"]["data"][key]
.as_str()
.map(|s| s.to_string())
.ok_or_else(|| AppError::Internal(format!("Secret {} not found in Vault", key)))
}
async fn get_aws_secret(key: &str) -> Result<String, AppError> {
let secret_name = format!("crawlapi/{}", key.to_lowercase().replace('_', "/"));
let config = aws_config::from_env().load().await;
let client = aws_sdk_secretsmanager::Client::new(&config);
let response = client
.get_secret_value()
.secret_id(&secret_name)
.send()
.await
.map_err(|e| AppError::Internal(format!("AWS Secrets Manager error: {}", e)))?;
response
.secret_string()
.map(|s| s.to_string())
.ok_or_else(|| AppError::Internal(format!("Secret {} not found in AWS", key)))
}

13
crates/api/src/state.rs Normal file
View File

@@ -0,0 +1,13 @@
use aws_sdk_s3::Client as S3Client;
use db::DbPool;
use redis::aio::MultiplexedConnection;
use shared::config::AppConfig;
use std::sync::Arc;
#[derive(Clone)]
pub struct AppState {
pub config: Arc<AppConfig>,
pub db: DbPool,
pub redis: MultiplexedConnection,
pub s3: S3Client,
}

View File

@@ -0,0 +1,54 @@
use aws_sdk_s3::Client as S3Client;
use shared::error::AppError;
use uuid::Uuid;
pub async fn upload_file(
s3: &S3Client,
bucket: &str,
key: &str,
content_type: &str,
data: Vec<u8>,
) -> Result<String, AppError> {
s3.put_object()
.bucket(bucket)
.key(key)
.content_type(content_type)
.body(data.into())
.send()
.await
.map_err(|e| AppError::S3(e.to_string()))?;
Ok(format!("{}/{}", bucket, key))
}
pub async fn ensure_bucket_exists(s3: &S3Client, bucket: &str) -> Result<(), AppError> {
let exists = s3
.head_bucket()
.bucket(bucket)
.send()
.await
.is_ok();
if !exists {
s3.create_bucket()
.bucket(bucket)
.send()
.await
.map_err(|e| AppError::S3(e.to_string()))?;
}
Ok(())
}
pub fn generate_file_key(endpoint: &str, ext: &str) -> String {
let id = Uuid::new_v4().to_string().replace('-', "");
match endpoint {
"screenshot" => format!("screenshots/{}.png", id),
"pdf" => format!("pdfs/{}.pdf", id),
_ => format!("files/{}.{}", id, ext),
}
}
pub fn get_public_url(_endpoint: &str, s3_endpoint: &str, bucket: &str, key: &str) -> String {
format!("{}/{}/{}", s3_endpoint, bucket, key)
}

View File

@@ -0,0 +1,62 @@
use shared::error::AppError;
pub fn validate_url(url: &str) -> Result<(), AppError> {
let parsed = url::Url::parse(url).map_err(|_| AppError::InvalidUrl(url.to_string()))?;
// Only allow http and https
if parsed.scheme() != "http" && parsed.scheme() != "https" {
return Err(AppError::InvalidUrl("Only HTTP and HTTPS URLs are allowed".to_string()));
}
// Block private IP ranges
if let Some(host) = parsed.host_str() {
if host == "localhost" || host == "127.0.0.1" || host.starts_with("10.") || host.starts_with("192.168.") {
return Err(AppError::InvalidUrl("Private IP addresses are not allowed".to_string()));
}
if host.starts_with("172.") {
if let Some(seg) = host.split('.').nth(1) {
if let Ok(n) = seg.parse::<u8>() {
if n >= 16 && n <= 31 {
return Err(AppError::InvalidUrl("Private IP addresses are not allowed".to_string()));
}
}
}
}
}
// Block file://, ftp://, etc.
if parsed.scheme() == "file" {
return Err(AppError::InvalidUrl("File URLs are not allowed".to_string()));
}
Ok(())
}
pub fn validate_webhook_url(url: &str) -> Result<(), AppError> {
let parsed = url::Url::parse(url).map_err(|_| AppError::InvalidUrl(url.to_string()))?;
// Only allow http and https
if parsed.scheme() != "http" && parsed.scheme() != "https" {
return Err(AppError::InvalidUrl("Webhook must use HTTP or HTTPS".to_string()));
}
// Block private IPs and localhost for webhooks
if let Some(host) = parsed.host_str() {
if host == "localhost" || host == "127.0.0.1" || host.starts_with("10.") || host.starts_with("192.168.") {
return Err(AppError::InvalidUrl("Webhook cannot point to private addresses".to_string()));
}
}
Ok(())
}
pub fn validate_size(content: &[u8], max_mb: usize) -> Result<(), AppError> {
let max_bytes = max_mb * 1024 * 1024;
if content.len() > max_bytes {
return Err(AppError::BadRequest(format!(
"Content exceeds maximum size of {}MB",
max_mb
)));
}
Ok(())
}