Initial commit: Full Crawl API implementation
Some checks failed
CI / Test (push) Has been cancelled
Deploy / Deploy to Staging (push) Has been cancelled
CI / Build & Push (push) Has been cancelled
Deploy / Deploy to Production (push) Has been cancelled

This commit is contained in:
2026-04-29 07:03:48 +00:00
commit 62994d4f3d
92 changed files with 6176 additions and 0 deletions

39
crates/api/Cargo.toml Normal file
View File

@@ -0,0 +1,39 @@
[package]
name = "api"
version = "0.1.0"
edition = "2021"
[dependencies]
shared = { path = "../shared" }
db = { path = "../db" }
axum = { workspace = true, features = ["ws"] }
tokio = { workspace = true }
tower = { workspace = true }
tower-http = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
redis = { workspace = true }
uuid = { workspace = true }
chrono = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
aws-config = { workspace = true }
aws-sdk-s3 = { workspace = true }
reqwest = { workspace = true }
jsonwebtoken = { workspace = true }
bcrypt = { workspace = true }
config = { workspace = true }
argon2 = { workspace = true }
url = { workspace = true }
sqlx = { workspace = true }
regex = { workspace = true }
scraper = { workspace = true }
markdown = { workspace = true }
md5 = "0.7"
prometheus = "0.13"
lazy_static = "1.5"
sentry = "0.36"
async-stripe = { version = "1.0.0-rc.5", features = ["default-tls"] }
aws-sdk-secretsmanager = "1.0"

View File

@@ -0,0 +1,51 @@
use db::connection::create_pool;
use db::repos::{api_keys, users};
use shared::config::AppConfig;
use std::sync::Arc;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use uuid::Uuid;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.init();
let config = Arc::new(AppConfig::from_env()?);
let db = create_pool(&config.database_url).await?;
// Create test user
let email = "demo@crawlapi.dev";
let password = "demo123456";
let password_hash = bcrypt::hash(password, bcrypt::DEFAULT_COST)?;
let user = match users::find_by_email(&db, email).await? {
Some(u) => {
tracing::info!("User already exists: {}", u.id);
u
}
None => {
let u = users::create(&db, email, Some(&password_hash), None).await?;
tracing::info!("Created user: {} with 30 free credits", u.id);
u
}
};
// Create API key
let api_key = format!("crawlapi_demo_{}", Uuid::new_v4().to_string().replace('-', ""));
let key_hash = format!("{:x}", md5::compute(&api_key));
let key = api_keys::create(&db, user.id, &key_hash, "Demo Key").await?;
tracing::info!("Created API key: {} (id: {})", api_key, key.id);
println!("\n========================================");
println!("SEED DATA CREATED");
println!("========================================");
println!("Email: {}", email);
println!("Password: {}", password);
println!("API Key: {}", api_key);
println!("Credits: {}", user.credits);
println!("========================================\n");
Ok(())
}

8
crates/api/src/lib.rs Normal file
View File

@@ -0,0 +1,8 @@
pub mod middleware;
pub mod metrics;
pub mod queue;
pub mod routes;
pub mod secrets;
pub mod state;
pub mod storage;
pub mod validation;

79
crates/api/src/main.rs Normal file
View File

@@ -0,0 +1,79 @@
use std::sync::Arc;
use tower_http::trace::TraceLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
use api::{metrics, routes, state::AppState, storage};
use db::connection::create_pool;
use shared::config::AppConfig;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let sentry_dsn = std::env::var("SENTRY_DSN").ok();
let _guard = sentry_dsn.map(|dsn| {
sentry::init((dsn, sentry::ClientOptions {
release: sentry::release_name!(),
..Default::default()
}))
});
// Structured JSON logging with correlation IDs
let json_logging = std::env::var("JSON_LOGGING").unwrap_or_else(|_| "false".to_string()) == "true";
if json_logging {
tracing_subscriber::registry()
.with(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "api=debug,tower_http=debug".into()),
)
.with(
tracing_subscriber::fmt::layer()
.json()
.with_current_span(true)
.with_span_list(true)
.with_target(true),
)
.init();
} else {
tracing_subscriber::registry()
.with(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "api=debug,tower_http=debug".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
}
metrics::register_metrics();
let config = Arc::new(AppConfig::from_env()?);
let db = create_pool(&config.database_url).await?;
sqlx::migrate!("../db/migrations").run(&db).await?;
let redis = redis::Client::open(config.redis_url.clone())?;
let redis_conn = redis.get_multiplexed_tokio_connection().await?;
let s3_config = aws_config::from_env()
.endpoint_url(&config.s3_endpoint)
.load()
.await;
let s3 = aws_sdk_s3::Client::new(&s3_config);
storage::ensure_bucket_exists(&s3, &config.s3_bucket).await?;
let state = AppState {
config,
db,
redis: redis_conn,
s3,
};
let app = routes::create_router(state)
.layer(TraceLayer::new_for_http());
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
tracing::info!("API server listening on {}", listener.local_addr()?);
axum::serve(listener, app).await?;
Ok(())
}

31
crates/api/src/metrics.rs Normal file
View File

@@ -0,0 +1,31 @@
use lazy_static::lazy_static;
use prometheus::{CounterVec, HistogramVec, Registry};
use std::time::Instant;
lazy_static! {
pub static ref REGISTRY: Registry = Registry::new();
pub static ref REQUEST_COUNTER: CounterVec = CounterVec::new(
prometheus::Opts::new("api_requests_total", "Total API requests"),
&["endpoint", "status"]
)
.unwrap();
pub static ref REQUEST_DURATION: HistogramVec = HistogramVec::new(
prometheus::HistogramOpts::new("api_request_duration_seconds", "Request duration in seconds"),
&["endpoint"]
)
.unwrap();
}
pub fn register_metrics() {
REGISTRY.register(Box::new(REQUEST_COUNTER.clone())).unwrap();
REGISTRY.register(Box::new(REQUEST_DURATION.clone())).unwrap();
}
pub fn record_request(endpoint: &str, status: &str) {
REQUEST_COUNTER.with_label_values(&[endpoint, status]).inc();
}
pub fn record_duration(endpoint: &str, start: Instant) {
let duration = start.elapsed().as_secs_f64();
REQUEST_DURATION.with_label_values(&[endpoint]).observe(duration);
}

View File

@@ -0,0 +1,52 @@
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use db::repos::api_keys;
use shared::models::User;
use crate::state::AppState;
#[derive(Clone)]
pub struct ApiKeyAuth {
pub user: User,
pub api_key_id: uuid::Uuid,
}
pub async fn api_key_middleware(
State(state): State<AppState>,
mut req: Request,
next: Next,
) -> Result<Response, StatusCode> {
let api_key = req
.headers()
.get("x-api-key")
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
let key_hash = format!("{:x}", md5::compute(api_key));
let api_key_record = api_keys::find_by_key_hash(&state.db, &key_hash)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::UNAUTHORIZED)?;
let user = db::repos::users::find_by_id(&state.db, api_key_record.user_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::UNAUTHORIZED)?;
api_keys::update_last_used(&state.db, api_key_record.id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let auth = ApiKeyAuth {
user,
api_key_id: api_key_record.id,
};
req.extensions_mut().insert(auth);
Ok(next.run(req).await)
}

View File

