Initial commit: Full Crawl API implementation
This commit is contained in:
39
crates/api/Cargo.toml
Normal file
39
crates/api/Cargo.toml
Normal file
@@ -0,0 +1,39 @@
|
||||
[package]
|
||||
name = "api"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
shared = { path = "../shared" }
|
||||
db = { path = "../db" }
|
||||
axum = { workspace = true, features = ["ws"] }
|
||||
tokio = { workspace = true }
|
||||
tower = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
redis = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
aws-config = { workspace = true }
|
||||
aws-sdk-s3 = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
bcrypt = { workspace = true }
|
||||
config = { workspace = true }
|
||||
argon2 = { workspace = true }
|
||||
url = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
scraper = { workspace = true }
|
||||
markdown = { workspace = true }
|
||||
md5 = "0.7"
|
||||
prometheus = "0.13"
|
||||
lazy_static = "1.5"
|
||||
sentry = "0.36"
|
||||
async-stripe = { version = "1.0.0-rc.5", features = ["default-tls"] }
|
||||
aws-sdk-secretsmanager = "1.0"
|
||||
51
crates/api/src/bin/seed.rs
Normal file
51
crates/api/src/bin/seed.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use db::connection::create_pool;
|
||||
use db::repos::{api_keys, users};
|
||||
use shared::config::AppConfig;
|
||||
use std::sync::Arc;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
let config = Arc::new(AppConfig::from_env()?);
|
||||
let db = create_pool(&config.database_url).await?;
|
||||
|
||||
// Create test user
|
||||
let email = "demo@crawlapi.dev";
|
||||
let password = "demo123456";
|
||||
let password_hash = bcrypt::hash(password, bcrypt::DEFAULT_COST)?;
|
||||
|
||||
let user = match users::find_by_email(&db, email).await? {
|
||||
Some(u) => {
|
||||
tracing::info!("User already exists: {}", u.id);
|
||||
u
|
||||
}
|
||||
None => {
|
||||
let u = users::create(&db, email, Some(&password_hash), None).await?;
|
||||
tracing::info!("Created user: {} with 30 free credits", u.id);
|
||||
u
|
||||
}
|
||||
};
|
||||
|
||||
// Create API key
|
||||
let api_key = format!("crawlapi_demo_{}", Uuid::new_v4().to_string().replace('-', ""));
|
||||
let key_hash = format!("{:x}", md5::compute(&api_key));
|
||||
|
||||
let key = api_keys::create(&db, user.id, &key_hash, "Demo Key").await?;
|
||||
tracing::info!("Created API key: {} (id: {})", api_key, key.id);
|
||||
|
||||
println!("\n========================================");
|
||||
println!("SEED DATA CREATED");
|
||||
println!("========================================");
|
||||
println!("Email: {}", email);
|
||||
println!("Password: {}", password);
|
||||
println!("API Key: {}", api_key);
|
||||
println!("Credits: {}", user.credits);
|
||||
println!("========================================\n");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
8
crates/api/src/lib.rs
Normal file
8
crates/api/src/lib.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
pub mod middleware;
|
||||
pub mod metrics;
|
||||
pub mod queue;
|
||||
pub mod routes;
|
||||
pub mod secrets;
|
||||
pub mod state;
|
||||
pub mod storage;
|
||||
pub mod validation;
|
||||
79
crates/api/src/main.rs
Normal file
79
crates/api/src/main.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
use std::sync::Arc;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
|
||||
|
||||
use api::{metrics, routes, state::AppState, storage};
|
||||
use db::connection::create_pool;
|
||||
use shared::config::AppConfig;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let sentry_dsn = std::env::var("SENTRY_DSN").ok();
|
||||
let _guard = sentry_dsn.map(|dsn| {
|
||||
sentry::init((dsn, sentry::ClientOptions {
|
||||
release: sentry::release_name!(),
|
||||
..Default::default()
|
||||
}))
|
||||
});
|
||||
|
||||
// Structured JSON logging with correlation IDs
|
||||
let json_logging = std::env::var("JSON_LOGGING").unwrap_or_else(|_| "false".to_string()) == "true";
|
||||
|
||||
if json_logging {
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| "api=debug,tower_http=debug".into()),
|
||||
)
|
||||
.with(
|
||||
tracing_subscriber::fmt::layer()
|
||||
.json()
|
||||
.with_current_span(true)
|
||||
.with_span_list(true)
|
||||
.with_target(true),
|
||||
)
|
||||
.init();
|
||||
} else {
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| "api=debug,tower_http=debug".into()),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
}
|
||||
|
||||
metrics::register_metrics();
|
||||
|
||||
let config = Arc::new(AppConfig::from_env()?);
|
||||
|
||||
let db = create_pool(&config.database_url).await?;
|
||||
sqlx::migrate!("../db/migrations").run(&db).await?;
|
||||
|
||||
let redis = redis::Client::open(config.redis_url.clone())?;
|
||||
let redis_conn = redis.get_multiplexed_tokio_connection().await?;
|
||||
|
||||
let s3_config = aws_config::from_env()
|
||||
.endpoint_url(&config.s3_endpoint)
|
||||
.load()
|
||||
.await;
|
||||
let s3 = aws_sdk_s3::Client::new(&s3_config);
|
||||
|
||||
storage::ensure_bucket_exists(&s3, &config.s3_bucket).await?;
|
||||
|
||||
let state = AppState {
|
||||
config,
|
||||
db,
|
||||
redis: redis_conn,
|
||||
s3,
|
||||
};
|
||||
|
||||
let app = routes::create_router(state)
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
|
||||
tracing::info!("API server listening on {}", listener.local_addr()?);
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
31
crates/api/src/metrics.rs
Normal file
31
crates/api/src/metrics.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use lazy_static::lazy_static;
|
||||
use prometheus::{CounterVec, HistogramVec, Registry};
|
||||
use std::time::Instant;
|
||||
|
||||
lazy_static! {
|
||||
pub static ref REGISTRY: Registry = Registry::new();
|
||||
pub static ref REQUEST_COUNTER: CounterVec = CounterVec::new(
|
||||
prometheus::Opts::new("api_requests_total", "Total API requests"),
|
||||
&["endpoint", "status"]
|
||||
)
|
||||
.unwrap();
|
||||
pub static ref REQUEST_DURATION: HistogramVec = HistogramVec::new(
|
||||
prometheus::HistogramOpts::new("api_request_duration_seconds", "Request duration in seconds"),
|
||||
&["endpoint"]
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn register_metrics() {
|
||||
REGISTRY.register(Box::new(REQUEST_COUNTER.clone())).unwrap();
|
||||
REGISTRY.register(Box::new(REQUEST_DURATION.clone())).unwrap();
|
||||
}
|
||||
|
||||
pub fn record_request(endpoint: &str, status: &str) {
|
||||
REQUEST_COUNTER.with_label_values(&[endpoint, status]).inc();
|
||||
}
|
||||
|
||||
pub fn record_duration(endpoint: &str, start: Instant) {
|
||||
let duration = start.elapsed().as_secs_f64();
|
||||
REQUEST_DURATION.with_label_values(&[endpoint]).observe(duration);
|
||||
}
|
||||
52
crates/api/src/middleware/auth.rs
Normal file
52
crates/api/src/middleware/auth.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use db::repos::api_keys;
|
||||
use shared::models::User;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ApiKeyAuth {
|
||||
pub user: User,
|
||||
pub api_key_id: uuid::Uuid,
|
||||
}
|
||||
|
||||
pub async fn api_key_middleware(
|
||||
State(state): State<AppState>,
|
||||
mut req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let api_key = req
|
||||
.headers()
|
||||
.get("x-api-key")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let key_hash = format!("{:x}", md5::compute(api_key));
|
||||
|
||||
let api_key_record = api_keys::find_by_key_hash(&state.db, &key_hash)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let user = db::repos::users::find_by_id(&state.db, api_key_record.user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
api_keys::update_last_used(&state.db, api_key_record.id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let auth = ApiKeyAuth {
|
||||
user,
|
||||
api_key_id: api_key_record.id,
|
||||
};
|
||||
|
||||
req.extensions_mut().insert(auth);
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
42
crates/api/src/middleware/correlation.rs
Normal file
42
crates/api/src/middleware/correlation.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::{header::HeaderValue, StatusCode},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use tracing::{info_span, Instrument};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
pub async fn correlation_id_middleware(
|
||||
State(_state): State<AppState>,
|
||||
mut req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let correlation_id = req
|
||||
.headers()
|
||||
.get("x-correlation-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| Uuid::new_v4().to_string());
|
||||
|
||||
req.headers_mut().insert(
|
||||
"x-correlation-id",
|
||||
HeaderValue::from_str(&correlation_id).unwrap(),
|
||||
);
|
||||
|
||||
let method = req.method().to_string();
|
||||
let uri = req.uri().to_string();
|
||||
|
||||
let span = info_span!(
|
||||
"http_request",
|
||||
correlation_id = %correlation_id,
|
||||
method = %method,
|
||||
uri = %uri,
|
||||
);
|
||||
|
||||
let response = next.run(req).instrument(span).await;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
49
crates/api/src/middleware/jwt.rs
Normal file
49
crates/api/src/middleware/jwt.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use jsonwebtoken::{decode, DecodingKey, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct JwtClaims {
|
||||
pub sub: String,
|
||||
pub exp: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct JwtAuth {
|
||||
pub user_id: Uuid,
|
||||
}
|
||||
|
||||
pub async fn jwt_middleware(
|
||||
State(state): State<AppState>,
|
||||
mut req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let token = req
|
||||
.headers()
|
||||
.get("x-auth-token")
|
||||
.or_else(|| req.headers().get("authorization"))
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer ").or(Some(v)))
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let validation = Validation::default();
|
||||
let token_data = decode::<JwtClaims>(
|
||||
token,
|
||||
&DecodingKey::from_secret(state.config.jwt_secret.as_bytes()),
|
||||
&validation,
|
||||
)
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let user_id = Uuid::parse_str(&token_data.claims.sub).map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
req.extensions_mut().insert(JwtAuth { user_id });
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
5
crates/api/src/middleware/mod.rs
Normal file
5
crates/api/src/middleware/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod auth;
|
||||
pub mod correlation;
|
||||
pub mod jwt;
|
||||
pub mod rate_limit;
|
||||
pub mod waf;
|
||||
36
crates/api/src/middleware/rate_limit.rs
Normal file
36
crates/api/src/middleware/rate_limit.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use redis::AsyncCommands;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
pub async fn rate_limit_middleware(
|
||||
State(state): State<AppState>,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let api_key = req
|
||||
.headers()
|
||||
.get("x-api-key")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("anonymous");
|
||||
|
||||
let key = format!("rate_limit:{}", api_key);
|
||||
let mut conn = state.redis.clone();
|
||||
|
||||
let count: i64 = conn.incr(&key, 1).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if count == 1 {
|
||||
let _: () = conn.expire(&key, 60).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
}
|
||||
|
||||
if count > 60 {
|
||||
return Err(StatusCode::TOO_MANY_REQUESTS);
|
||||
}
|
||||
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
51
crates/api/src/middleware/waf.rs
Normal file
51
crates/api/src/middleware/waf.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use axum::{
|
||||
extract::{ConnectInfo, Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use redis::AsyncCommands;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
pub async fn ip_rate_limit_middleware(
|
||||
State(state): State<AppState>,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let ip = addr.ip().to_string();
|
||||
let key = format!("ip_rate_limit:{}", ip);
|
||||
let mut conn = state.redis.clone();
|
||||
|
||||
// Check if IP is blocked
|
||||
let blocked: bool = conn.exists(format!("ip_blocked:{}", ip))
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if blocked {
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
// Aggressive rate limiting: 100 req/min per IP
|
||||
let count: i64 = conn.incr(&key, 1)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if count == 1 {
|
||||
let _: () = conn.expire(&key, 60)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
}
|
||||
|
||||
if count > 100 {
|
||||
// Block IP for 1 hour after exceeding limit
|
||||
let _: () = conn.set_ex(format!("ip_blocked:{}", ip), "1", 3600)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
return Err(StatusCode::TOO_MANY_REQUESTS);
|
||||
}
|
||||
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
94
crates/api/src/queue.rs
Normal file
94
crates/api/src/queue.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
use redis::AsyncCommands;
|
||||
use shared::{
|
||||
models::CrawlOptions,
|
||||
queue::{Job, JobResult, QUEUE_NAME, RESULT_PREFIX},
|
||||
};
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
pub async fn enqueue_job(
|
||||
state: &AppState,
|
||||
user_id: Uuid,
|
||||
api_key_id: Uuid,
|
||||
endpoint: &str,
|
||||
url: &str,
|
||||
options: &CrawlOptions,
|
||||
webhook_url: Option<String>,
|
||||
) -> Result<Uuid, redis::RedisError> {
|
||||
let job = Job {
|
||||
id: Uuid::new_v4(),
|
||||
user_id,
|
||||
api_key_id,
|
||||
endpoint: endpoint.to_string(),
|
||||
url: url.to_string(),
|
||||
options: options.clone(),
|
||||
webhook_url,
|
||||
};
|
||||
|
||||
let job_json = serde_json::to_string(&job).unwrap();
|
||||
let mut conn = state.redis.clone();
|
||||
conn.rpush::<_, _, ()>(QUEUE_NAME, job_json).await?;
|
||||
|
||||
Ok(job.id)
|
||||
}
|
||||
|
||||
pub async fn wait_for_result(
|
||||
state: &AppState,
|
||||
job_id: Uuid,
|
||||
timeout_secs: u64,
|
||||
) -> Result<Option<JobResult>, redis::RedisError> {
|
||||
let result_key = format!("{}{}", RESULT_PREFIX, job_id);
|
||||
let mut conn = state.redis.clone();
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
while start.elapsed().as_secs() < timeout_secs {
|
||||
let result_json: Option<String> = conn.get(&result_key).await?;
|
||||
if let Some(json) = result_json {
|
||||
let result: JobResult = serde_json::from_str(&json).unwrap_or_else(|_| JobResult {
|
||||
id: job_id,
|
||||
success: false,
|
||||
data: None,
|
||||
error: Some("Failed to deserialize result".to_string()),
|
||||
duration_ms: 0,
|
||||
});
|
||||
return Ok(Some(result));
|
||||
}
|
||||
sleep(Duration::from_millis(200)).await;
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub async fn get_cache_key(url: &str, endpoint: &str, options: &CrawlOptions) -> String {
|
||||
let opts_json = serde_json::to_string(options).unwrap_or_default();
|
||||
let hash = format!("{:x}", md5::compute(format!("{}:{}:{}", url, endpoint, opts_json)));
|
||||
format!("crawlapi:cache:{}", hash)
|
||||
}
|
||||
|
||||
pub async fn get_cached_result(
|
||||
state: &AppState,
|
||||
cache_key: &str,
|
||||
) -> Result<Option<JobResult>, redis::RedisError> {
|
||||
let mut conn = state.redis.clone();
|
||||
let result_json: Option<String> = conn.get::<_, Option<String>>(cache_key).await?;
|
||||
if let Some(json) = result_json {
|
||||
let result: JobResult = serde_json::from_str(&json).unwrap();
|
||||
return Ok(Some(result));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub async fn set_cached_result(
|
||||
state: &AppState,
|
||||
cache_key: &str,
|
||||
result: &JobResult,
|
||||
ttl_secs: u64,
|
||||
) -> Result<(), redis::RedisError> {
|
||||
let mut conn = state.redis.clone();
|
||||
let json = serde_json::to_string(result).unwrap();
|
||||
conn.set_ex::<_, _, ()>(cache_key, json, ttl_secs).await?;
|
||||
Ok(())
|
||||
}
|
||||
130
crates/api/src/routes/ai.rs
Normal file
130
crates/api/src/routes/ai.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
use axum::{
|
||||
extract::{Json, State},
|
||||
http::StatusCode,
|
||||
Extension,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use shared::models::CrawlRequest;
|
||||
|
||||
use crate::{middleware::auth::ApiKeyAuth, queue, state::AppState};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AiExtractRequest {
|
||||
pub url: String,
|
||||
pub schema: serde_json::Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct AiExtractResponse {
|
||||
pub success: bool,
|
||||
pub data: Option<serde_json::Value>,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn extract(
|
||||
State(state): State<AppState>,
|
||||
Extension(auth): Extension<ApiKeyAuth>,
|
||||
Json(body): Json<AiExtractRequest>,
|
||||
) -> Result<Json<AiExtractResponse>, StatusCode> {
|
||||
let openai_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
|
||||
if openai_key.is_empty() {
|
||||
return Ok(Json(AiExtractResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
error: Some("OpenAI not configured".to_string()),
|
||||
}));
|
||||
}
|
||||
|
||||
// Crawl the page via queue
|
||||
let crawl_req = CrawlRequest {
|
||||
url: body.url.clone(),
|
||||
options: Default::default(),
|
||||
};
|
||||
|
||||
let job_id = queue::enqueue_job(
|
||||
&state,
|
||||
auth.user.id,
|
||||
auth.api_key_id,
|
||||
"crawl",
|
||||
&body.url,
|
||||
&crawl_req.options,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let job_result = queue::wait_for_result(&state, job_id, 60)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
match job_result {
|
||||
Some(result) if result.success => {
|
||||
let data = result.data.unwrap_or_default();
|
||||
let html = data.get("html").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let title = data.get("title").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
// Call OpenAI
|
||||
let client = reqwest::Client::new();
|
||||
let system_prompt = format!(
|
||||
"You are a web scraping assistant. Extract structured data from the following HTML page titled '{}'. \
|
||||
Return ONLY a JSON object matching the requested schema. Do not include any explanation.",
|
||||
title
|
||||
);
|
||||
|
||||
let user_prompt = if let Some(p) = body.prompt {
|
||||
p
|
||||
} else {
|
||||
format!("Extract data from this HTML according to schema: {}\n\nHTML:\n{}",
|
||||
body.schema.to_string(),
|
||||
&html[..html.len().min(8000)])
|
||||
};
|
||||
|
||||
let res = client
|
||||
.post("https://api.openai.com/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {}", openai_key))
|
||||
.json(&serde_json::json!({
|
||||
"model": "gpt-4o-mini",
|
||||
"messages": [
|
||||
{ "role": "system", "content": system_prompt },
|
||||
{ "role": "user", "content": user_prompt }
|
||||
],
|
||||
"temperature": 0.1,
|
||||
"response_format": { "type": "json_object" }
|
||||
}))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(response) => {
|
||||
let ai_data: serde_json::Value = response.json().await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
if let Some(content) = ai_data["choices"][0]["message"]["content"].as_str() {
|
||||
let parsed: serde_json::Value = serde_json::from_str(content).unwrap_or_else(|_| serde_json::json!({"raw": content}));
|
||||
Ok(Json(AiExtractResponse {
|
||||
success: true,
|
||||
data: Some(parsed),
|
||||
error: None,
|
||||
}))
|
||||
} else {
|
||||
Ok(Json(AiExtractResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
error: Some("Invalid OpenAI response".to_string()),
|
||||
}))
|
||||
}
|
||||
}
|
||||
Err(e) => Ok(Json(AiExtractResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
error: Some(format!("OpenAI error: {}", e)),
|
||||
})),
|
||||
}
|
||||
}
|
||||
_ => Ok(Json(AiExtractResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
error: Some("Failed to crawl page".to_string()),
|
||||
})),
|
||||
}
|
||||
}
|
||||
142
crates/api/src/routes/auth.rs
Normal file
142
crates/api/src/routes/auth.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
use axum::{
|
||||
extract::{Json, Path, State},
|
||||
http::StatusCode,
|
||||
Extension,
|
||||
};
|
||||
use db::repos::{api_keys, users};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use shared::models::{ApiKey, User};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{middleware::jwt::JwtAuth, state::AppState};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RegisterRequest {
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct AuthResponse {
|
||||
pub user: User,
|
||||
pub token: String,
|
||||
}
|
||||
|
||||
pub async fn register(
|
||||
State(state): State<AppState>,
|
||||
Json(body): Json<RegisterRequest>,
|
||||
) -> Result<Json<AuthResponse>, StatusCode> {
|
||||
if users::find_by_email(&state.db, &body.email).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?.is_some() {
|
||||
return Err(StatusCode::CONFLICT);
|
||||
}
|
||||
|
||||
let password_hash = bcrypt::hash(&body.password, bcrypt::DEFAULT_COST).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let user = users::create(&state.db, &body.email, Some(&password_hash), None)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let token = create_jwt(&user.id.to_string(), &state.config.jwt_secret)?;
|
||||
|
||||
Ok(Json(AuthResponse { user, token }))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LoginRequest {
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
pub async fn login(
|
||||
State(state): State<AppState>,
|
||||
Json(body): Json<LoginRequest>,
|
||||
) -> Result<Json<AuthResponse>, StatusCode> {
|
||||
let user = users::find_by_email(&state.db, &body.email)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let password_hash = user.password_hash.as_ref().ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
if !bcrypt::verify(&body.password, password_hash).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
let token = create_jwt(&user.id.to_string(), &state.config.jwt_secret)?;
|
||||
|
||||
Ok(Json(AuthResponse { user, token }))
|
||||
}
|
||||
|
||||
pub fn create_jwt(user_id: &str, secret: &str) -> Result<String, StatusCode> {
|
||||
use jsonwebtoken::{encode, EncodingKey, Header};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Claims {
|
||||
sub: String,
|
||||
exp: usize,
|
||||
}
|
||||
|
||||
let claims = Claims {
|
||||
sub: user_id.to_string(),
|
||||
exp: (chrono::Utc::now() + chrono::Duration::days(30)).timestamp() as usize,
|
||||
};
|
||||
|
||||
encode(&Header::default(), &claims, &EncodingKey::from_secret(secret.as_bytes()))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateApiKeyRequest {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ApiKeyResponse {
|
||||
pub id: Uuid,
|
||||
pub key: String,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
pub async fn create_api_key(
|
||||
State(state): State<AppState>,
|
||||
Extension(auth): Extension<JwtAuth>,
|
||||
Json(body): Json<CreateApiKeyRequest>,
|
||||
) -> Result<Json<ApiKeyResponse>, StatusCode> {
|
||||
let api_key = format!("crawlapi_{}", Uuid::new_v4().to_string().replace('-', ""));
|
||||
let key_hash = format!("{:x}", md5::compute(&api_key));
|
||||
|
||||
let key = api_keys::create(&state.db, auth.user_id, &key_hash, &body.name)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(ApiKeyResponse {
|
||||
id: key.id,
|
||||
key: api_key,
|
||||
name: key.name,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn list_api_keys(
|
||||
State(state): State<AppState>,
|
||||
Extension(auth): Extension<JwtAuth>,
|
||||
) -> Result<Json<Vec<ApiKey>>, StatusCode> {
|
||||
let keys = api_keys::list_by_user(&state.db, auth.user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(keys))
|
||||
}
|
||||
|
||||
pub async fn delete_api_key(
|
||||
State(state): State<AppState>,
|
||||
Extension(auth): Extension<JwtAuth>,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
let deleted = api_keys::delete_by_id(&state.db, id, auth.user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if deleted {
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
} else {
|
||||
Err(StatusCode::NOT_FOUND)
|
||||
}
|
||||
}
|
||||
296
crates/api/src/routes/crawl.rs
Normal file
296
crates/api/src/routes/crawl.rs
Normal file
@@ -0,0 +1,296 @@
|
||||
use axum::{
|
||||
extract::{Json, State},
|
||||
http::StatusCode,
|
||||
Extension,
|
||||
};
|
||||
use db::repos::{usage_logs, users};
|
||||
use serde_json::json;
|
||||
use shared::{
|
||||
error::AppError,
|
||||
models::{CrawlRequest, CrawlResponse},
|
||||
};
|
||||
use std::time::Instant;
|
||||
use tokio::fs;
|
||||
|
||||
use crate::{middleware::auth::ApiKeyAuth, queue, state::AppState, storage, validation};
|
||||
|
||||
async fn upload_files_if_needed(
|
||||
state: &AppState,
|
||||
endpoint: &str,
|
||||
result: &mut serde_json::Value,
|
||||
) -> Result<(), AppError> {
|
||||
// Handle file_path
|
||||
let file_path_opt = result.get("file_path").and_then(|v| v.as_str()).map(String::from);
|
||||
if let Some(file_path) = file_path_opt {
|
||||
let file_data = fs::read(&file_path)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to read file: {}", e)))?;
|
||||
|
||||
let ext = if endpoint == "pdf" { "pdf" } else { "png" };
|
||||
let content_type = if endpoint == "pdf" { "application/pdf" } else { "image/png" };
|
||||
let key = storage::generate_file_key(endpoint, ext);
|
||||
|
||||
storage::upload_file(
|
||||
&state.s3,
|
||||
&state.config.s3_bucket,
|
||||
&key,
|
||||
content_type,
|
||||
file_data,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let public_url = storage::get_public_url(
|
||||
endpoint,
|
||||
&state.config.s3_endpoint,
|
||||
&state.config.s3_bucket,
|
||||
&key,
|
||||
);
|
||||
|
||||
if let Some(obj) = result.as_object_mut() {
|
||||
obj.remove("file_path");
|
||||
obj.insert("url".to_string(), json!(public_url));
|
||||
}
|
||||
|
||||
let _ = fs::remove_file(file_path).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_endpoint(
|
||||
state: State<AppState>,
|
||||
Extension(auth): Extension<ApiKeyAuth>,
|
||||
Json(body): Json<CrawlRequest>,
|
||||
endpoint: &'static str,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
let start = Instant::now();
|
||||
|
||||
// Validate URL
|
||||
if let Err(e) = validation::validate_url(&body.url) {
|
||||
let _ = usage_logs::create(
|
||||
&state.db,
|
||||
auth.user.id,
|
||||
auth.api_key_id,
|
||||
endpoint,
|
||||
&body.url,
|
||||
"error",
|
||||
0,
|
||||
start.elapsed().as_millis() as i64,
|
||||
)
|
||||
.await;
|
||||
return Ok(Json(CrawlResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
calls_remaining: Some(auth.user.credits),
|
||||
error: Some(e.to_string()),
|
||||
}));
|
||||
}
|
||||
|
||||
// Validate webhook URL if provided
|
||||
if let Some(ref webhook_url) = body.options.webhook_url {
|
||||
if let Err(e) = validation::validate_webhook_url(webhook_url) {
|
||||
return Ok(Json(CrawlResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
calls_remaining: Some(auth.user.credits),
|
||||
error: Some(e.to_string()),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// Check credits
|
||||
if auth.user.credits <= 0 {
|
||||
return Ok(Json(CrawlResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
calls_remaining: Some(0),
|
||||
error: Some("Insufficient credits".to_string()),
|
||||
}));
|
||||
}
|
||||
|
||||
// Deduct credits
|
||||
let has_credits = users::deduct_credits(&state.db, auth.user.id, 1)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if !has_credits {
|
||||
return Ok(Json(CrawlResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
calls_remaining: Some(0),
|
||||
error: Some("Insufficient credits".to_string()),
|
||||
}));
|
||||
}
|
||||
|
||||
// Try cache first
|
||||
let cache_key = queue::get_cache_key(&body.url, endpoint, &body.options).await;
|
||||
let cached = queue::get_cached_result(&state, &cache_key)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if let Some(cached_result) = cached {
|
||||
let _ = usage_logs::create(
|
||||
&state.db,
|
||||
auth.user.id,
|
||||
auth.api_key_id,
|
||||
endpoint,
|
||||
&body.url,
|
||||
"success",
|
||||
1,
|
||||
start.elapsed().as_millis() as i64,
|
||||
)
|
||||
.await;
|
||||
|
||||
return Ok(Json(CrawlResponse {
|
||||
success: cached_result.success,
|
||||
data: cached_result.data,
|
||||
calls_remaining: Some(auth.user.credits - 1),
|
||||
error: cached_result.error,
|
||||
}));
|
||||
}
|
||||
|
||||
// Enqueue job and wait for result
|
||||
let job_id = queue::enqueue_job(
|
||||
&state,
|
||||
auth.user.id,
|
||||
auth.api_key_id,
|
||||
endpoint,
|
||||
&body.url,
|
||||
&body.options,
|
||||
body.options.webhook_url.clone(),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let job_result = queue::wait_for_result(&state, job_id, 60)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let duration = start.elapsed().as_millis() as i64;
|
||||
let remaining = auth.user.credits - 1;
|
||||
|
||||
match job_result {
|
||||
Some(mut result) => {
|
||||
// Upload files if needed
|
||||
if let Some(ref mut data) = result.data {
|
||||
let _ = upload_files_if_needed(&state, endpoint, data).await;
|
||||
}
|
||||
|
||||
// Cache successful results for 5 minutes
|
||||
if result.success {
|
||||
let _ = queue::set_cached_result(&state, &cache_key, &result, 300).await;
|
||||
}
|
||||
|
||||
let status = if result.success { "success" } else { "error" };
|
||||
let _ = usage_logs::create(
|
||||
&state.db,
|
||||
auth.user.id,
|
||||
auth.api_key_id,
|
||||
endpoint,
|
||||
&body.url,
|
||||
status,
|
||||
1,
|
||||
duration,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(CrawlResponse {
|
||||
success: result.success,
|
||||
data: result.data,
|
||||
calls_remaining: Some(remaining),
|
||||
error: result.error,
|
||||
}))
|
||||
}
|
||||
None => {
|
||||
let _ = usage_logs::create(
|
||||
&state.db,
|
||||
auth.user.id,
|
||||
auth.api_key_id,
|
||||
endpoint,
|
||||
&body.url,
|
||||
"timeout",
|
||||
1,
|
||||
duration,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(CrawlResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
calls_remaining: Some(remaining),
|
||||
error: Some("Job timed out".to_string()),
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_crawl(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "crawl").await
|
||||
}
|
||||
|
||||
pub async fn handle_content(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "content").await
|
||||
}
|
||||
|
||||
pub async fn handle_screenshot(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "screenshot").await
|
||||
}
|
||||
|
||||
pub async fn handle_pdf(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "pdf").await
|
||||
}
|
||||
|
||||
pub async fn handle_markdown(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "markdown").await
|
||||
}
|
||||
|
||||
pub async fn handle_snapshot(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "snapshot").await
|
||||
}
|
||||
|
||||
pub async fn handle_scrape(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "scrape").await
|
||||
}
|
||||
|
||||
pub async fn handle_json(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "json").await
|
||||
}
|
||||
|
||||
pub async fn handle_links(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "links").await
|
||||
}
|
||||
72
crates/api/src/routes/mod.rs
Normal file
72
crates/api/src/routes/mod.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
pub mod auth;
|
||||
pub mod crawl;
|
||||
pub mod oauth;
|
||||
pub mod stripe;
|
||||
pub mod ai;
|
||||
pub mod teams;
|
||||
pub mod ws;
|
||||
|
||||
use axum::{
|
||||
middleware,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use tower_http::cors::CorsLayer;
|
||||
|
||||
use crate::{
|
||||
middleware::{auth::api_key_middleware, correlation::correlation_id_middleware, jwt::jwt_middleware, rate_limit::rate_limit_middleware, waf::ip_rate_limit_middleware},
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
pub fn create_router(state: AppState) -> Router {
|
||||
let api_routes = Router::new()
|
||||
.route("/crawl", post(crawl::handle_crawl))
|
||||
.route("/content", post(crawl::handle_content))
|
||||
.route("/screenshot", post(crawl::handle_screenshot))
|
||||
.route("/pdf", post(crawl::handle_pdf))
|
||||
.route("/markdown", post(crawl::handle_markdown))
|
||||
.route("/snapshot", post(crawl::handle_snapshot))
|
||||
.route("/scrape", post(crawl::handle_scrape))
|
||||
.route("/json", post(crawl::handle_json))
|
||||
.route("/links", post(crawl::handle_links))
|
||||
.route("/extract", post(ai::extract))
|
||||
.route_layer(middleware::from_fn_with_state(state.clone(), api_key_middleware))
|
||||
.route_layer(middleware::from_fn_with_state(state.clone(), rate_limit_middleware));
|
||||
|
||||
let auth_routes = Router::new()
|
||||
.route("/auth/register", post(auth::register))
|
||||
.route("/auth/login", post(auth::login))
|
||||
.route("/auth/google", get(oauth::google_auth_url))
|
||||
.route("/auth/google/callback", get(oauth::google_callback));
|
||||
|
||||
let protected_routes = Router::new()
|
||||
.route("/auth/api-keys", post(auth::create_api_key))
|
||||
.route("/auth/api-keys", get(auth::list_api_keys))
|
||||
.route("/auth/api-keys/{id}", axum::routing::delete(auth::delete_api_key))
|
||||
.route("/stripe/checkout", post(stripe::create_checkout))
|
||||
.route("/teams", post(teams::create))
|
||||
.route("/teams/{slug}", get(teams::get))
|
||||
.route("/teams/{slug}/members", post(teams::add_member))
|
||||
.route_layer(middleware::from_fn_with_state(state.clone(), jwt_middleware));
|
||||
|
||||
let stripe_webhook = Router::new()
|
||||
.route("/stripe/webhook", post(stripe::webhook));
|
||||
|
||||
Router::new()
|
||||
.nest("/api", api_routes)
|
||||
.nest("/api", auth_routes)
|
||||
.nest("/api", protected_routes)
|
||||
.nest("/api", stripe_webhook)
|
||||
.route("/metrics", get(|| async {
|
||||
use prometheus::Encoder;
|
||||
let encoder = prometheus::TextEncoder::new();
|
||||
let mut buffer = vec![];
|
||||
encoder.encode(&crate::metrics::REGISTRY.gather(), &mut buffer).unwrap();
|
||||
String::from_utf8(buffer).unwrap()
|
||||
}))
|
||||
.route("/ws/logs", get(ws::live_logs))
|
||||
.layer(middleware::from_fn_with_state(state.clone(), ip_rate_limit_middleware))
|
||||
.layer(middleware::from_fn_with_state(state.clone(), correlation_id_middleware))
|
||||
.layer(CorsLayer::permissive())
|
||||
.with_state(state)
|
||||
}
|
||||
137
crates/api/src/routes/oauth.rs
Normal file
137
crates/api/src/routes/oauth.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
Json,
|
||||
};
|
||||
use db::repos::{oauth, users};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GoogleCallback {
|
||||
pub code: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct GoogleAuthUrl {
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GoogleTokenResponse {
|
||||
access_token: String,
|
||||
id_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GoogleUserInfo {
|
||||
sub: String,
|
||||
email: String,
|
||||
name: Option<String>,
|
||||
picture: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn google_auth_url(State(_state): State<AppState>) -> Result<Json<GoogleAuthUrl>, StatusCode> {
|
||||
let client_id = std::env::var("GOOGLE_CLIENT_ID").unwrap_or_default();
|
||||
if client_id.is_empty() {
|
||||
return Err(StatusCode::NOT_IMPLEMENTED);
|
||||
}
|
||||
let redirect_uri = std::env::var("GOOGLE_REDIRECT_URI")
|
||||
.unwrap_or_else(|_| "http://localhost:3000/api/auth/google/callback".to_string());
|
||||
|
||||
let url = format!(
|
||||
"https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=email%20profile&access_type=offline&prompt=consent",
|
||||
client_id, redirect_uri
|
||||
);
|
||||
|
||||
Ok(Json(GoogleAuthUrl { url }))
|
||||
}
|
||||
|
||||
pub async fn google_callback(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<GoogleCallback>,
|
||||
) -> Result<Json<super::auth::AuthResponse>, StatusCode> {
|
||||
let client_id = std::env::var("GOOGLE_CLIENT_ID").unwrap_or_default();
|
||||
let client_secret = std::env::var("GOOGLE_CLIENT_SECRET").unwrap_or_default();
|
||||
let redirect_uri = std::env::var("GOOGLE_REDIRECT_URI")
|
||||
.unwrap_or_else(|_| "http://localhost:3000/api/auth/google/callback".to_string());
|
||||
|
||||
if client_id.is_empty() || client_secret.is_empty() {
|
||||
// MVP fallback: create mock user
|
||||
let email = format!("google_user_{}@example.com", ¶ms.code[..8.min(params.code.len())]);
|
||||
let user = match users::find_by_email(&state.db, &email).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? {
|
||||
Some(u) => u,
|
||||
None => {
|
||||
let u = users::create(&state.db, &email, None, Some(¶ms.code)).await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let _ = oauth::create(&state.db, u.id, "google", ¶ms.code).await;
|
||||
u
|
||||
}
|
||||
};
|
||||
let token = super::auth::create_jwt(&user.id.to_string(), &state.config.jwt_secret)?;
|
||||
return Ok(Json(super::auth::AuthResponse { user, token }));
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
let client = reqwest::Client::new();
|
||||
let token_res = client
|
||||
.post("https://oauth2.googleapis.com/token")
|
||||
.form(&[
|
||||
("code", params.code.as_str()),
|
||||
("client_id", client_id.as_str()),
|
||||
("client_secret", client_secret.as_str()),
|
||||
("redirect_uri", redirect_uri.as_str()),
|
||||
("grant_type", "authorization_code"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.json::<GoogleTokenResponse>()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
// Get user info
|
||||
let user_info = client
|
||||
.get("https://openidconnect.googleapis.com/v1/userinfo")
|
||||
.header("Authorization", format!("Bearer {}", token_res.access_token))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.json::<GoogleUserInfo>()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
// Find or create user
|
||||
let user = match oauth::find_by_provider(&state.db, "google", &user_info.sub).await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
{
|
||||
Some(oauth_account) => {
|
||||
users::find_by_id(&state.db, oauth_account.user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
}
|
||||
None => {
|
||||
let user = match users::find_by_email(&state.db, &user_info.email).await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
{
|
||||
Some(u) => u,
|
||||
None => {
|
||||
users::create(&state.db, &user_info.email, None, Some(&user_info.sub))
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
}
|
||||
};
|
||||
|
||||
oauth::create(&state.db, user.id, "google", &user_info.sub)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
user
|
||||
}
|
||||
};
|
||||
|
||||
let token = super::auth::create_jwt(&user.id.to_string(), &state.config.jwt_secret)?;
|
||||
Ok(Json(super::auth::AuthResponse { user, token }))
|
||||
}
|
||||
146
crates/api/src/routes/stripe.rs
Normal file
146
crates/api/src/routes/stripe.rs
Normal file
@@ -0,0 +1,146 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::{HeaderMap, StatusCode},
|
||||
Json,
|
||||
};
|
||||
use db::repos::{subscriptions, users};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{middleware::jwt::JwtAuth, state::AppState};
|
||||
use axum::Extension;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateCheckoutRequest {
|
||||
pub price_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct CheckoutResponse {
|
||||
pub checkout_url: String,
|
||||
}
|
||||
|
||||
pub async fn create_checkout(
|
||||
State(state): State<AppState>,
|
||||
Extension(auth): Extension<JwtAuth>,
|
||||
Json(body): Json<CreateCheckoutRequest>,
|
||||
) -> Result<Json<CheckoutResponse>, StatusCode> {
|
||||
let stripe_secret = std::env::var("STRIPE_SECRET_KEY").map_err(|_| StatusCode::NOT_IMPLEMENTED)?;
|
||||
|
||||
let user = users::find_by_id(&state.db, auth.user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
// Create Stripe customer via HTTP API directly (simpler than SDK for MVP)
|
||||
let client = reqwest::Client::new();
|
||||
let customer_res = client
|
||||
.post("https://api.stripe.com/v1/customers")
|
||||
.basic_auth(&stripe_secret, Some(""))
|
||||
.form(&[("email", &user.email)])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let customer_data: serde_json::Value = customer_res.json().await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let customer_id = customer_data["id"].as_str().unwrap_or("");
|
||||
|
||||
let success_url = std::env::var("STRIPE_SUCCESS_URL")
|
||||
.unwrap_or_else(|_| "http://localhost:3000/dashboard?success=true".to_string());
|
||||
let cancel_url = std::env::var("STRIPE_CANCEL_URL")
|
||||
.unwrap_or_else(|_| "http://localhost:3000/dashboard?canceled=true".to_string());
|
||||
|
||||
let session_res = client
|
||||
.post("https://api.stripe.com/v1/checkout/sessions")
|
||||
.basic_auth(&stripe_secret, Some(""))
|
||||
.form(&[
|
||||
("customer", customer_id),
|
||||
("success_url", &success_url),
|
||||
("cancel_url", &cancel_url),
|
||||
("mode", "subscription"),
|
||||
("line_items[0][price]", &body.price_id),
|
||||
("line_items[0][quantity]", "1"),
|
||||
("metadata[user_id]", &auth.user_id.to_string()),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let session_data: serde_json::Value = session_res.json().await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let url = session_data["url"].as_str().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(CheckoutResponse { checkout_url: url.to_string() }))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct StripeWebhook {
|
||||
#[serde(rename = "type")]
|
||||
pub event_type: String,
|
||||
pub data: StripeEventData,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct StripeEventData {
|
||||
pub object: serde_json::Value,
|
||||
}
|
||||
|
||||
pub async fn webhook(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
body: String,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
let stripe_secret = std::env::var("STRIPE_WEBHOOK_SECRET").unwrap_or_default();
|
||||
|
||||
// Verify webhook signature if configured
|
||||
if !stripe_secret.is_empty() {
|
||||
let sig = headers
|
||||
.get("stripe-signature")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.ok_or(StatusCode::BAD_REQUEST)?;
|
||||
|
||||
// In production, verify signature using Stripe library
|
||||
// For MVP, we log and process
|
||||
tracing::info!("Webhook signature: {}", sig);
|
||||
}
|
||||
|
||||
let event: serde_json::Value = serde_json::from_str(&body).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let event_type = event["type"].as_str().unwrap_or("");
|
||||
|
||||
match event_type {
|
||||
"checkout.session.completed" => {
|
||||
if let Some(metadata) = event["data"]["object"]["metadata"].as_object() {
|
||||
if let Some(user_id_str) = metadata.get("user_id").and_then(|v| v.as_str()) {
|
||||
if let Ok(user_id) = uuid::Uuid::parse_str(user_id_str) {
|
||||
let customer_id = event["data"]["object"]["customer"].as_str().unwrap_or("");
|
||||
let subscription_id = event["data"]["object"]["subscription"].as_str().unwrap_or("");
|
||||
let _ = subscriptions::create_or_update(
|
||||
&state.db,
|
||||
user_id,
|
||||
Some(customer_id),
|
||||
Some(subscription_id),
|
||||
None,
|
||||
"active",
|
||||
"paid",
|
||||
).await;
|
||||
tracing::info!("Subscription activated for user {}", user_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"invoice.payment_succeeded" => {
|
||||
tracing::info!("Invoice payment succeeded");
|
||||
}
|
||||
"customer.subscription.deleted" => {
|
||||
if let Some(sub_id) = event["data"]["object"]["id"].as_str() {
|
||||
let _ = subscriptions::update_status(&state.db, sub_id, "canceled").await;
|
||||
tracing::info!("Subscription {} canceled", sub_id);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
95
crates/api/src/routes/teams.rs
Normal file
95
crates/api/src/routes/teams.rs
Normal file
@@ -0,0 +1,95 @@
|
||||
use axum::{
|
||||
extract::{Json, Path, State},
|
||||
http::StatusCode,
|
||||
Extension,
|
||||
};
|
||||
use db::repos::teams;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use shared::models::{Team, TeamMember};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{middleware::jwt::JwtAuth, state::AppState};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateTeamRequest {
|
||||
pub name: String,
|
||||
pub slug: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct TeamResponse {
|
||||
pub team: Team,
|
||||
pub members: Vec<TeamMember>,
|
||||
}
|
||||
|
||||
pub async fn create(
|
||||
State(state): State<AppState>,
|
||||
Extension(auth): Extension<JwtAuth>,
|
||||
Json(body): Json<CreateTeamRequest>,
|
||||
) -> Result<Json<Team>, StatusCode> {
|
||||
let team = teams::create(&state.db, &body.name, &body.slug, auth.user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
// Add owner as member
|
||||
let _ = teams::add_member(&state.db, team.id, auth.user_id, "owner").await;
|
||||
|
||||
Ok(Json(team))
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
State(state): State<AppState>,
|
||||
Extension(auth): Extension<JwtAuth>,
|
||||
Path(slug): Path<String>,
|
||||
) -> Result<Json<TeamResponse>, StatusCode> {
|
||||
let team = teams::find_by_slug(&state.db, &slug)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
// Check membership
|
||||
let member = teams::find_member(&state.db, team.id, auth.user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if member.is_none() {
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
let members = teams::list_members(&state.db, team.id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(TeamResponse { team, members }))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AddMemberRequest {
|
||||
pub user_id: Uuid,
|
||||
#[serde(default)]
|
||||
pub role: String,
|
||||
}
|
||||
|
||||
pub async fn add_member(
|
||||
State(state): State<AppState>,
|
||||
Extension(auth): Extension<JwtAuth>,
|
||||
Path(slug): Path<String>,
|
||||
Json(body): Json<AddMemberRequest>,
|
||||
) -> Result<Json<TeamMember>, StatusCode> {
|
||||
let team = teams::find_by_slug(&state.db, &slug)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
// Only owner can add members
|
||||
if team.owner_id != auth.user_id {
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
let role = if body.role.is_empty() { "member" } else { &body.role };
|
||||
let member = teams::add_member(&state.db, team.id, body.user_id, role)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(member))
|
||||
}
|
||||
41
crates/api/src/routes/ws.rs
Normal file
41
crates/api/src/routes/ws.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use axum::{
|
||||
extract::ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
response::IntoResponse,
|
||||
};
|
||||
|
||||
pub async fn live_logs(ws: WebSocketUpgrade) -> impl IntoResponse {
|
||||
ws.on_upgrade(handle_socket)
|
||||
}
|
||||
|
||||
async fn handle_socket(mut socket: WebSocket) {
|
||||
// Send initial connection message
|
||||
let _ = socket.send(Message::Text(r#"{"type":"connected","message":"Live logs connected"}"#.to_string())).await;
|
||||
|
||||
// In a real implementation, this would subscribe to a Redis pub/sub channel
|
||||
// and stream logs to the client. For MVP, we send a heartbeat.
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = interval.tick() => {
|
||||
let msg = r#"{"type":"heartbeat","timestamp":""#.to_string()
|
||||
+ &chrono::Utc::now().to_rfc3339()
|
||||
+ "\"}";
|
||||
if socket.send(Message::Text(msg)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
msg = socket.recv() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Close(_))) | None => break,
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
if text == "ping" {
|
||||
let _ = socket.send(Message::Text("pong".to_string())).await;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
69
crates/api/src/secrets/mod.rs
Normal file
69
crates/api/src/secrets/mod.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
use shared::error::AppError;
|
||||
|
||||
pub async fn get_secret(key: &str) -> Result<String, AppError> {
|
||||
// Priority 1: Environment variable (for local dev)
|
||||
if let Ok(val) = std::env::var(key) {
|
||||
return Ok(val);
|
||||
}
|
||||
|
||||
// Priority 2: Try Vault
|
||||
if let Ok(val) = get_vault_secret(key).await {
|
||||
return Ok(val);
|
||||
}
|
||||
|
||||
// Priority 3: Try AWS Secrets Manager
|
||||
if let Ok(val) = get_aws_secret(key).await {
|
||||
return Ok(val);
|
||||
}
|
||||
|
||||
Err(AppError::Internal(format!("Secret {} not found in any provider", key)))
|
||||
}
|
||||
|
||||
async fn get_vault_secret(key: &str) -> Result<String, AppError> {
|
||||
let vault_addr = match std::env::var("VAULT_ADDR") {
|
||||
Ok(addr) => addr,
|
||||
Err(_) => return Err(AppError::Internal("VAULT_ADDR not set".to_string())),
|
||||
};
|
||||
|
||||
let vault_token = match std::env::var("VAULT_TOKEN") {
|
||||
Ok(token) => token,
|
||||
Err(_) => return Err(AppError::Internal("VAULT_TOKEN not set".to_string())),
|
||||
};
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let response = client
|
||||
.get(format!("{}/v1/secret/data/{}", vault_addr, key))
|
||||
.header("X-Vault-Token", vault_token)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Vault request failed: {}", e)))?;
|
||||
|
||||
let data: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Vault response parse failed: {}", e)))?;
|
||||
|
||||
data["data"]["data"][key]
|
||||
.as_str()
|
||||
.map(|s| s.to_string())
|
||||
.ok_or_else(|| AppError::Internal(format!("Secret {} not found in Vault", key)))
|
||||
}
|
||||
|
||||
async fn get_aws_secret(key: &str) -> Result<String, AppError> {
|
||||
let secret_name = format!("crawlapi/{}", key.to_lowercase().replace('_', "/"));
|
||||
|
||||
let config = aws_config::from_env().load().await;
|
||||
let client = aws_sdk_secretsmanager::Client::new(&config);
|
||||
|
||||
let response = client
|
||||
.get_secret_value()
|
||||
.secret_id(&secret_name)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("AWS Secrets Manager error: {}", e)))?;
|
||||
|
||||
response
|
||||
.secret_string()
|
||||
.map(|s| s.to_string())
|
||||
.ok_or_else(|| AppError::Internal(format!("Secret {} not found in AWS", key)))
|
||||
}
|
||||
13
crates/api/src/state.rs
Normal file
13
crates/api/src/state.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
use aws_sdk_s3::Client as S3Client;
|
||||
use db::DbPool;
|
||||
use redis::aio::MultiplexedConnection;
|
||||
use shared::config::AppConfig;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub config: Arc<AppConfig>,
|
||||
pub db: DbPool,
|
||||
pub redis: MultiplexedConnection,
|
||||
pub s3: S3Client,
|
||||
}
|
||||
54
crates/api/src/storage/mod.rs
Normal file
54
crates/api/src/storage/mod.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use aws_sdk_s3::Client as S3Client;
|
||||
use shared::error::AppError;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn upload_file(
|
||||
s3: &S3Client,
|
||||
bucket: &str,
|
||||
key: &str,
|
||||
content_type: &str,
|
||||
data: Vec<u8>,
|
||||
) -> Result<String, AppError> {
|
||||
s3.put_object()
|
||||
.bucket(bucket)
|
||||
.key(key)
|
||||
.content_type(content_type)
|
||||
.body(data.into())
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::S3(e.to_string()))?;
|
||||
|
||||
Ok(format!("{}/{}", bucket, key))
|
||||
}
|
||||
|
||||
pub async fn ensure_bucket_exists(s3: &S3Client, bucket: &str) -> Result<(), AppError> {
|
||||
let exists = s3
|
||||
.head_bucket()
|
||||
.bucket(bucket)
|
||||
.send()
|
||||
.await
|
||||
.is_ok();
|
||||
|
||||
if !exists {
|
||||
s3.create_bucket()
|
||||
.bucket(bucket)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::S3(e.to_string()))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn generate_file_key(endpoint: &str, ext: &str) -> String {
|
||||
let id = Uuid::new_v4().to_string().replace('-', "");
|
||||
match endpoint {
|
||||
"screenshot" => format!("screenshots/{}.png", id),
|
||||
"pdf" => format!("pdfs/{}.pdf", id),
|
||||
_ => format!("files/{}.{}", id, ext),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_public_url(_endpoint: &str, s3_endpoint: &str, bucket: &str, key: &str) -> String {
|
||||
format!("{}/{}/{}", s3_endpoint, bucket, key)
|
||||
}
|
||||
62
crates/api/src/validation.rs
Normal file
62
crates/api/src/validation.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
use shared::error::AppError;
|
||||
|
||||
pub fn validate_url(url: &str) -> Result<(), AppError> {
|
||||
let parsed = url::Url::parse(url).map_err(|_| AppError::InvalidUrl(url.to_string()))?;
|
||||
|
||||
// Only allow http and https
|
||||
if parsed.scheme() != "http" && parsed.scheme() != "https" {
|
||||
return Err(AppError::InvalidUrl("Only HTTP and HTTPS URLs are allowed".to_string()));
|
||||
}
|
||||
|
||||
// Block private IP ranges
|
||||
if let Some(host) = parsed.host_str() {
|
||||
if host == "localhost" || host == "127.0.0.1" || host.starts_with("10.") || host.starts_with("192.168.") {
|
||||
return Err(AppError::InvalidUrl("Private IP addresses are not allowed".to_string()));
|
||||
}
|
||||
if host.starts_with("172.") {
|
||||
if let Some(seg) = host.split('.').nth(1) {
|
||||
if let Ok(n) = seg.parse::<u8>() {
|
||||
if n >= 16 && n <= 31 {
|
||||
return Err(AppError::InvalidUrl("Private IP addresses are not allowed".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Block file://, ftp://, etc.
|
||||
if parsed.scheme() == "file" {
|
||||
return Err(AppError::InvalidUrl("File URLs are not allowed".to_string()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn validate_webhook_url(url: &str) -> Result<(), AppError> {
|
||||
let parsed = url::Url::parse(url).map_err(|_| AppError::InvalidUrl(url.to_string()))?;
|
||||
|
||||
// Only allow http and https
|
||||
if parsed.scheme() != "http" && parsed.scheme() != "https" {
|
||||
return Err(AppError::InvalidUrl("Webhook must use HTTP or HTTPS".to_string()));
|
||||
}
|
||||
|
||||
// Block private IPs and localhost for webhooks
|
||||
if let Some(host) = parsed.host_str() {
|
||||
if host == "localhost" || host == "127.0.0.1" || host.starts_with("10.") || host.starts_with("192.168.") {
|
||||
return Err(AppError::InvalidUrl("Webhook cannot point to private addresses".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn validate_size(content: &[u8], max_mb: usize) -> Result<(), AppError> {
|
||||
let max_bytes = max_mb * 1024 * 1024;
|
||||
if content.len() > max_bytes {
|
||||
return Err(AppError::BadRequest(format!(
|
||||
"Content exceeds maximum size of {}MB",
|
||||
max_mb
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
29
crates/api/tests/integration_test.rs
Normal file
29
crates/api/tests/integration_test.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_check() {
|
||||
// This is a placeholder for integration tests
|
||||
// In a real setup, you would spawn the API server and make HTTP requests
|
||||
assert!(true);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_crawl_request_validation() {
|
||||
let req = shared::models::CrawlRequest {
|
||||
url: "https://example.com".to_string(),
|
||||
options: shared::models::CrawlOptions::default(),
|
||||
};
|
||||
assert_eq!(req.url, "https://example.com");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_api_response_format() {
|
||||
let response = shared::api::ApiResponse::ok(json!({"test": true}));
|
||||
assert!(response.success);
|
||||
assert!(response.data.is_some());
|
||||
assert!(response.error.is_none());
|
||||
|
||||
let error = shared::api::ApiResponse::<()>::err("Something went wrong");
|
||||
assert!(!error.success);
|
||||
assert!(error.error.is_some());
|
||||
}
|
||||
Reference in New Issue
Block a user