@@ -0,0 +1,42 @@
use axum::{
extract::{Request, State},
http::{header::HeaderValue, StatusCode},
middleware::Next,
response::Response,
};
use tracing::{info_span, Instrument};
use uuid::Uuid;
use crate::state::AppState;
pub async fn correlation_id_middleware(
State(_state): State<AppState>,
mut req: Request,
next: Next,
) -> Result<Response, StatusCode> {
let correlation_id = req
.headers()
.get("x-correlation-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| Uuid::new_v4().to_string());
req.headers_mut().insert(
"x-correlation-id",
HeaderValue::from_str(&correlation_id).unwrap(),
);
let method = req.method().to_string();
let uri = req.uri().to_string();
let span = info_span!(
"http_request",
correlation_id = %correlation_id,
method = %method,
uri = %uri,
);
let response = next.run(req).instrument(span).await;
Ok(response)
}

View File

@@ -0,0 +1,49 @@
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use jsonwebtoken::{decode, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::state::AppState;
#[derive(Debug, Serialize, Deserialize)]
pub struct JwtClaims {
pub sub: String,
pub exp: usize,
}
#[derive(Clone)]
pub struct JwtAuth {
pub user_id: Uuid,
}
pub async fn jwt_middleware(
State(state): State<AppState>,
mut req: Request,
next: Next,
) -> Result<Response, StatusCode> {
let token = req
.headers()
.get("x-auth-token")
.or_else(|| req.headers().get("authorization"))
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer ").or(Some(v)))
.ok_or(StatusCode::UNAUTHORIZED)?;
let validation = Validation::default();
let token_data = decode::<JwtClaims>(
token,
&DecodingKey::from_secret(state.config.jwt_secret.as_bytes()),
&validation,
)
.map_err(|_| StatusCode::UNAUTHORIZED)?;
let user_id = Uuid::parse_str(&token_data.claims.sub).map_err(|_| StatusCode::UNAUTHORIZED)?;
req.extensions_mut().insert(JwtAuth { user_id });
Ok(next.run(req).await)
}

View File

@@ -0,0 +1,5 @@
pub mod auth;
pub mod correlation;
pub mod jwt;
pub mod rate_limit;
pub mod waf;

View File

@@ -0,0 +1,36 @@
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use redis::AsyncCommands;
use crate::state::AppState;
pub async fn rate_limit_middleware(
State(state): State<AppState>,
req: Request,
next: Next,
) -> Result<Response, StatusCode> {
let api_key = req
.headers()
.get("x-api-key")
.and_then(|v| v.to_str().ok())
.unwrap_or("anonymous");
let key = format!("rate_limit:{}", api_key);
let mut conn = state.redis.clone();
let count: i64 = conn.incr(&key, 1).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if count == 1 {
let _: () = conn.expire(&key, 60).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
}
if count > 60 {
return Err(StatusCode::TOO_MANY_REQUESTS);
}
Ok(next.run(req).await)
}

View File

@@ -0,0 +1,51 @@
use axum::{
extract::{ConnectInfo, Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use redis::AsyncCommands;
use std::net::SocketAddr;
use crate::state::AppState;
pub async fn ip_rate_limit_middleware(
State(state): State<AppState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
req: Request,
next: Next,
) -> Result<Response, StatusCode> {
let ip = addr.ip().to_string();
let key = format!("ip_rate_limit:{}", ip);
let mut conn = state.redis.clone();
// Check if IP is blocked
let blocked: bool = conn.exists(format!("ip_blocked:{}", ip))
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if blocked {
return Err(StatusCode::FORBIDDEN);
}
// Aggressive rate limiting: 100 req/min per IP
let count: i64 = conn.incr(&key, 1)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if count == 1 {
let _: () = conn.expire(&key, 60)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
}
if count > 100 {
// Block IP for 1 hour after exceeding limit
let _: () = conn.set_ex(format!("ip_blocked:{}", ip), "1", 3600)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
return Err(StatusCode::TOO_MANY_REQUESTS);
}
Ok(next.run(req).await)
}

94
crates/api/src/queue.rs Normal file
View File

@@ -0,0 +1,94 @@
use redis::AsyncCommands;
use shared::{
models::CrawlOptions,
queue::{Job, JobResult, QUEUE_NAME, RESULT_PREFIX},
};
use std::time::Duration;
use tokio::time::sleep;
use uuid::Uuid;
use crate::state::AppState;
pub async fn enqueue_job(
state: &AppState,
user_id: Uuid,
api_key_id: Uuid,
endpoint: &str,
url: &str,
options: &CrawlOptions,
webhook_url: Option<String>,
) -> Result<Uuid, redis::RedisError> {
let job = Job {
id: Uuid::new_v4(),
user_id,
api_key_id,
endpoint: endpoint.to_string(),
url: url.to_string(),
options: options.clone(),
webhook_url,
};
let job_json = serde_json::to_string(&job).unwrap();
let mut conn = state.redis.clone();
conn.rpush::<_, _, ()>(QUEUE_NAME, job_json).await?;
Ok(job.id)
}
pub async fn wait_for_result(
state: &AppState,
job_id: Uuid,
timeout_secs: u64,
) -> Result<Option<JobResult>, redis::RedisError> {
let result_key = format!("{}{}", RESULT_PREFIX, job_id);
let mut conn = state.redis.clone();
let start = std::time::Instant::now();
while start.elapsed().as_secs() < timeout_secs {
let result_json: Option<String> = conn.get(&result_key).await?;
if let Some(json) = result_json {
let result: JobResult = serde_json::from_str(&json).unwrap_or_else(|_| JobResult {
id: job_id,
success: false,
data: None,
error: Some("Failed to deserialize result".to_string()),
duration_ms: 0,
});
return Ok(Some(result));
}
sleep(Duration::from_millis(200)).await;
}
Ok(None)
}
pub async fn get_cache_key(url: &str, endpoint: &str, options: &CrawlOptions) -> String {
let opts_json = serde_json::to_string(options).unwrap_or_default();
let hash = format!("{:x}", md5::compute(format!("{}:{}:{}", url, endpoint, opts_json)));
format!("crawlapi:cache:{}", hash)
}
pub async fn get_cached_result(
state: &AppState,
cache_key: &str,
) -> Result<Option<JobResult>, redis::RedisError> {
let mut conn = state.redis.clone();
let result_json: Option<String> = conn.get::<_, Option<String>>(cache_key).await?;
if let Some(json) = result_json {
let result: JobResult = serde_json::from_str(&json).unwrap();
return Ok(Some(result));
}
Ok(None)
}
pub async fn set_cached_result(
state: &AppState,
cache_key: &str,
result: &JobResult,
ttl_secs: u64,
) -> Result<(), redis::RedisError> {
let mut conn = state.redis.clone();
let json = serde_json::to_string(result).unwrap();
conn.set_ex::<_, _, ()>(cache_key, json, ttl_secs).await?;
Ok(())
}

130
crates/api/src/routes/ai.rs Normal file
View File

@@ -0,0 +1,130 @@
use axum::{
extract::{Json, State},
http::StatusCode,
Extension,
};
use serde::{Deserialize, Serialize};
use shared::models::CrawlRequest;
use crate::{middleware::auth::ApiKeyAuth, queue, state::AppState};
#[derive(Debug, Deserialize)]
pub struct AiExtractRequest {
pub url: String,
pub schema: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct AiExtractResponse {
pub success: bool,
pub data: Option<serde_json::Value>,
pub error: Option<String>,
}
pub async fn extract(
State(state): State<AppState>,
Extension(auth): Extension<ApiKeyAuth>,
Json(body): Json<AiExtractRequest>,
) -> Result<Json<AiExtractResponse>, StatusCode> {
let openai_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
if openai_key.is_empty() {
return Ok(Json(AiExtractResponse {
success: false,
data: None,
error: Some("OpenAI not configured".to_string()),
}));
}
// Crawl the page via queue
let crawl_req = CrawlRequest {
url: body.url.clone(),
options: Default::default(),
};
let job_id = queue::enqueue_job(
&state,
auth.user.id,
auth.api_key_id,
"crawl",
&body.url,
&crawl_req.options,
None,
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let job_result = queue::wait_for_result(&state, job_id, 60)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
match job_result {
Some(result) if result.success => {
let data = result.data.unwrap_or_default();
let html = data.get("html").and_then(|v| v.as_str()).unwrap_or("");
let title = data.get("title").and_then(|v| v.as_str()).unwrap_or("");
// Call OpenAI
let client = reqwest::Client::new();
let system_prompt = format!(
"You are a web scraping assistant. Extract structured data from the following HTML page titled '{}'. \
Return ONLY a JSON object matching the requested schema. Do not include any explanation.",
title
);
let user_prompt = if let Some(p) = body.prompt {
p
} else {
format!("Extract data from this HTML according to schema: {}\n\nHTML:\n{}",
body.schema.to_string(),
&html[..html.len().min(8000)])
};
let res = client
.post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {}", openai_key))
.json(&serde_json::json!({
"model": "gpt-4o-mini",
"messages": [
{ "role": "system", "content": system_prompt },
{ "role": "user", "content": user_prompt }
],
"temperature": 0.1,
"response_format": { "type": "json_object" }
}))
.send()
.await;
match res {
Ok(response) => {
let ai_data: serde_json::Value = response.json().await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if let Some(content) = ai_data["choices"][0]["message"]["content"].as_str() {
let parsed: serde_json::Value = serde_json::from_str(content).unwrap_or_else(|_| serde_json::json!({"raw": content}));
Ok(Json(AiExtractResponse {
success: true,
data: Some(parsed),
error: None,
}))
} else {
Ok(Json(AiExtractResponse {
success: false,
data: None,
error: Some("Invalid OpenAI response".to_string()),
}))
}
}
Err(e) => Ok(Json(AiExtractResponse {
success: false,
data: None,
error: Some(format!("OpenAI error: {}", e)),
})),
}
}
_ => Ok(Json(AiExtractResponse {
success: false,
data: None,
error: Some("Failed to crawl page".to_string()),
})),
}
}

View File

@@ -0,0 +1,142 @@
use axum::{
extract::{Json, Path, State},
http::StatusCode,
Extension,
};
use db::repos::{api_keys, users};
use serde::{Deserialize, Serialize};
use shared::models::{ApiKey, User};
use uuid::Uuid;
use crate::{middleware::jwt::JwtAuth, state::AppState};
#[derive(Debug, Deserialize)]
pub struct RegisterRequest {
pub email: String,
pub password: String,
}
#[derive(Debug, Serialize)]
pub struct AuthResponse {
pub user: User,
pub token: String,
}
pub async fn register(
State(state): State<AppState>,
Json(body): Json<RegisterRequest>,
) -> Result<Json<AuthResponse>, StatusCode> {
if users::find_by_email(&state.db, &body.email).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?.is_some() {
return Err(StatusCode::CONFLICT);
}
let password_hash = bcrypt::hash(&body.password, bcrypt::DEFAULT_COST).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let user = users::create(&state.db, &body.email, Some(&password_hash), None)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let token = create_jwt(&user.id.to_string(), &state.config.jwt_secret)?;
Ok(Json(AuthResponse { user, token }))
}
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
pub email: String,
pub password: String,
}
pub async fn login(
State(state): State<AppState>,
Json(body): Json<LoginRequest>,
) -> Result<Json<AuthResponse>, StatusCode> {
let user = users::find_by_email(&state.db, &body.email)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::UNAUTHORIZED)?;
let password_hash = user.password_hash.as_ref().ok_or(StatusCode::UNAUTHORIZED)?;
if !bcrypt::verify(&body.password, password_hash).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? {
return Err(StatusCode::UNAUTHORIZED);
}
let token = create_jwt(&user.id.to_string(), &state.config.jwt_secret)?;
Ok(Json(AuthResponse { user, token }))
}
pub fn create_jwt(user_id: &str, secret: &str) -> Result<String, StatusCode> {
use jsonwebtoken::{encode, EncodingKey, Header};
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
sub: String,
exp: usize,
}
let claims = Claims {
sub: user_id.to_string(),
exp: (chrono::Utc::now() + chrono::Duration::days(30)).timestamp() as usize,
};
encode(&Header::default(), &claims, &EncodingKey::from_secret(secret.as_bytes()))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
#[derive(Debug, Deserialize)]
pub struct CreateApiKeyRequest {
pub name: String,
}
#[derive(Debug, Serialize)]
pub struct ApiKeyResponse {
pub id: Uuid,
pub key: String,
pub name: String,
}
pub async fn create_api_key(
State(state): State<AppState>,
Extension(auth): Extension<JwtAuth>,
Json(body): Json<CreateApiKeyRequest>,
) -> Result<Json<ApiKeyResponse>, StatusCode> {
let api_key = format!("crawlapi_{}", Uuid::new_v4().to_string().replace('-', ""));
let key_hash = format!("{:x}", md5::compute(&api_key));
let key = api_keys::create(&state.db, auth.user_id, &key_hash, &body.name)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(ApiKeyResponse {
id: key.id,
key: api_key,
name: key.name,
}))
}
pub async fn list_api_keys(
State(state): State<AppState>,
Extension(auth): Extension<JwtAuth>,
) -> Result<Json<Vec<ApiKey>>, StatusCode> {
let keys = api_keys::list_by_user(&state.db, auth.user_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(keys))
}
pub async fn delete_api_key(
State(state): State<AppState>,
Extension(auth): Extension<JwtAuth>,
Path(id): Path<Uuid>,
) -> Result<StatusCode, StatusCode> {
let deleted = api_keys::delete_by_id(&state.db, id, auth.user_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if deleted {
Ok(StatusCode::NO_CONTENT)
} else {
Err(StatusCode::NOT_FOUND)
}
}

View File

@@ -0,0 +1,296 @@
use axum::{
extract::{Json, State},
http::StatusCode,
Extension,
};
use db::repos::{usage_logs, users};
use serde_json::json;
use shared::{
error::AppError,
models::{CrawlRequest, CrawlResponse},
};
use std::time::Instant;
use tokio::fs;
use crate::{middleware::auth::ApiKeyAuth, queue, state::AppState, storage, validation};
async fn upload_files_if_needed(
state: &AppState,
endpoint: &str,
result: &mut serde_json::Value,
) -> Result<(), AppError> {
// Handle file_path
let file_path_opt = result.get("file_path").and_then(|v| v.as_str()).map(String::from);
if let Some(file_path) = file_path_opt {
let file_data = fs::read(&file_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to read file: {}", e)))?;
let ext = if endpoint == "pdf" { "pdf" } else { "png" };
let content_type = if endpoint == "pdf" { "application/pdf" } else { "image/png" };
let key = storage::generate_file_key(endpoint, ext);
storage::upload_file(
&state.s3,
&state.config.s3_bucket,
&key,
content_type,
file_data,
)
.await?;
let public_url = storage::get_public_url(
endpoint,
&state.config.s3_endpoint,
&state.config.s3_bucket,
&key,
);
if let Some(obj) = result.as_object_mut() {
obj.remove("file_path");
obj.insert("url".to_string(), json!(public_url));
}
let _ = fs::remove_file(file_path).await;
}
Ok(())
}
async fn handle_endpoint(
state: State<AppState>,
Extension(auth): Extension<ApiKeyAuth>,
Json(body): Json<CrawlRequest>,
endpoint: &'static str,
) -> Result<Json<CrawlResponse>, StatusCode> {
let start = Instant::now();
// Validate URL
if let Err(e) = validation::validate_url(&body.url) {
let _ = usage_logs::create(
&state.db,
auth.user.id,
auth.api_key_id,
endpoint,
&body.url,
"error",
0,
start.elapsed().as_millis() as i64,
)
.await;
return Ok(Json(CrawlResponse {
success: false,
data: None,
calls_remaining: Some(auth.user.credits),
error: Some(e.to_string()),
}));
}
// Validate webhook URL if provided
if let Some(ref webhook_url) = body.options.webhook_url {
if let Err(e) = validation::validate_webhook_url(webhook_url) {
return Ok(Json(CrawlResponse {
success: false,
data: None,
calls_remaining: Some(auth.user.credits),
error: Some(e.to_string()),
}));
}
}
// Check credits
if auth.user.credits <= 0 {
return Ok(Json(CrawlResponse {
success: false,
data: None,
calls_remaining: Some(0),
error: Some("Insufficient credits".to_string()),
}));
}
// Deduct credits
let has_credits = users::deduct_credits(&state.db, auth.user.id, 1)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if !has_credits {
return Ok(Json(CrawlResponse {
success: false,
data: None,
calls_remaining: Some(0),
error: Some("Insufficient credits".to_string()),
}));
}
// Try cache first
let cache_key = queue::get_cache_key(&body.url, endpoint, &body.options).await;
let cached = queue::get_cached_result(&state, &cache_key)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if let Some(cached_result) = cached {
let _ = usage_logs::create(
&state.db,
auth.user.id,
auth.api_key_id,
endpoint,
&body.url,
"success",
1,
start.elapsed().as_millis() as i64,
)
.await;
return Ok(Json(CrawlResponse {
success: cached_result.success,
data: cached_result.data,
calls_remaining: Some(auth.user.credits - 1),
error: cached_result.error,
}));
}
// Enqueue job and wait for result
let job_id = queue::enqueue_job(
&state,
auth.user.id,
auth.api_key_id,
endpoint,
&body.url,
&body.options,
body.options.webhook_url.clone(),
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let job_result = queue::wait_for_result(&state, job_id, 60)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let duration = start.elapsed().as_millis() as i64;
let remaining = auth.user.credits - 1;
match job_result {
Some(mut result) => {
// Upload files if needed
if let Some(ref mut data) = result.data {
let _ = upload_files_if_needed(&state, endpoint, data).await;
}
// Cache successful results for 5 minutes
if result.success {
let _ = queue::set_cached_result(&state, &cache_key, &result, 300).await;
}
let status = if result.success { "success" } else { "error" };
let _ = usage_logs::create(
&state.db,
auth.user.id,
auth.api_key_id,
endpoint,
&body.url,
status,
1,
duration,
)
.await;
Ok(Json(CrawlResponse {
success: result.success,
data: result.data,
calls_remaining: Some(remaining),
error: result.error,
}))
}
None => {
let _ = usage_logs::create(
&state.db,
auth.user.id,
auth.api_key_id,
endpoint,
&body.url,
"timeout",
1,
duration,
)
.await;
Ok(Json(CrawlResponse {
success: false,
data: None,
calls_remaining: Some(remaining),
error: Some("Job timed out".to_string()),
}))
}
}
}
pub async fn handle_crawl(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "crawl").await
}
pub async fn handle_content(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "content").await
}
pub async fn handle_screenshot(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "screenshot").await
}
pub async fn handle_pdf(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "pdf").await
}
pub async fn handle_markdown(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "markdown").await
}
pub async fn handle_snapshot(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "snapshot").await
}
pub async fn handle_scrape(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "scrape").await
}
pub async fn handle_json(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "json").await
}
pub async fn handle_links(
state: State<AppState>,
auth: Extension<ApiKeyAuth>,
body: Json<CrawlRequest>,
) -> Result<Json<CrawlResponse>, StatusCode> {
handle_endpoint(state, auth, body, "links").await
}

View File

@@ -0,0 +1,72 @@
pub mod auth;
pub mod crawl;
pub mod oauth;
pub mod stripe;
pub mod ai;
pub mod teams;
pub mod ws;
use axum::{
middleware,
routing::{get, post},
Router,
};
use tower_http::cors::CorsLayer;
use crate::{
middleware::{auth::api_key_middleware, correlation::correlation_id_middleware, jwt::jwt_middleware, rate_limit::rate_limit_middleware, waf::ip_rate_limit_middleware},
state::AppState,
};
pub fn create_router(state: AppState) -> Router {
let api_routes = Router::new()
.route("/crawl", post(crawl::handle_crawl))
.route("/content", post(crawl::handle_content))
.route("/screenshot", post(crawl::handle_screenshot))
.route("/pdf", post(crawl::handle_pdf))
.route("/markdown", post(crawl::handle_markdown))
.route("/snapshot", post(crawl::handle_snapshot))
.route("/scrape", post(crawl::handle_scrape))
.route("/json", post(crawl::handle_json))
.route("/links", post(crawl::handle_links))
.route("/extract", post(ai::extract))
.route_layer(middleware::from_fn_with_state(state.clone(), api_key_middleware))
.route_layer(middleware::from_fn_with_state(state.clone(), rate_limit_middleware));
let auth_routes = Router::new()
.route("/auth/register", post(auth::register))
.route("/auth/login", post(auth::login))
.route("/auth/google", get(oauth::google_auth_url))
.route("/auth/google/callback", get(oauth::google_callback));
let protected_routes = Router::new()
.route("/auth/api-keys", post(auth::create_api_key))
.route("/auth/api-keys", get(auth::list_api_keys))
.route("/auth/api-keys/{id}", axum::routing::delete(auth::delete_api_key))
.route("/stripe/checkout", post(stripe::create_checkout))
.route("/teams", post(teams::create))
.route("/teams/{slug}", get(teams::get))
.route("/teams/{slug}/members", post(teams::add_member))
.route_layer(middleware::from_fn_with_state(state.clone(), jwt_middleware));
let stripe_webhook = Router::new()
.route("/stripe/webhook", post(stripe::webhook));
Router::new()
.nest("/api", api_routes)
.nest("/api", auth_routes)
.nest("/api", protected_routes)
.nest("/api", stripe_webhook)
.route("/metrics", get(|| async {
use prometheus::Encoder;
let encoder = prometheus::TextEncoder::new();
let mut buffer = vec![];
encoder.encode(&crate::metrics::REGISTRY.gather(), &mut buffer).unwrap();
String::from_utf8(buffer).unwrap()
}))
.route("/ws/logs", get(ws::live_logs))
.layer(middleware::from_fn_with_state(state.clone(), ip_rate_limit_middleware))
.layer(middleware::from_fn_with_state(state.clone(), correlation_id_middleware))
.layer(CorsLayer::permissive())
.with_state(state)
}

View File

@@ -0,0 +1,137 @@
use axum::{
extract::{Query, State},
http::StatusCode,
Json,
};
use db::repos::{oauth, users};
use serde::{Deserialize, Serialize};
use crate::state::AppState;
#[derive(Debug, Deserialize)]
pub struct GoogleCallback {
pub code: String,
}
#[derive(Debug, Serialize)]
pub struct GoogleAuthUrl {
pub url: String,
}
#[derive(Debug, Deserialize)]
struct GoogleTokenResponse {
access_token: String,
id_token: Option<String>,
}
#[derive(Debug, Deserialize)]
struct GoogleUserInfo {
sub: String,
email: String,
name: Option<String>,
picture: Option<String>,
}
pub async fn google_auth_url(State(_state): State<AppState>) -> Result<Json<GoogleAuthUrl>, StatusCode> {
let client_id = std::env::var("GOOGLE_CLIENT_ID").unwrap_or_default();
if client_id.is_empty() {
return Err(StatusCode::NOT_IMPLEMENTED);
}
let redirect_uri = std::env::var("GOOGLE_REDIRECT_URI")
.unwrap_or_else(|_| "http://localhost:3000/api/auth/google/callback".to_string());
let url = format!(
"https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=email%20profile&access_type=offline&prompt=consent",
client_id, redirect_uri
);
Ok(Json(GoogleAuthUrl { url }))
}
pub async fn google_callback(
State(state): State<AppState>,
Query(params): Query<GoogleCallback>,
) -> Result<Json<super::auth::AuthResponse>, StatusCode> {
let client_id = std::env::var("GOOGLE_CLIENT_ID").unwrap_or_default();
let client_secret = std::env::var("GOOGLE_CLIENT_SECRET").unwrap_or_default();
let redirect_uri = std::env::var("GOOGLE_REDIRECT_URI")
.unwrap_or_else(|_| "http://localhost:3000/api/auth/google/callback".to_string());
if client_id.is_empty() || client_secret.is_empty() {
// MVP fallback: create mock user
let email = format!("google_user_{}@example.com", &params.code[..8.min(params.code.len())]);
let user = match users::find_by_email(&state.db, &email).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? {
Some(u) => u,
None => {
let u = users::create(&state.db, &email, None, Some(&params.code)).await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let _ = oauth::create(&state.db, u.id, "google", &params.code).await;
u
}
};
let token = super::auth::create_jwt(&user.id.to_string(), &state.config.jwt_secret)?;
return Ok(Json(super::auth::AuthResponse { user, token }));
}
// Exchange code for token
let client = reqwest::Client::new();
let token_res = client
.post("https://oauth2.googleapis.com/token")
.form(&[
("code", params.code.as_str()),
("client_id", client_id.as_str()),
("client_secret", client_secret.as_str()),
("redirect_uri", redirect_uri.as_str()),
("grant_type", "authorization_code"),
])
.send()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.json::<GoogleTokenResponse>()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// Get user info
let user_info = client
.get("https://openidconnect.googleapis.com/v1/userinfo")
.header("Authorization", format!("Bearer {}", token_res.access_token))
.send()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.json::<GoogleUserInfo>()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// Find or create user
let user = match oauth::find_by_provider(&state.db, "google", &user_info.sub).await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
{
Some(oauth_account) => {
users::find_by_id(&state.db, oauth_account.user_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?
}
None => {
let user = match users::find_by_email(&state.db, &user_info.email).await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
{
Some(u) => u,
None => {
users::create(&state.db, &user_info.email, None, Some(&user_info.sub))
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
}
};
oauth::create(&state.db, user.id, "google", &user_info.sub)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
user
}
};
let token = super::auth::create_jwt(&user.id.to_string(), &state.config.jwt_secret)?;
Ok(Json(super::auth::AuthResponse { user, token }))
}

View File

@@ -0,0 +1,146 @@
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
Json,
};
use db::repos::{subscriptions, users};
use serde::{Deserialize, Serialize};
use crate::{middleware::jwt::JwtAuth, state::AppState};
use axum::Extension;
#[derive(Debug, Deserialize)]
pub struct CreateCheckoutRequest {
pub price_id: String,
}
#[derive(Debug, Serialize)]
pub struct CheckoutResponse {
pub checkout_url: String,
}
pub async fn create_checkout(
State(state): State<AppState>,
Extension(auth): Extension<JwtAuth>,
Json(body): Json<CreateCheckoutRequest>,
) -> Result<Json<CheckoutResponse>, StatusCode> {
let stripe_secret = std::env::var("STRIPE_SECRET_KEY").map_err(|_| StatusCode::NOT_IMPLEMENTED)?;
let user = users::find_by_id(&state.db, auth.user_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
// Create Stripe customer via HTTP API directly (simpler than SDK for MVP)
let client = reqwest::Client::new();
let customer_res = client
.post("https://api.stripe.com/v1/customers")
.basic_auth(&stripe_secret, Some(""))
.form(&[("email", &user.email)])
.send()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let customer_data: serde_json::Value = customer_res.json().await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let customer_id = customer_data["id"].as_str().unwrap_or("");
let success_url = std::env::var("STRIPE_SUCCESS_URL")
.unwrap_or_else(|_| "http://localhost:3000/dashboard?success=true".to_string());
let cancel_url = std::env::var("STRIPE_CANCEL_URL")
.unwrap_or_else(|_| "http://localhost:3000/dashboard?canceled=true".to_string());
let session_res = client
.post("https://api.stripe.com/v1/checkout/sessions")
.basic_auth(&stripe_secret, Some(""))
.form(&[
("customer", customer_id),
("success_url", &success_url),
("cancel_url", &cancel_url),
("mode", "subscription"),
("line_items[0][price]", &body.price_id),
("line_items[0][quantity]", "1"),
("metadata[user_id]", &auth.user_id.to_string()),
])
.send()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let session_data: serde_json::Value = session_res.json().await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let url = session_data["url"].as_str().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(CheckoutResponse { checkout_url: url.to_string() }))
}
#[derive(Debug, Deserialize)]
pub struct StripeWebhook {
#[serde(rename = "type")]
pub event_type: String,
pub data: StripeEventData,
}
#[derive(Debug, Deserialize)]
pub struct StripeEventData {
pub object: serde_json::Value,
}
pub async fn webhook(
State(state): State<AppState>,
headers: HeaderMap,
body: String,
) -> Result<StatusCode, StatusCode> {
let stripe_secret = std::env::var("STRIPE_WEBHOOK_SECRET").unwrap_or_default();
// Verify webhook signature if configured
if !stripe_secret.is_empty() {
let sig = headers
.get("stripe-signature")
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::BAD_REQUEST)?;
// In production, verify signature using Stripe library
// For MVP, we log and process
tracing::info!("Webhook signature: {}", sig);
}
let event: serde_json::Value = serde_json::from_str(&body).map_err(|_| StatusCode::BAD_REQUEST)?;
let event_type = event["type"].as_str().unwrap_or("");
match event_type {
"checkout.session.completed" => {
if let Some(metadata) = event["data"]["object"]["metadata"].as_object() {
if let Some(user_id_str) = metadata.get("user_id").and_then(|v| v.as_str()) {
if let Ok(user_id) = uuid::Uuid::parse_str(user_id_str) {
let customer_id = event["data"]["object"]["customer"].as_str().unwrap_or("");
let subscription_id = event["data"]["object"]["subscription"].as_str().unwrap_or("");
let _ = subscriptions::create_or_update(
&state.db,
user_id,
Some(customer_id),
Some(subscription_id),
None,
"active",
"paid",
).await;
tracing::info!("Subscription activated for user {}", user_id);
}
}
}
}
"invoice.payment_succeeded" => {
tracing::info!("Invoice payment succeeded");
}
"customer.subscription.deleted" => {
if let Some(sub_id) = event["data"]["object"]["id"].as_str() {
let _ = subscriptions::update_status(&state.db, sub_id, "canceled").await;
tracing::info!("Subscription {} canceled", sub_id);
}
}
_ => {}
}
Ok(StatusCode::OK)
}

View File

@@ -0,0 +1,95 @@
use axum::{
extract::{Json, Path, State},
http::StatusCode,
Extension,
};
use db::repos::teams;
use serde::{Deserialize, Serialize};
use shared::models::{Team, TeamMember};
use uuid::Uuid;
use crate::{middleware::jwt::JwtAuth, state::AppState};
#[derive(Debug, Deserialize)]
pub struct CreateTeamRequest {
pub name: String,
pub slug: String,
}
#[derive(Debug, Serialize)]
pub struct TeamResponse {
pub team: Team,
pub members: Vec<TeamMember>,
}
pub async fn create(
State(state): State<AppState>,
Extension(auth): Extension<JwtAuth>,
Json(body): Json<CreateTeamRequest>,
) -> Result<Json<Team>, StatusCode> {
let team = teams::create(&state.db, &body.name, &body.slug, auth.user_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// Add owner as member
let _ = teams::add_member(&state.db, team.id, auth.user_id, "owner").await;
Ok(Json(team))
}
pub async fn get(
State(state): State<AppState>,
Extension(auth): Extension<JwtAuth>,
Path(slug): Path<String>,
) -> Result<Json<TeamResponse>, StatusCode> {
let team = teams::find_by_slug(&state.db, &slug)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
// Check membership
let member = teams::find_member(&state.db, team.id, auth.user_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if member.is_none() {
return Err(StatusCode::FORBIDDEN);
}
let members = teams::list_members(&state.db, team.id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(TeamResponse { team, members }))
}
#[derive(Debug, Deserialize)]
pub struct AddMemberRequest {
pub user_id: Uuid,
#[serde(default)]
pub role: String,
}
pub async fn add_member(
State(state): State<AppState>,
Extension(auth): Extension<JwtAuth>,
Path(slug): Path<String>,
Json(body): Json<AddMemberRequest>,
) -> Result<Json<TeamMember>, StatusCode> {
let team = teams::find_by_slug(&state.db, &slug)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
// Only owner can add members
if team.owner_id != auth.user_id {
return Err(StatusCode::FORBIDDEN);
}
let role = if body.role.is_empty() { "member" } else { &body.role };
let member = teams::add_member(&state.db, team.id, body.user_id, role)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(member))
}

View File

@@ -0,0 +1,41 @@
use axum::{
extract::ws::{Message, WebSocket, WebSocketUpgrade},
response::IntoResponse,
};
pub async fn live_logs(ws: WebSocketUpgrade) -> impl IntoResponse {
ws.on_upgrade(handle_socket)
}
async fn handle_socket(mut socket: WebSocket) {
// Send initial connection message
let _ = socket.send(Message::Text(r#"{"type":"connected","message":"Live logs connected"}"#.to_string())).await;
// In a real implementation, this would subscribe to a Redis pub/sub channel
// and stream logs to the client. For MVP, we send a heartbeat.
let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
loop {
tokio::select! {
_ = interval.tick() => {
let msg = r#"{"type":"heartbeat","timestamp":""#.to_string()
+ &chrono::Utc::now().to_rfc3339()
+ "\"}";
if socket.send(Message::Text(msg)).await.is_err() {
break;
}
}
msg = socket.recv() => {
match msg {
Some(Ok(Message::Close(_))) | None => break,
Some(Ok(Message::Text(text))) => {
if text == "ping" {
let _ = socket.send(Message::Text("pong".to_string())).await;
}
}
_ => {}
}
}
}
}
}

View File

@@ -0,0 +1,69 @@
use shared::error::AppError;
pub async fn get_secret(key: &str) -> Result<String, AppError> {
// Priority 1: Environment variable (for local dev)
if let Ok(val) = std::env::var(key) {
return Ok(val);
}
// Priority 2: Try Vault
if let Ok(val) = get_vault_secret(key).await {
return Ok(val);
}
// Priority 3: Try AWS Secrets Manager
if let Ok(val) = get_aws_secret(key).await {
return Ok(val);
}
Err(AppError::Internal(format!("Secret {} not found in any provider", key)))
}
async fn get_vault_secret(key: &str) -> Result<String, AppError> {
let vault_addr = match std::env::var("VAULT_ADDR") {
Ok(addr) => addr,
Err(_) => return Err(AppError::Internal("VAULT_ADDR not set".to_string())),
};
let vault_token = match std::env::var("VAULT_TOKEN") {
Ok(token) => token,
Err(_) => return Err(AppError::Internal("VAULT_TOKEN not set".to_string())),
};
let client = reqwest::Client::new();
let response = client
.get(format!("{}/v1/secret/data/{}", vault_addr, key))
.header("X-Vault-Token", vault_token)
.send()
.await
.map_err(|e| AppError::Internal(format!("Vault request failed: {}", e)))?;
let data: serde_json::Value = response
.json()
.await
.map_err(|e| AppError::Internal(format!("Vault response parse failed: {}", e)))?;
data["data"]["data"][key]
.as_str()
.map(|s| s.to_string())
.ok_or_else(|| AppError::Internal(format!("Secret {} not found in Vault", key)))
}
async fn get_aws_secret(key: &str) -> Result<String, AppError> {
let secret_name = format!("crawlapi/{}", key.to_lowercase().replace('_', "/"));
let config = aws_config::from_env().load().await;
let client = aws_sdk_secretsmanager::Client::new(&config);
let response = client
.get_secret_value()
.secret_id(&secret_name)
.send()
.await
.map_err(|e| AppError::Internal(format!("AWS Secrets Manager error: {}", e)))?;
response
.secret_string()
.map(|s| s.to_string())
.ok_or_else(|| AppError::Internal(format!("Secret {} not found in AWS", key)))
}

13
crates/api/src/state.rs Normal file
View File

@@ -0,0 +1,13 @@
use aws_sdk_s3::Client as S3Client;
use db::DbPool;
use redis::aio::MultiplexedConnection;
use shared::config::AppConfig;
use std::sync::Arc;
#[derive(Clone)]
pub struct AppState {
pub config: Arc<AppConfig>,
pub db: DbPool,
pub redis: MultiplexedConnection,
pub s3: S3Client,
}

View File

@@ -0,0 +1,54 @@
use aws_sdk_s3::Client as S3Client;
use shared::error::AppError;
use uuid::Uuid;
pub async fn upload_file(
s3: &S3Client,
bucket: &str,
key: &str,
content_type: &str,
data: Vec<u8>,
) -> Result<String, AppError> {
s3.put_object()
.bucket(bucket)
.key(key)
.content_type(content_type)
.body(data.into())
.send()
.await
.map_err(|e| AppError::S3(e.to_string()))?;
Ok(format!("{}/{}", bucket, key))
}
pub async fn ensure_bucket_exists(s3: &S3Client, bucket: &str) -> Result<(), AppError> {
let exists = s3
.head_bucket()
.bucket(bucket)
.send()
.await
.is_ok();
if !exists {
s3.create_bucket()
.bucket(bucket)
.send()
.await
.map_err(|e| AppError::S3(e.to_string()))?;
}
Ok(())
}
pub fn generate_file_key(endpoint: &str, ext: &str) -> String {
let id = Uuid::new_v4().to_string().replace('-', "");
match endpoint {
"screenshot" => format!("screenshots/{}.png", id),
"pdf" => format!("pdfs/{}.pdf", id),
_ => format!("files/{}.{}", id, ext),
}
}
pub fn get_public_url(_endpoint: &str, s3_endpoint: &str, bucket: &str, key: &str) -> String {
format!("{}/{}/{}", s3_endpoint, bucket, key)
}

View File

@@ -0,0 +1,62 @@
use shared::error::AppError;
pub fn validate_url(url: &str) -> Result<(), AppError> {
let parsed = url::Url::parse(url).map_err(|_| AppError::InvalidUrl(url.to_string()))?;
// Only allow http and https
if parsed.scheme() != "http" && parsed.scheme() != "https" {
return Err(AppError::InvalidUrl("Only HTTP and HTTPS URLs are allowed".to_string()));
}
// Block private IP ranges
if let Some(host) = parsed.host_str() {
if host == "localhost" || host == "127.0.0.1" || host.starts_with("10.") || host.starts_with("192.168.") {
return Err(AppError::InvalidUrl("Private IP addresses are not allowed".to_string()));
}
if host.starts_with("172.") {
if let Some(seg) = host.split('.').nth(1) {
if let Ok(n) = seg.parse::<u8>() {
if n >= 16 && n <= 31 {
return Err(AppError::InvalidUrl("Private IP addresses are not allowed".to_string()));
}
}
}
}
}
// Block file://, ftp://, etc.
if parsed.scheme() == "file" {
return Err(AppError::InvalidUrl("File URLs are not allowed".to_string()));
}
Ok(())
}
pub fn validate_webhook_url(url: &str) -> Result<(), AppError> {
let parsed = url::Url::parse(url).map_err(|_| AppError::InvalidUrl(url.to_string()))?;
// Only allow http and https
if parsed.scheme() != "http" && parsed.scheme() != "https" {
return Err(AppError::InvalidUrl("Webhook must use HTTP or HTTPS".to_string()));
}
// Block private IPs and localhost for webhooks
if let Some(host) = parsed.host_str() {
if host == "localhost" || host == "127.0.0.1" || host.starts_with("10.") || host.starts_with("192.168.") {
return Err(AppError::InvalidUrl("Webhook cannot point to private addresses".to_string()));
}
}
Ok(())
}
pub fn validate_size(content: &[u8], max_mb: usize) -> Result<(), AppError> {
let max_bytes = max_mb * 1024 * 1024;
if content.len() > max_bytes {
return Err(AppError::BadRequest(format!(
"Content exceeds maximum size of {}MB",
max_mb
)));
}
Ok(())
}

View File

@@ -0,0 +1,29 @@
use serde_json::json;
#[tokio::test]
async fn test_health_check() {
// This is a placeholder for integration tests
// In a real setup, you would spawn the API server and make HTTP requests
assert!(true);
}
#[tokio::test]
async fn test_crawl_request_validation() {
let req = shared::models::CrawlRequest {
url: "https://example.com".to_string(),
options: shared::models::CrawlOptions::default(),
};
assert_eq!(req.url, "https://example.com");
}
#[tokio::test]
async fn test_api_response_format() {
let response = shared::api::ApiResponse::ok(json!({"test": true}));
assert!(response.success);
assert!(response.data.is_some());
assert!(response.error.is_none());
let error = shared::api::ApiResponse::<()>::err("Something went wrong");
assert!(!error.success);
assert!(error.error.is_some());
}

17
crates/db/Cargo.toml Normal file
View File

@@ -0,0 +1,17 @@
[package]
name = "db"
version = "0.1.0"
edition = "2021"
[dependencies]
shared = { path = "../shared" }
sqlx = { workspace = true }
tokio = { workspace = true }
uuid = { workspace = true }
chrono = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
serde_json = { workspace = true }

View File

@@ -0,0 +1,38 @@
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE TABLE users (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
email VARCHAR(255) UNIQUE NOT NULL,
password_hash VARCHAR(255),
google_id VARCHAR(255) UNIQUE,
credits BIGINT NOT NULL DEFAULT 30,
tier VARCHAR(50) NOT NULL DEFAULT 'free',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE api_keys (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
key_hash VARCHAR(255) UNIQUE NOT NULL,
name VARCHAR(255) NOT NULL DEFAULT 'Default',
last_used_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE usage_logs (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
api_key_id UUID NOT NULL REFERENCES api_keys(id) ON DELETE CASCADE,
endpoint VARCHAR(100) NOT NULL,
url TEXT NOT NULL,
status VARCHAR(50) NOT NULL,
credits_used BIGINT NOT NULL DEFAULT 1,
duration_ms BIGINT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_api_keys_user_id ON api_keys(user_id);
CREATE INDEX idx_api_keys_key_hash ON api_keys(key_hash);
CREATE INDEX idx_usage_logs_user_id ON usage_logs(user_id);
CREATE INDEX idx_usage_logs_created_at ON usage_logs(created_at);

View File

@@ -0,0 +1,27 @@
CREATE TABLE oauth_accounts (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
provider VARCHAR(50) NOT NULL,
provider_account_id VARCHAR(255) NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE(provider, provider_account_id)
);
CREATE INDEX idx_oauth_accounts_user_id ON oauth_accounts(user_id);
CREATE TABLE subscriptions (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
stripe_customer_id VARCHAR(255),
stripe_subscription_id VARCHAR(255),
stripe_price_id VARCHAR(255),
status VARCHAR(50) NOT NULL DEFAULT 'incomplete',
tier VARCHAR(50) NOT NULL DEFAULT 'free',
current_period_start TIMESTAMPTZ,
current_period_end TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_subscriptions_user_id ON subscriptions(user_id);
CREATE INDEX idx_subscriptions_stripe_customer ON subscriptions(stripe_customer_id);

View File

@@ -0,0 +1,21 @@
CREATE TABLE teams (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
name VARCHAR(255) NOT NULL,
slug VARCHAR(255) UNIQUE NOT NULL,
owner_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE team_members (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
team_id UUID NOT NULL REFERENCES teams(id) ON DELETE CASCADE,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
role VARCHAR(50) NOT NULL DEFAULT 'member',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE(team_id, user_id)
);
CREATE INDEX idx_teams_owner ON teams(owner_id);
CREATE INDEX idx_team_members_team ON team_members(team_id);
CREATE INDEX idx_team_members_user ON team_members(user_id);

View File

@@ -0,0 +1,7 @@
use sqlx::PgPool;
pub type DbPool = PgPool;
pub async fn create_pool(database_url: &str) -> Result<DbPool, sqlx::Error> {
PgPool::connect(database_url).await
}

4
crates/db/src/lib.rs Normal file
View File

@@ -0,0 +1,4 @@
pub mod connection;
pub mod repos;
pub use connection::DbPool;

View File

@@ -0,0 +1,64 @@
use shared::models::ApiKey;
use sqlx::PgPool;
use uuid::Uuid;
pub async fn find_by_key_hash(pool: &PgPool, key_hash: &str) -> Result<Option<ApiKey>, sqlx::Error> {
sqlx::query_as::<_, ApiKey>(
r#"SELECT id, user_id, key_hash, name, last_used_at, created_at FROM api_keys WHERE key_hash = $1"#,
)
.bind(key_hash)
.fetch_optional(pool)
.await
}
pub async fn create(
pool: &PgPool,
user_id: Uuid,
key_hash: &str,
name: &str,
) -> Result<ApiKey, sqlx::Error> {
sqlx::query_as::<_, ApiKey>(
r#"INSERT INTO api_keys (id, user_id, key_hash, name)
VALUES ($1, $2, $3, $4)
RETURNING id, user_id, key_hash, name, last_used_at, created_at"#,
)
.bind(Uuid::new_v4())
.bind(user_id)
.bind(key_hash)
.bind(name)
.fetch_one(pool)
.await
}
pub async fn list_by_user(pool: &PgPool, user_id: Uuid) -> Result<Vec<ApiKey>, sqlx::Error> {
sqlx::query_as::<_, ApiKey>(
r#"SELECT id, user_id, key_hash, name, last_used_at, created_at FROM api_keys WHERE user_id = $1 ORDER BY created_at DESC"#,
)
.bind(user_id)
.fetch_all(pool)
.await
}
pub async fn update_last_used(pool: &PgPool, id: Uuid) -> Result<(), sqlx::Error> {
sqlx::query(
r#"UPDATE api_keys SET last_used_at = $1 WHERE id = $2"#,
)
.bind(chrono::Utc::now())
.bind(id)
.execute(pool)
.await?;
Ok(())
}
pub async fn delete_by_id(pool: &PgPool, id: Uuid, user_id: Uuid) -> Result<bool, sqlx::Error> {
let result = sqlx::query(
r#"DELETE FROM api_keys WHERE id = $1 AND user_id = $2"#,
)
.bind(id)
.bind(user_id)
.execute(pool)
.await?;
Ok(result.rows_affected() > 0)
}

View File

@@ -0,0 +1,6 @@
pub mod api_keys;
pub mod oauth;
pub mod subscriptions;
pub mod teams;
pub mod usage_logs;
pub mod users;

View File

@@ -0,0 +1,37 @@
use shared::models::OAuthAccount;
use sqlx::PgPool;
use uuid::Uuid;
pub async fn find_by_provider(
pool: &PgPool,
provider: &str,
provider_account_id: &str,
) -> Result<Option<OAuthAccount>, sqlx::Error> {
sqlx::query_as::<_, OAuthAccount>(
r#"SELECT id, user_id, provider, provider_account_id, created_at
FROM oauth_accounts WHERE provider = $1 AND provider_account_id = $2"#,
)
.bind(provider)
.bind(provider_account_id)
.fetch_optional(pool)
.await
}
pub async fn create(
pool: &PgPool,
user_id: Uuid,
provider: &str,
provider_account_id: &str,
) -> Result<OAuthAccount, sqlx::Error> {
sqlx::query_as::<_, OAuthAccount>(
r#"INSERT INTO oauth_accounts (id, user_id, provider, provider_account_id)
VALUES ($1, $2, $3, $4)
RETURNING id, user_id, provider, provider_account_id, created_at"#,
)
.bind(Uuid::new_v4())
.bind(user_id)
.bind(provider)
.bind(provider_account_id)
.fetch_one(pool)
.await
}

View File

@@ -0,0 +1,76 @@
use shared::models::Subscription;
use sqlx::PgPool;
use uuid::Uuid;
pub async fn find_by_user(pool: &PgPool, user_id: Uuid) -> Result<Option<Subscription>, sqlx::Error> {
sqlx::query_as::<_, Subscription>(
r#"SELECT id, user_id, stripe_customer_id, stripe_subscription_id, stripe_price_id,
status, tier, current_period_start, current_period_end, created_at, updated_at
FROM subscriptions WHERE user_id = $1 ORDER BY created_at DESC LIMIT 1"#,
)
.bind(user_id)
.fetch_optional(pool)
.await
}
pub async fn find_by_stripe_subscription(
pool: &PgPool,
stripe_subscription_id: &str,
) -> Result<Option<Subscription>, sqlx::Error> {
sqlx::query_as::<_, Subscription>(
r#"SELECT id, user_id, stripe_customer_id, stripe_subscription_id, stripe_price_id,
status, tier, current_period_start, current_period_end, created_at, updated_at
FROM subscriptions WHERE stripe_subscription_id = $1"#,
)
.bind(stripe_subscription_id)
.fetch_optional(pool)
.await
}
pub async fn create_or_update(
pool: &PgPool,
user_id: Uuid,
stripe_customer_id: Option<&str>,
stripe_subscription_id: Option<&str>,
stripe_price_id: Option<&str>,
status: &str,
tier: &str,
) -> Result<Subscription, sqlx::Error> {
sqlx::query_as::<_, Subscription>(
r#"INSERT INTO subscriptions (id, user_id, stripe_customer_id, stripe_subscription_id, stripe_price_id, status, tier)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (user_id) DO UPDATE SET
stripe_customer_id = EXCLUDED.stripe_customer_id,
stripe_subscription_id = EXCLUDED.stripe_subscription_id,
stripe_price_id = EXCLUDED.stripe_price_id,
status = EXCLUDED.status,
tier = EXCLUDED.tier,
updated_at = NOW()
RETURNING id, user_id, stripe_customer_id, stripe_subscription_id, stripe_price_id,
status, tier, current_period_start, current_period_end, created_at, updated_at"#,
)
.bind(Uuid::new_v4())
.bind(user_id)
.bind(stripe_customer_id)
.bind(stripe_subscription_id)
.bind(stripe_price_id)
.bind(status)
.bind(tier)
.fetch_one(pool)
.await
}
pub async fn update_status(
pool: &PgPool,
stripe_subscription_id: &str,
status: &str,
) -> Result<(), sqlx::Error> {
sqlx::query(
r#"UPDATE subscriptions SET status = $1, updated_at = NOW() WHERE stripe_subscription_id = $2"#,
)
.bind(status)
.bind(stripe_subscription_id)
.execute(pool)
.await?;
Ok(())
}

View File

@@ -0,0 +1,68 @@
use shared::models::{Team, TeamMember};
use sqlx::PgPool;
use uuid::Uuid;
pub async fn create(pool: &PgPool, name: &str, slug: &str, owner_id: Uuid) -> Result<Team, sqlx::Error> {
sqlx::query_as::<_, Team>(
r#"INSERT INTO teams (id, name, slug, owner_id)
VALUES ($1, $2, $3, $4)
RETURNING id, name, slug, owner_id, created_at, updated_at"#,
)
.bind(Uuid::new_v4())
.bind(name)
.bind(slug)
.bind(owner_id)
.fetch_one(pool)
.await
}
pub async fn find_by_slug(pool: &PgPool, slug: &str) -> Result<Option<Team>, sqlx::Error> {
sqlx::query_as::<_, Team>(
r#"SELECT id, name, slug, owner_id, created_at, updated_at FROM teams WHERE slug = $1"#,
)
.bind(slug)
.fetch_optional(pool)
.await
}
pub async fn find_by_id(pool: &PgPool, id: Uuid) -> Result<Option<Team>, sqlx::Error> {
sqlx::query_as::<_, Team>(
r#"SELECT id, name, slug, owner_id, created_at, updated_at FROM teams WHERE id = $1"#,
)
.bind(id)
.fetch_optional(pool)
.await
}
pub async fn add_member(pool: &PgPool, team_id: Uuid, user_id: Uuid, role: &str) -> Result<TeamMember, sqlx::Error> {
sqlx::query_as::<_, TeamMember>(
r#"INSERT INTO team_members (id, team_id, user_id, role)
VALUES ($1, $2, $3, $4)
RETURNING id, team_id, user_id, role, created_at"#,
)
.bind(Uuid::new_v4())
.bind(team_id)
.bind(user_id)
.bind(role)
.fetch_one(pool)
.await
}
pub async fn list_members(pool: &PgPool, team_id: Uuid) -> Result<Vec<TeamMember>, sqlx::Error> {
sqlx::query_as::<_, TeamMember>(
r#"SELECT id, team_id, user_id, role, created_at FROM team_members WHERE team_id = $1"#,
)
.bind(team_id)
.fetch_all(pool)
.await
}
pub async fn find_member(pool: &PgPool, team_id: Uuid, user_id: Uuid) -> Result<Option<TeamMember>, sqlx::Error> {
sqlx::query_as::<_, TeamMember>(
r#"SELECT id, team_id, user_id, role, created_at FROM team_members WHERE team_id = $1 AND user_id = $2"#,
)
.bind(team_id)
.bind(user_id)
.fetch_optional(pool)
.await
}

View File

@@ -0,0 +1,47 @@
use shared::models::UsageLog;
use sqlx::PgPool;
use uuid::Uuid;
pub async fn create(
pool: &PgPool,
user_id: Uuid,
api_key_id: Uuid,
endpoint: &str,
url: &str,
status: &str,
credits_used: i64,
duration_ms: i64,
) -> Result<UsageLog, sqlx::Error> {
sqlx::query_as::<_, UsageLog>(
r#"INSERT INTO usage_logs (id, user_id, api_key_id, endpoint, url, status, credits_used, duration_ms)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, user_id, api_key_id, endpoint, url, status, credits_used, duration_ms, created_at"#,
)
.bind(Uuid::new_v4())
.bind(user_id)
.bind(api_key_id)
.bind(endpoint)
.bind(url)
.bind(status)
.bind(credits_used)
.bind(duration_ms)
.fetch_one(pool)
.await
}
pub async fn list_by_user(
pool: &PgPool,
user_id: Uuid,
limit: i64,
offset: i64,
) -> Result<Vec<UsageLog>, sqlx::Error> {
sqlx::query_as::<_, UsageLog>(
r#"SELECT id, user_id, api_key_id, endpoint, url, status, credits_used, duration_ms, created_at
FROM usage_logs WHERE user_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3"#,
)
.bind(user_id)
.bind(limit)
.bind(offset)
.fetch_all(pool)
.await
}

View File

@@ -0,0 +1,64 @@
use shared::models::User;
use sqlx::PgPool;
use uuid::Uuid;
pub async fn find_by_email(pool: &PgPool, email: &str) -> Result<Option<User>, sqlx::Error> {
sqlx::query_as::<_, User>(
r#"SELECT id, email, password_hash, google_id, credits, tier, created_at, updated_at FROM users WHERE email = $1"#,
)
.bind(email)
.fetch_optional(pool)
.await
}
pub async fn find_by_id(pool: &PgPool, id: Uuid) -> Result<Option<User>, sqlx::Error> {
sqlx::query_as::<_, User>(
r#"SELECT id, email, password_hash, google_id, credits, tier, created_at, updated_at FROM users WHERE id = $1"#,
)
.bind(id)
.fetch_optional(pool)
.await
}
pub async fn create(
pool: &PgPool,
email: &str,
password_hash: Option<&str>,
google_id: Option<&str>,
) -> Result<User, sqlx::Error> {
sqlx::query_as::<_, User>(
r#"INSERT INTO users (id, email, password_hash, google_id, credits, tier)
VALUES ($1, $2, $3, $4, 30, 'free')
RETURNING id, email, password_hash, google_id, credits, tier, created_at, updated_at"#,
)
.bind(Uuid::new_v4())
.bind(email)
.bind(password_hash)
.bind(google_id)
.fetch_one(pool)
.await
}
pub async fn deduct_credits(pool: &PgPool, user_id: Uuid, amount: i64) -> Result<bool, sqlx::Error> {
let result = sqlx::query(
r#"UPDATE users SET credits = credits - $1 WHERE id = $2 AND credits >= $1"#,
)
.bind(amount)
.bind(user_id)
.execute(pool)
.await?;
Ok(result.rows_affected() > 0)
}
pub async fn add_credits(pool: &PgPool, user_id: Uuid, amount: i64) -> Result<(), sqlx::Error> {
sqlx::query(
r#"UPDATE users SET credits = credits + $1 WHERE id = $2"#,
)
.bind(amount)
.bind(user_id)
.execute(pool)
.await?;
Ok(())
}

View File

@@ -0,0 +1,18 @@
use shared::models::User;
#[test]
fn test_user_model_serialization() {
let user = User {
id: uuid::Uuid::new_v4(),
email: "test@example.com".to_string(),
password_hash: Some("hash".to_string()),
google_id: None,
credits: 30,
tier: "free".to_string(),
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
};
let json = serde_json::to_string(&user).unwrap();
assert!(json.contains("test@example.com"));
}

15
crates/shared/Cargo.toml Normal file
View File

@@ -0,0 +1,15 @@
[package]
name = "shared"
version = "0.1.0"
edition = "2021"
[dependencies]
serde = { workspace = true }
serde_json = { workspace = true }
uuid = { workspace = true }
chrono = { workspace = true }
thiserror = { workspace = true }
url = { workspace = true }
regex = { workspace = true }
config = { workspace = true }
sqlx = { workspace = true }

26
crates/shared/src/api.rs Normal file
View File

@@ -0,0 +1,26 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiResponse<T> {
pub success: bool,
pub data: Option<T>,
pub error: Option<String>,
}
impl<T> ApiResponse<T> {
pub fn ok(data: T) -> Self {
Self {
success: true,
data: Some(data),
error: None,
}
}
pub fn err(msg: impl Into<String>) -> Self {
Self {
success: false,
data: None,
error: Some(msg.into()),
}
}
}

View File

@@ -0,0 +1,25 @@
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct AppConfig {
pub database_url: String,
pub redis_url: String,
pub jwt_secret: String,
pub s3_endpoint: String,
pub s3_bucket: String,
pub s3_region: String,
pub s3_access_key: String,
pub s3_secret_key: String,
pub app_port: u16,
pub app_host: String,
pub playwright_script_path: String,
}
impl AppConfig {
pub fn from_env() -> Result<Self, config::ConfigError> {
config::Config::builder()
.add_source(config::Environment::default())
.build()?
.try_deserialize()
}
}

View File

@@ -0,0 +1,41 @@
use thiserror::Error;
#[derive(Debug, Error)]
pub enum AppError {
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Redis error: {0}")]
Redis(String),
#[error("S3 error: {0}")]
S3(String),
#[error("Invalid URL: {0}")]
InvalidUrl(String),
#[error("Browser automation failed: {0}")]
BrowserError(String),
#[error("Rate limit exceeded")]
RateLimit,
#[error("Insufficient credits")]
InsufficientCredits,
#[error("Unauthorized")]
Unauthorized,
#[error("Not found")]
NotFound,
#[error("Bad request: {0}")]
BadRequest(String),
#[error("Internal error: {0}")]
Internal(String),
}
impl AppError {
pub fn status_code(&self) -> u16 {
match self {
AppError::InvalidUrl(_) | AppError::BadRequest(_) => 400,
AppError::Unauthorized => 401,
AppError::InsufficientCredits => 403,
AppError::NotFound => 404,
AppError::RateLimit => 429,
AppError::BrowserError(_) => 500,
_ => 500,
}
}
}

24
crates/shared/src/jobs.rs Normal file
View File

@@ -0,0 +1,24 @@
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::models::CrawlOptions;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrawlJob {
pub job_id: Uuid,
pub user_id: Uuid,
pub api_key_id: Uuid,
pub endpoint: String,
pub url: String,
pub options: CrawlOptions,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrawlResult {
pub job_id: Uuid,
pub success: bool,
pub data: Option<serde_json::Value>,
pub error: Option<String>,
pub duration_ms: i64,
pub file_url: Option<String>,
}

6
crates/shared/src/lib.rs Normal file
View File

@@ -0,0 +1,6 @@
pub mod api;
pub mod config;
pub mod error;
pub mod jobs;
pub mod models;
pub mod queue;

136
crates/shared/src/models.rs Normal file
View File

@@ -0,0 +1,136 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct User {
pub id: Uuid,
pub email: String,
pub password_hash: Option<String>,
pub google_id: Option<String>,
pub credits: i64,
pub tier: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct ApiKey {
pub id: Uuid,
pub user_id: Uuid,
pub key_hash: String,
pub name: String,
pub last_used_at: Option<DateTime<Utc>>,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct OAuthAccount {
pub id: Uuid,
pub user_id: Uuid,
pub provider: String,
pub provider_account_id: String,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct Subscription {
pub id: Uuid,
pub user_id: Uuid,
pub stripe_customer_id: Option<String>,
pub stripe_subscription_id: Option<String>,
pub stripe_price_id: Option<String>,
pub status: String,
pub tier: String,
pub current_period_start: Option<DateTime<Utc>>,
pub current_period_end: Option<DateTime<Utc>>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct UsageLog {
pub id: Uuid,
pub user_id: Uuid,
pub api_key_id: Uuid,
pub endpoint: String,
pub url: String,
pub status: String,
pub credits_used: i64,
pub duration_ms: i64,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct Team {
pub id: Uuid,
pub name: String,
pub slug: String,
pub owner_id: Uuid,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct TeamMember {
pub id: Uuid,
pub team_id: Uuid,
pub user_id: Uuid,
pub role: String,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrawlRequest {
pub url: String,
#[serde(default)]
pub options: CrawlOptions,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CrawlOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub full_page: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub width: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub height: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub wait_for: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user_agent: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub selectors: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub include_html: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub webhook_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<std::collections::HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mobile: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scroll_to_bottom: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stealth: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_proxy: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub solve_captcha: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrawlResponse {
pub success: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub calls_remaining: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}

View File

@@ -0,0 +1,27 @@
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::models::CrawlOptions;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Job {
pub id: Uuid,
pub user_id: Uuid,
pub api_key_id: Uuid,
pub endpoint: String,
pub url: String,
pub options: CrawlOptions,
pub webhook_url: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JobResult {
pub id: Uuid,
pub success: bool,
pub data: Option<serde_json::Value>,
pub error: Option<String>,
pub duration_ms: i64,
}
pub const QUEUE_NAME: &str = "crawlapi:jobs";
pub const RESULT_PREFIX: &str = "crawlapi:results:";

26
crates/worker/Cargo.toml Normal file
View File

@@ -0,0 +1,26 @@
[package]
name = "worker"
version = "0.1.0"
edition = "2021"
[dependencies]
shared = { path = "../shared" }
db = { path = "../db" }
tokio = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
redis = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["json", "env-filter"] }
chrono = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
aws-config = { workspace = true }
aws-sdk-s3 = { workspace = true }
config = { workspace = true }
tokio-util = { workspace = true }
futures = { workspace = true }
uuid = { workspace = true }
reqwest = { workspace = true }
sentry = "0.36"
sqlx = { workspace = true }

230
crates/worker/src/main.rs Normal file
View File

@@ -0,0 +1,230 @@
use chrono::Utc;
use db::connection::create_pool;
use redis::AsyncCommands;
use shared::{
config::AppConfig,
queue::{Job, JobResult, QUEUE_NAME, RESULT_PREFIX},
};
use std::time::{Duration, Instant};
use tokio::process::Command;
use tokio::time::sleep;
use tracing::{info_span, Instrument};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let sentry_dsn = std::env::var("SENTRY_DSN").ok();
let _guard = sentry_dsn.map(|dsn| {
sentry::init((dsn, sentry::ClientOptions {
release: sentry::release_name!(),
..Default::default()
}))
});
let json_logging = std::env::var("JSON_LOGGING").unwrap_or_else(|_| "false".to_string()) == "true";
if json_logging {
tracing_subscriber::registry()
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| "worker=debug".into()))
.with(
tracing_subscriber::fmt::layer()
.json()
.with_current_span(true)
.with_span_list(true)
.with_target(true),
)
.init();
} else {
tracing_subscriber::registry()
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| "worker=debug".into()))
.with(tracing_subscriber::fmt::layer())
.init();
}
let config = AppConfig::from_env()?;
let db = create_pool(&config.database_url).await?;
let redis_client = redis::Client::open(config.redis_url.clone())?;
let mut redis_conn = redis_client.get_multiplexed_tokio_connection().await?;
tracing::info!("Worker started. Waiting for jobs...");
loop {
let job_json: Option<(String, String)> = redis::cmd("BLPOP")
.arg(QUEUE_NAME)
.arg(5)
.query_async(&mut redis_conn)
.await?;
if let Some((_, json)) = job_json {
let job: Job = match serde_json::from_str(&json) {
Ok(j) => j,
Err(e) => {
tracing::error!("Failed to deserialize job: {}", e);
continue;
}
};
let span = info_span!(
"process_job",
job_id = %job.id,
user_id = %job.user_id,
endpoint = %job.endpoint,
url = %job.url,
);
process_single_job(&config, &db, &mut redis_conn, &job)
.instrument(span)
.await;
}
}
}
async fn process_single_job(
config: &AppConfig,
db: &sqlx::PgPool,
redis_conn: &mut redis::aio::MultiplexedConnection,
job: &Job,
) {
tracing::info!("Processing job {}: {} {}", job.id, job.endpoint, job.url);
let start = Instant::now();
let result = process_job_with_retry(config, job).await;
let duration = start.elapsed().as_millis() as i64;
let job_result = match result {
Ok(data) => JobResult {
id: job.id,
success: true,
data: Some(data),
error: None,
duration_ms: duration,
},
Err(e) => {
tracing::error!("Job {} failed after retries: {}", job.id, e);
JobResult {
id: job.id,
success: false,
data: None,
error: Some(e.clone()),
duration_ms: duration,
}
}
};
let result_json = serde_json::to_string(&job_result).unwrap();
let result_key = format!("{}{}", RESULT_PREFIX, job.id);
let _: () = redis_conn.set_ex(&result_key, result_json, 300).await.unwrap_or(());
let status = if job_result.success { "success" } else { "error" };
let _ = db::repos::usage_logs::create(
db,
job.user_id,
job.api_key_id,
&job.endpoint,
&job.url,
status,
1,
duration,
)
.await;
if let Some(webhook_url) = &job.webhook_url {
let _ = send_webhook(webhook_url, &job_result).await;
}
if !job_result.success {
let dlq_key = format!("crawlapi:dlq:{}", job.id);
let dlq_data = serde_json::json!({
"job": job,
"error": job_result.error,
"failed_at": Utc::now().to_rfc3339(),
});
let _: () = redis_conn.set_ex(dlq_key, dlq_data.to_string(), 86400).await.unwrap_or(());
tracing::warn!("Job {} moved to DLQ", job.id);
}
tracing::info!("Job {} completed in {}ms", job.id, duration);
}
async fn process_job_with_retry(config: &AppConfig, job: &Job) -> Result<serde_json::Value, String> {
let max_retries = 3;
let mut last_error = String::new();
for attempt in 0..max_retries {
if attempt > 0 {
let backoff = Duration::from_secs(2_u64.pow(attempt as u32));
tracing::info!(
"Retrying job {} (attempt {}/{}), waiting {:?}",
job.id,
attempt + 1,
max_retries,
backoff
);
sleep(backoff).await;
}
match process_job(config, job).await {
Ok(data) => return Ok(data),
Err(e) => {
last_error = e;
tracing::warn!(
"Job {} attempt {}/{} failed: {}",
job.id,
attempt + 1,
max_retries,
last_error
);
}
}
}
Err(format!("Failed after {} retries: {}", max_retries, last_error))
}
async fn process_job(config: &AppConfig, job: &Job) -> Result<serde_json::Value, String> {
let script_path = &config.playwright_script_path;
let mut cmd = Command::new("node");
cmd.arg(script_path)
.arg(&job.endpoint)
.arg(serde_json::to_string(&job.url).unwrap())
.arg(serde_json::to_string(&job.options).unwrap())
.env("OUTPUT_DIR", "/tmp/crawlapi")
.env("BROWSER_POOL_SIZE", std::env::var("BROWSER_POOL_SIZE").unwrap_or_else(|_| "5".to_string()))
.env("MAX_PAGES_PER_BROWSER", std::env::var("MAX_PAGES_PER_BROWSER").unwrap_or_else(|_| "10".to_string()));
if let Ok(proxy_url) = std::env::var("PROXY_URL") {
cmd.env("PROXY_URL", proxy_url);
}
if let Ok(captcha_key) = std::env::var("CAPTCHA_API_KEY") {
cmd.env("CAPTCHA_API_KEY", captcha_key);
}
let output = cmd.output()
.await
.map_err(|e| format!("Failed to execute browser: {}", e))?;
if !output.status.success() {
return Err(format!("Browser error: {}", String::from_utf8_lossy(&output.stderr)));
}
let stdout = String::from_utf8_lossy(&output.stdout);
let result: serde_json::Value = serde_json::from_str(&stdout)
.map_err(|e| format!("Invalid JSON from browser: {} | output: {}", e, stdout))?;
Ok(result)
}
async fn send_webhook(url: &str, result: &JobResult) -> Result<(), reqwest::Error> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()?;
let _ = client
.post(url)
.json(result)
.send()
.await?;
Ok(())
}