Initial commit: Full Crawl API implementation
This commit is contained in:
39
crates/api/Cargo.toml
Normal file
39
crates/api/Cargo.toml
Normal file
@@ -0,0 +1,39 @@
|
||||
[package]
|
||||
name = "api"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
shared = { path = "../shared" }
|
||||
db = { path = "../db" }
|
||||
axum = { workspace = true, features = ["ws"] }
|
||||
tokio = { workspace = true }
|
||||
tower = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
redis = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
aws-config = { workspace = true }
|
||||
aws-sdk-s3 = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
bcrypt = { workspace = true }
|
||||
config = { workspace = true }
|
||||
argon2 = { workspace = true }
|
||||
url = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
scraper = { workspace = true }
|
||||
markdown = { workspace = true }
|
||||
md5 = "0.7"
|
||||
prometheus = "0.13"
|
||||
lazy_static = "1.5"
|
||||
sentry = "0.36"
|
||||
async-stripe = { version = "1.0.0-rc.5", features = ["default-tls"] }
|
||||
aws-sdk-secretsmanager = "1.0"
|
||||
51
crates/api/src/bin/seed.rs
Normal file
51
crates/api/src/bin/seed.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use db::connection::create_pool;
|
||||
use db::repos::{api_keys, users};
|
||||
use shared::config::AppConfig;
|
||||
use std::sync::Arc;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
let config = Arc::new(AppConfig::from_env()?);
|
||||
let db = create_pool(&config.database_url).await?;
|
||||
|
||||
// Create test user
|
||||
let email = "demo@crawlapi.dev";
|
||||
let password = "demo123456";
|
||||
let password_hash = bcrypt::hash(password, bcrypt::DEFAULT_COST)?;
|
||||
|
||||
let user = match users::find_by_email(&db, email).await? {
|
||||
Some(u) => {
|
||||
tracing::info!("User already exists: {}", u.id);
|
||||
u
|
||||
}
|
||||
None => {
|
||||
let u = users::create(&db, email, Some(&password_hash), None).await?;
|
||||
tracing::info!("Created user: {} with 30 free credits", u.id);
|
||||
u
|
||||
}
|
||||
};
|
||||
|
||||
// Create API key
|
||||
let api_key = format!("crawlapi_demo_{}", Uuid::new_v4().to_string().replace('-', ""));
|
||||
let key_hash = format!("{:x}", md5::compute(&api_key));
|
||||
|
||||
let key = api_keys::create(&db, user.id, &key_hash, "Demo Key").await?;
|
||||
tracing::info!("Created API key: {} (id: {})", api_key, key.id);
|
||||
|
||||
println!("\n========================================");
|
||||
println!("SEED DATA CREATED");
|
||||
println!("========================================");
|
||||
println!("Email: {}", email);
|
||||
println!("Password: {}", password);
|
||||
println!("API Key: {}", api_key);
|
||||
println!("Credits: {}", user.credits);
|
||||
println!("========================================\n");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
8
crates/api/src/lib.rs
Normal file
8
crates/api/src/lib.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
pub mod middleware;
|
||||
pub mod metrics;
|
||||
pub mod queue;
|
||||
pub mod routes;
|
||||
pub mod secrets;
|
||||
pub mod state;
|
||||
pub mod storage;
|
||||
pub mod validation;
|
||||
79
crates/api/src/main.rs
Normal file
79
crates/api/src/main.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
use std::sync::Arc;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
|
||||
|
||||
use api::{metrics, routes, state::AppState, storage};
|
||||
use db::connection::create_pool;
|
||||
use shared::config::AppConfig;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let sentry_dsn = std::env::var("SENTRY_DSN").ok();
|
||||
let _guard = sentry_dsn.map(|dsn| {
|
||||
sentry::init((dsn, sentry::ClientOptions {
|
||||
release: sentry::release_name!(),
|
||||
..Default::default()
|
||||
}))
|
||||
});
|
||||
|
||||
// Structured JSON logging with correlation IDs
|
||||
let json_logging = std::env::var("JSON_LOGGING").unwrap_or_else(|_| "false".to_string()) == "true";
|
||||
|
||||
if json_logging {
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| "api=debug,tower_http=debug".into()),
|
||||
)
|
||||
.with(
|
||||
tracing_subscriber::fmt::layer()
|
||||
.json()
|
||||
.with_current_span(true)
|
||||
.with_span_list(true)
|
||||
.with_target(true),
|
||||
)
|
||||
.init();
|
||||
} else {
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| "api=debug,tower_http=debug".into()),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
}
|
||||
|
||||
metrics::register_metrics();
|
||||
|
||||
let config = Arc::new(AppConfig::from_env()?);
|
||||
|
||||
let db = create_pool(&config.database_url).await?;
|
||||
sqlx::migrate!("../db/migrations").run(&db).await?;
|
||||
|
||||
let redis = redis::Client::open(config.redis_url.clone())?;
|
||||
let redis_conn = redis.get_multiplexed_tokio_connection().await?;
|
||||
|
||||
let s3_config = aws_config::from_env()
|
||||
.endpoint_url(&config.s3_endpoint)
|
||||
.load()
|
||||
.await;
|
||||
let s3 = aws_sdk_s3::Client::new(&s3_config);
|
||||
|
||||
storage::ensure_bucket_exists(&s3, &config.s3_bucket).await?;
|
||||
|
||||
let state = AppState {
|
||||
config,
|
||||
db,
|
||||
redis: redis_conn,
|
||||
s3,
|
||||
};
|
||||
|
||||
let app = routes::create_router(state)
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
|
||||
tracing::info!("API server listening on {}", listener.local_addr()?);
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
31
crates/api/src/metrics.rs
Normal file
31
crates/api/src/metrics.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use lazy_static::lazy_static;
|
||||
use prometheus::{CounterVec, HistogramVec, Registry};
|
||||
use std::time::Instant;
|
||||
|
||||
lazy_static! {
|
||||
pub static ref REGISTRY: Registry = Registry::new();
|
||||
pub static ref REQUEST_COUNTER: CounterVec = CounterVec::new(
|
||||
prometheus::Opts::new("api_requests_total", "Total API requests"),
|
||||
&["endpoint", "status"]
|
||||
)
|
||||
.unwrap();
|
||||
pub static ref REQUEST_DURATION: HistogramVec = HistogramVec::new(
|
||||
prometheus::HistogramOpts::new("api_request_duration_seconds", "Request duration in seconds"),
|
||||
&["endpoint"]
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn register_metrics() {
|
||||
REGISTRY.register(Box::new(REQUEST_COUNTER.clone())).unwrap();
|
||||
REGISTRY.register(Box::new(REQUEST_DURATION.clone())).unwrap();
|
||||
}
|
||||
|
||||
pub fn record_request(endpoint: &str, status: &str) {
|
||||
REQUEST_COUNTER.with_label_values(&[endpoint, status]).inc();
|
||||
}
|
||||
|
||||
pub fn record_duration(endpoint: &str, start: Instant) {
|
||||
let duration = start.elapsed().as_secs_f64();
|
||||
REQUEST_DURATION.with_label_values(&[endpoint]).observe(duration);
|
||||
}
|
||||
52
crates/api/src/middleware/auth.rs
Normal file
52
crates/api/src/middleware/auth.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use db::repos::api_keys;
|
||||
use shared::models::User;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ApiKeyAuth {
|
||||
pub user: User,
|
||||
pub api_key_id: uuid::Uuid,
|
||||
}
|
||||
|
||||
pub async fn api_key_middleware(
|
||||
State(state): State<AppState>,
|
||||
mut req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let api_key = req
|
||||
.headers()
|
||||
.get("x-api-key")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let key_hash = format!("{:x}", md5::compute(api_key));
|
||||
|
||||
let api_key_record = api_keys::find_by_key_hash(&state.db, &key_hash)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let user = db::repos::users::find_by_id(&state.db, api_key_record.user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
api_keys::update_last_used(&state.db, api_key_record.id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let auth = ApiKeyAuth {
|
||||
user,
|
||||
api_key_id: api_key_record.id,
|
||||
};
|
||||
|
||||
req.extensions_mut().insert(auth);
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
42
crates/api/src/middleware/correlation.rs
Normal file
42
crates/api/src/middleware/correlation.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::{header::HeaderValue, StatusCode},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use tracing::{info_span, Instrument};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
pub async fn correlation_id_middleware(
|
||||
State(_state): State<AppState>,
|
||||
mut req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let correlation_id = req
|
||||
.headers()
|
||||
.get("x-correlation-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| Uuid::new_v4().to_string());
|
||||
|
||||
req.headers_mut().insert(
|
||||
"x-correlation-id",
|
||||
HeaderValue::from_str(&correlation_id).unwrap(),
|
||||
);
|
||||
|
||||
let method = req.method().to_string();
|
||||
let uri = req.uri().to_string();
|
||||
|
||||
let span = info_span!(
|
||||
"http_request",
|
||||
correlation_id = %correlation_id,
|
||||
method = %method,
|
||||
uri = %uri,
|
||||
);
|
||||
|
||||
let response = next.run(req).instrument(span).await;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
49
crates/api/src/middleware/jwt.rs
Normal file
49
crates/api/src/middleware/jwt.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use jsonwebtoken::{decode, DecodingKey, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct JwtClaims {
|
||||
pub sub: String,
|
||||
pub exp: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct JwtAuth {
|
||||
pub user_id: Uuid,
|
||||
}
|
||||
|
||||
pub async fn jwt_middleware(
|
||||
State(state): State<AppState>,
|
||||
mut req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let token = req
|
||||
.headers()
|
||||
.get("x-auth-token")
|
||||
.or_else(|| req.headers().get("authorization"))
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer ").or(Some(v)))
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let validation = Validation::default();
|
||||
let token_data = decode::<JwtClaims>(
|
||||
token,
|
||||
&DecodingKey::from_secret(state.config.jwt_secret.as_bytes()),
|
||||
&validation,
|
||||
)
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let user_id = Uuid::parse_str(&token_data.claims.sub).map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
req.extensions_mut().insert(JwtAuth { user_id });
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
5
crates/api/src/middleware/mod.rs
Normal file
5
crates/api/src/middleware/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod auth;
|
||||
pub mod correlation;
|
||||
pub mod jwt;
|
||||
pub mod rate_limit;
|
||||
pub mod waf;
|
||||
36
crates/api/src/middleware/rate_limit.rs
Normal file
36
crates/api/src/middleware/rate_limit.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use redis::AsyncCommands;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
pub async fn rate_limit_middleware(
|
||||
State(state): State<AppState>,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let api_key = req
|
||||
.headers()
|
||||
.get("x-api-key")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("anonymous");
|
||||
|
||||
let key = format!("rate_limit:{}", api_key);
|
||||
let mut conn = state.redis.clone();
|
||||
|
||||
let count: i64 = conn.incr(&key, 1).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if count == 1 {
|
||||
let _: () = conn.expire(&key, 60).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
}
|
||||
|
||||
if count > 60 {
|
||||
return Err(StatusCode::TOO_MANY_REQUESTS);
|
||||
}
|
||||
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
51
crates/api/src/middleware/waf.rs
Normal file
51
crates/api/src/middleware/waf.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use axum::{
|
||||
extract::{ConnectInfo, Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use redis::AsyncCommands;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
pub async fn ip_rate_limit_middleware(
|
||||
State(state): State<AppState>,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let ip = addr.ip().to_string();
|
||||
let key = format!("ip_rate_limit:{}", ip);
|
||||
let mut conn = state.redis.clone();
|
||||
|
||||
// Check if IP is blocked
|
||||
let blocked: bool = conn.exists(format!("ip_blocked:{}", ip))
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if blocked {
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
// Aggressive rate limiting: 100 req/min per IP
|
||||
let count: i64 = conn.incr(&key, 1)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if count == 1 {
|
||||
let _: () = conn.expire(&key, 60)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
}
|
||||
|
||||
if count > 100 {
|
||||
// Block IP for 1 hour after exceeding limit
|
||||
let _: () = conn.set_ex(format!("ip_blocked:{}", ip), "1", 3600)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
return Err(StatusCode::TOO_MANY_REQUESTS);
|
||||
}
|
||||
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
94
crates/api/src/queue.rs
Normal file
94
crates/api/src/queue.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
use redis::AsyncCommands;
|
||||
use shared::{
|
||||
models::CrawlOptions,
|
||||
queue::{Job, JobResult, QUEUE_NAME, RESULT_PREFIX},
|
||||
};
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
pub async fn enqueue_job(
|
||||
state: &AppState,
|
||||
user_id: Uuid,
|
||||
api_key_id: Uuid,
|
||||
endpoint: &str,
|
||||
url: &str,
|
||||
options: &CrawlOptions,
|
||||
webhook_url: Option<String>,
|
||||
) -> Result<Uuid, redis::RedisError> {
|
||||
let job = Job {
|
||||
id: Uuid::new_v4(),
|
||||
user_id,
|
||||
api_key_id,
|
||||
endpoint: endpoint.to_string(),
|
||||
url: url.to_string(),
|
||||
options: options.clone(),
|
||||
webhook_url,
|
||||
};
|
||||
|
||||
let job_json = serde_json::to_string(&job).unwrap();
|
||||
let mut conn = state.redis.clone();
|
||||
conn.rpush::<_, _, ()>(QUEUE_NAME, job_json).await?;
|
||||
|
||||
Ok(job.id)
|
||||
}
|
||||
|
||||
pub async fn wait_for_result(
|
||||
state: &AppState,
|
||||
job_id: Uuid,
|
||||
timeout_secs: u64,
|
||||
) -> Result<Option<JobResult>, redis::RedisError> {
|
||||
let result_key = format!("{}{}", RESULT_PREFIX, job_id);
|
||||
let mut conn = state.redis.clone();
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
while start.elapsed().as_secs() < timeout_secs {
|
||||
let result_json: Option<String> = conn.get(&result_key).await?;
|
||||
if let Some(json) = result_json {
|
||||
let result: JobResult = serde_json::from_str(&json).unwrap_or_else(|_| JobResult {
|
||||
id: job_id,
|
||||
success: false,
|
||||
data: None,
|
||||
error: Some("Failed to deserialize result".to_string()),
|
||||
duration_ms: 0,
|
||||
});
|
||||
return Ok(Some(result));
|
||||
}
|
||||
sleep(Duration::from_millis(200)).await;
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub async fn get_cache_key(url: &str, endpoint: &str, options: &CrawlOptions) -> String {
|
||||
let opts_json = serde_json::to_string(options).unwrap_or_default();
|
||||
let hash = format!("{:x}", md5::compute(format!("{}:{}:{}", url, endpoint, opts_json)));
|
||||
format!("crawlapi:cache:{}", hash)
|
||||
}
|
||||
|
||||
pub async fn get_cached_result(
|
||||
state: &AppState,
|
||||
cache_key: &str,
|
||||
) -> Result<Option<JobResult>, redis::RedisError> {
|
||||
let mut conn = state.redis.clone();
|
||||
let result_json: Option<String> = conn.get::<_, Option<String>>(cache_key).await?;
|
||||
if let Some(json) = result_json {
|
||||
let result: JobResult = serde_json::from_str(&json).unwrap();
|
||||
return Ok(Some(result));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub async fn set_cached_result(
|
||||
state: &AppState,
|
||||
cache_key: &str,
|
||||
result: &JobResult,
|
||||
ttl_secs: u64,
|
||||
) -> Result<(), redis::RedisError> {
|
||||
let mut conn = state.redis.clone();
|
||||
let json = serde_json::to_string(result).unwrap();
|
||||
conn.set_ex::<_, _, ()>(cache_key, json, ttl_secs).await?;
|
||||
Ok(())
|
||||
}
|
||||
130
crates/api/src/routes/ai.rs
Normal file
130
crates/api/src/routes/ai.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
use axum::{
|
||||
extract::{Json, State},
|
||||
http::StatusCode,
|
||||
Extension,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use shared::models::CrawlRequest;
|
||||
|
||||
use crate::{middleware::auth::ApiKeyAuth, queue, state::AppState};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AiExtractRequest {
|
||||
pub url: String,
|
||||
pub schema: serde_json::Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct AiExtractResponse {
|
||||
pub success: bool,
|
||||
pub data: Option<serde_json::Value>,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn extract(
|
||||
State(state): State<AppState>,
|
||||
Extension(auth): Extension<ApiKeyAuth>,
|
||||
Json(body): Json<AiExtractRequest>,
|
||||
) -> Result<Json<AiExtractResponse>, StatusCode> {
|
||||
let openai_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
|
||||
if openai_key.is_empty() {
|
||||
return Ok(Json(AiExtractResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
error: Some("OpenAI not configured".to_string()),
|
||||
}));
|
||||
}
|
||||
|
||||
// Crawl the page via queue
|
||||
let crawl_req = CrawlRequest {
|
||||
url: body.url.clone(),
|
||||
options: Default::default(),
|
||||
};
|
||||
|
||||
let job_id = queue::enqueue_job(
|
||||
&state,
|
||||
auth.user.id,
|
||||
auth.api_key_id,
|
||||
"crawl",
|
||||
&body.url,
|
||||
&crawl_req.options,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let job_result = queue::wait_for_result(&state, job_id, 60)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
match job_result {
|
||||
Some(result) if result.success => {
|
||||
let data = result.data.unwrap_or_default();
|
||||
let html = data.get("html").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let title = data.get("title").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
// Call OpenAI
|
||||
let client = reqwest::Client::new();
|
||||
let system_prompt = format!(
|
||||
"You are a web scraping assistant. Extract structured data from the following HTML page titled '{}'. \
|
||||
Return ONLY a JSON object matching the requested schema. Do not include any explanation.",
|
||||
title
|
||||
);
|
||||
|
||||
let user_prompt = if let Some(p) = body.prompt {
|
||||
p
|
||||
} else {
|
||||
format!("Extract data from this HTML according to schema: {}\n\nHTML:\n{}",
|
||||
body.schema.to_string(),
|
||||
&html[..html.len().min(8000)])
|
||||
};
|
||||
|
||||
let res = client
|
||||
.post("https://api.openai.com/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {}", openai_key))
|
||||
.json(&serde_json::json!({
|
||||
"model": "gpt-4o-mini",
|
||||
"messages": [
|
||||
{ "role": "system", "content": system_prompt },
|
||||
{ "role": "user", "content": user_prompt }
|
||||
],
|
||||
"temperature": 0.1,
|
||||
"response_format": { "type": "json_object" }
|
||||
}))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(response) => {
|
||||
let ai_data: serde_json::Value = response.json().await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
if let Some(content) = ai_data["choices"][0]["message"]["content"].as_str() {
|
||||
let parsed: serde_json::Value = serde_json::from_str(content).unwrap_or_else(|_| serde_json::json!({"raw": content}));
|
||||
Ok(Json(AiExtractResponse {
|
||||
success: true,
|
||||
data: Some(parsed),
|
||||
error: None,
|
||||
}))
|
||||
} else {
|
||||
Ok(Json(AiExtractResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
error: Some("Invalid OpenAI response".to_string()),
|
||||
}))
|
||||
}
|
||||
}
|
||||
Err(e) => Ok(Json(AiExtractResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
error: Some(format!("OpenAI error: {}", e)),
|
||||
})),
|
||||
}
|
||||
}
|
||||
_ => Ok(Json(AiExtractResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
error: Some("Failed to crawl page".to_string()),
|
||||
})),
|
||||
}
|
||||
}
|
||||
142
crates/api/src/routes/auth.rs
Normal file
142
crates/api/src/routes/auth.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
use axum::{
|
||||
extract::{Json, Path, State},
|
||||
http::StatusCode,
|
||||
Extension,
|
||||
};
|
||||
use db::repos::{api_keys, users};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use shared::models::{ApiKey, User};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{middleware::jwt::JwtAuth, state::AppState};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RegisterRequest {
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct AuthResponse {
|
||||
pub user: User,
|
||||
pub token: String,
|
||||
}
|
||||
|
||||
pub async fn register(
|
||||
State(state): State<AppState>,
|
||||
Json(body): Json<RegisterRequest>,
|
||||
) -> Result<Json<AuthResponse>, StatusCode> {
|
||||
if users::find_by_email(&state.db, &body.email).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?.is_some() {
|
||||
return Err(StatusCode::CONFLICT);
|
||||
}
|
||||
|
||||
let password_hash = bcrypt::hash(&body.password, bcrypt::DEFAULT_COST).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let user = users::create(&state.db, &body.email, Some(&password_hash), None)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let token = create_jwt(&user.id.to_string(), &state.config.jwt_secret)?;
|
||||
|
||||
Ok(Json(AuthResponse { user, token }))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LoginRequest {
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
pub async fn login(
|
||||
State(state): State<AppState>,
|
||||
Json(body): Json<LoginRequest>,
|
||||
) -> Result<Json<AuthResponse>, StatusCode> {
|
||||
let user = users::find_by_email(&state.db, &body.email)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let password_hash = user.password_hash.as_ref().ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
if !bcrypt::verify(&body.password, password_hash).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
let token = create_jwt(&user.id.to_string(), &state.config.jwt_secret)?;
|
||||
|
||||
Ok(Json(AuthResponse { user, token }))
|
||||
}
|
||||
|
||||
pub fn create_jwt(user_id: &str, secret: &str) -> Result<String, StatusCode> {
|
||||
use jsonwebtoken::{encode, EncodingKey, Header};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Claims {
|
||||
sub: String,
|
||||
exp: usize,
|
||||
}
|
||||
|
||||
let claims = Claims {
|
||||
sub: user_id.to_string(),
|
||||
exp: (chrono::Utc::now() + chrono::Duration::days(30)).timestamp() as usize,
|
||||
};
|
||||
|
||||
encode(&Header::default(), &claims, &EncodingKey::from_secret(secret.as_bytes()))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateApiKeyRequest {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ApiKeyResponse {
|
||||
pub id: Uuid,
|
||||
pub key: String,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
pub async fn create_api_key(
|
||||
State(state): State<AppState>,
|
||||
Extension(auth): Extension<JwtAuth>,
|
||||
Json(body): Json<CreateApiKeyRequest>,
|
||||
) -> Result<Json<ApiKeyResponse>, StatusCode> {
|
||||
let api_key = format!("crawlapi_{}", Uuid::new_v4().to_string().replace('-', ""));
|
||||
let key_hash = format!("{:x}", md5::compute(&api_key));
|
||||
|
||||
let key = api_keys::create(&state.db, auth.user_id, &key_hash, &body.name)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(ApiKeyResponse {
|
||||
id: key.id,
|
||||
key: api_key,
|
||||
name: key.name,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn list_api_keys(
|
||||
State(state): State<AppState>,
|
||||
Extension(auth): Extension<JwtAuth>,
|
||||
) -> Result<Json<Vec<ApiKey>>, StatusCode> {
|
||||
let keys = api_keys::list_by_user(&state.db, auth.user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(keys))
|
||||
}
|
||||
|
||||
pub async fn delete_api_key(
|
||||
State(state): State<AppState>,
|
||||
Extension(auth): Extension<JwtAuth>,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
let deleted = api_keys::delete_by_id(&state.db, id, auth.user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if deleted {
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
} else {
|
||||
Err(StatusCode::NOT_FOUND)
|
||||
}
|
||||
}
|
||||
296
crates/api/src/routes/crawl.rs
Normal file
296
crates/api/src/routes/crawl.rs
Normal file
@@ -0,0 +1,296 @@
|
||||
use axum::{
|
||||
extract::{Json, State},
|
||||
http::StatusCode,
|
||||
Extension,
|
||||
};
|
||||
use db::repos::{usage_logs, users};
|
||||
use serde_json::json;
|
||||
use shared::{
|
||||
error::AppError,
|
||||
models::{CrawlRequest, CrawlResponse},
|
||||
};
|
||||
use std::time::Instant;
|
||||
use tokio::fs;
|
||||
|
||||
use crate::{middleware::auth::ApiKeyAuth, queue, state::AppState, storage, validation};
|
||||
|
||||
async fn upload_files_if_needed(
|
||||
state: &AppState,
|
||||
endpoint: &str,
|
||||
result: &mut serde_json::Value,
|
||||
) -> Result<(), AppError> {
|
||||
// Handle file_path
|
||||
let file_path_opt = result.get("file_path").and_then(|v| v.as_str()).map(String::from);
|
||||
if let Some(file_path) = file_path_opt {
|
||||
let file_data = fs::read(&file_path)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to read file: {}", e)))?;
|
||||
|
||||
let ext = if endpoint == "pdf" { "pdf" } else { "png" };
|
||||
let content_type = if endpoint == "pdf" { "application/pdf" } else { "image/png" };
|
||||
let key = storage::generate_file_key(endpoint, ext);
|
||||
|
||||
storage::upload_file(
|
||||
&state.s3,
|
||||
&state.config.s3_bucket,
|
||||
&key,
|
||||
content_type,
|
||||
file_data,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let public_url = storage::get_public_url(
|
||||
endpoint,
|
||||
&state.config.s3_endpoint,
|
||||
&state.config.s3_bucket,
|
||||
&key,
|
||||
);
|
||||
|
||||
if let Some(obj) = result.as_object_mut() {
|
||||
obj.remove("file_path");
|
||||
obj.insert("url".to_string(), json!(public_url));
|
||||
}
|
||||
|
||||
let _ = fs::remove_file(file_path).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_endpoint(
|
||||
state: State<AppState>,
|
||||
Extension(auth): Extension<ApiKeyAuth>,
|
||||
Json(body): Json<CrawlRequest>,
|
||||
endpoint: &'static str,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
let start = Instant::now();
|
||||
|
||||
// Validate URL
|
||||
if let Err(e) = validation::validate_url(&body.url) {
|
||||
let _ = usage_logs::create(
|
||||
&state.db,
|
||||
auth.user.id,
|
||||
auth.api_key_id,
|
||||
endpoint,
|
||||
&body.url,
|
||||
"error",
|
||||
0,
|
||||
start.elapsed().as_millis() as i64,
|
||||
)
|
||||
.await;
|
||||
return Ok(Json(CrawlResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
calls_remaining: Some(auth.user.credits),
|
||||
error: Some(e.to_string()),
|
||||
}));
|
||||
}
|
||||
|
||||
// Validate webhook URL if provided
|
||||
if let Some(ref webhook_url) = body.options.webhook_url {
|
||||
if let Err(e) = validation::validate_webhook_url(webhook_url) {
|
||||
return Ok(Json(CrawlResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
calls_remaining: Some(auth.user.credits),
|
||||
error: Some(e.to_string()),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// Check credits
|
||||
if auth.user.credits <= 0 {
|
||||
return Ok(Json(CrawlResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
calls_remaining: Some(0),
|
||||
error: Some("Insufficient credits".to_string()),
|
||||
}));
|
||||
}
|
||||
|
||||
// Deduct credits
|
||||
let has_credits = users::deduct_credits(&state.db, auth.user.id, 1)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if !has_credits {
|
||||
return Ok(Json(CrawlResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
calls_remaining: Some(0),
|
||||
error: Some("Insufficient credits".to_string()),
|
||||
}));
|
||||
}
|
||||
|
||||
// Try cache first
|
||||
let cache_key = queue::get_cache_key(&body.url, endpoint, &body.options).await;
|
||||
let cached = queue::get_cached_result(&state, &cache_key)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if let Some(cached_result) = cached {
|
||||
let _ = usage_logs::create(
|
||||
&state.db,
|
||||
auth.user.id,
|
||||
auth.api_key_id,
|
||||
endpoint,
|
||||
&body.url,
|
||||
"success",
|
||||
1,
|
||||
start.elapsed().as_millis() as i64,
|
||||
)
|
||||
.await;
|
||||
|
||||
return Ok(Json(CrawlResponse {
|
||||
success: cached_result.success,
|
||||
data: cached_result.data,
|
||||
calls_remaining: Some(auth.user.credits - 1),
|
||||
error: cached_result.error,
|
||||
}));
|
||||
}
|
||||
|
||||
// Enqueue job and wait for result
|
||||
let job_id = queue::enqueue_job(
|
||||
&state,
|
||||
auth.user.id,
|
||||
auth.api_key_id,
|
||||
endpoint,
|
||||
&body.url,
|
||||
&body.options,
|
||||
body.options.webhook_url.clone(),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let job_result = queue::wait_for_result(&state, job_id, 60)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let duration = start.elapsed().as_millis() as i64;
|
||||
let remaining = auth.user.credits - 1;
|
||||
|
||||
match job_result {
|
||||
Some(mut result) => {
|
||||
// Upload files if needed
|
||||
if let Some(ref mut data) = result.data {
|
||||
let _ = upload_files_if_needed(&state, endpoint, data).await;
|
||||
}
|
||||
|
||||
// Cache successful results for 5 minutes
|
||||
if result.success {
|
||||
let _ = queue::set_cached_result(&state, &cache_key, &result, 300).await;
|
||||
}
|
||||
|
||||
let status = if result.success { "success" } else { "error" };
|
||||
let _ = usage_logs::create(
|
||||
&state.db,
|
||||
auth.user.id,
|
||||
auth.api_key_id,
|
||||
endpoint,
|
||||
&body.url,
|
||||
status,
|
||||
1,
|
||||
duration,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(CrawlResponse {
|
||||
success: result.success,
|
||||
data: result.data,
|
||||
calls_remaining: Some(remaining),
|
||||
error: result.error,
|
||||
}))
|
||||
}
|
||||
None => {
|
||||
let _ = usage_logs::create(
|
||||
&state.db,
|
||||
auth.user.id,
|
||||
auth.api_key_id,
|
||||
endpoint,
|
||||
&body.url,
|
||||
"timeout",
|
||||
1,
|
||||
duration,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Json(CrawlResponse {
|
||||
success: false,
|
||||
data: None,
|
||||
calls_remaining: Some(remaining),
|
||||
error: Some("Job timed out".to_string()),
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_crawl(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "crawl").await
|
||||
}
|
||||
|
||||
pub async fn handle_content(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "content").await
|
||||
}
|
||||
|
||||
pub async fn handle_screenshot(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "screenshot").await
|
||||
}
|
||||
|
||||
pub async fn handle_pdf(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "pdf").await
|
||||
}
|
||||
|
||||
pub async fn handle_markdown(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "markdown").await
|
||||
}
|
||||
|
||||
pub async fn handle_snapshot(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "snapshot").await
|
||||
}
|
||||
|
||||
pub async fn handle_scrape(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "scrape").await
|
||||
}
|
||||
|
||||
pub async fn handle_json(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "json").await
|
||||
}
|
||||
|
||||
pub async fn handle_links(
|
||||
state: State<AppState>,
|
||||
auth: Extension<ApiKeyAuth>,
|
||||
body: Json<CrawlRequest>,
|
||||
) -> Result<Json<CrawlResponse>, StatusCode> {
|
||||
handle_endpoint(state, auth, body, "links").await
|
||||
}
|
||||
72
crates/api/src/routes/mod.rs
Normal file
72
crates/api/src/routes/mod.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
pub mod auth;
|
||||
pub mod crawl;
|
||||
pub mod oauth;
|
||||
pub mod stripe;
|
||||
pub mod ai;
|
||||
pub mod teams;
|
||||
pub mod ws;
|
||||
|
||||
use axum::{
|
||||
middleware,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use tower_http::cors::CorsLayer;
|
||||
|
||||
use crate::{
|
||||
middleware::{auth::api_key_middleware, correlation::correlation_id_middleware, jwt::jwt_middleware, rate_limit::rate_limit_middleware, waf::ip_rate_limit_middleware},
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
pub fn create_router(state: AppState) -> Router {
|
||||
let api_routes = Router::new()
|
||||
.route("/crawl", post(crawl::handle_crawl))
|
||||
.route("/content", post(crawl::handle_content))
|
||||
.route("/screenshot", post(crawl::handle_screenshot))
|
||||
.route("/pdf", post(crawl::handle_pdf))
|
||||
.route("/markdown", post(crawl::handle_markdown))
|
||||
.route("/snapshot", post(crawl::handle_snapshot))
|
||||
.route("/scrape", post(crawl::handle_scrape))
|
||||
.route("/json", post(crawl::handle_json))
|
||||
.route("/links", post(crawl::handle_links))
|
||||
.route("/extract", post(ai::extract))
|
||||
.route_layer(middleware::from_fn_with_state(state.clone(), api_key_middleware))
|
||||
.route_layer(middleware::from_fn_with_state(state.clone(), rate_limit_middleware));
|
||||
|
||||
let auth_routes = Router::new()
|
||||
.route("/auth/register", post(auth::register))
|
||||
.route("/auth/login", post(auth::login))
|
||||
.route("/auth/google", get(oauth::google_auth_url))
|
||||
.route("/auth/google/callback", get(oauth::google_callback));
|
||||
|
||||
let protected_routes = Router::new()
|
||||
.route("/auth/api-keys", post(auth::create_api_key))
|
||||
.route("/auth/api-keys", get(auth::list_api_keys))
|
||||
.route("/auth/api-keys/{id}", axum::routing::delete(auth::delete_api_key))
|
||||
.route("/stripe/checkout", post(stripe::create_checkout))
|
||||
.route("/teams", post(teams::create))
|
||||
.route("/teams/{slug}", get(teams::get))
|
||||
.route("/teams/{slug}/members", post(teams::add_member))
|
||||
.route_layer(middleware::from_fn_with_state(state.clone(), jwt_middleware));
|
||||
|
||||
let stripe_webhook = Router::new()
|
||||
.route("/stripe/webhook", post(stripe::webhook));
|
||||
|
||||
Router::new()
|
||||
.nest("/api", api_routes)
|
||||
.nest("/api", auth_routes)
|
||||
.nest("/api", protected_routes)
|
||||
.nest("/api", stripe_webhook)
|
||||
.route("/metrics", get(|| async {
|
||||
use prometheus::Encoder;
|
||||
let encoder = prometheus::TextEncoder::new();
|
||||
let mut buffer = vec![];
|
||||
encoder.encode(&crate::metrics::REGISTRY.gather(), &mut buffer).unwrap();
|
||||
String::from_utf8(buffer).unwrap()
|
||||
}))
|
||||
.route("/ws/logs", get(ws::live_logs))
|
||||
.layer(middleware::from_fn_with_state(state.clone(), ip_rate_limit_middleware))
|
||||
.layer(middleware::from_fn_with_state(state.clone(), correlation_id_middleware))
|
||||
.layer(CorsLayer::permissive())
|
||||
.with_state(state)
|
||||
}
|
||||
137
crates/api/src/routes/oauth.rs
Normal file
137
crates/api/src/routes/oauth.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
Json,
|
||||
};
|
||||
use db::repos::{oauth, users};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GoogleCallback {
|
||||
pub code: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct GoogleAuthUrl {
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GoogleTokenResponse {
|
||||
access_token: String,
|
||||
id_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GoogleUserInfo {
|
||||
sub: String,
|
||||
email: String,
|
||||
name: Option<String>,
|
||||
picture: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn google_auth_url(State(_state): State<AppState>) -> Result<Json<GoogleAuthUrl>, StatusCode> {
|
||||
let client_id = std::env::var("GOOGLE_CLIENT_ID").unwrap_or_default();
|
||||
if client_id.is_empty() {
|
||||
return Err(StatusCode::NOT_IMPLEMENTED);
|
||||
}
|
||||
let redirect_uri = std::env::var("GOOGLE_REDIRECT_URI")
|
||||
.unwrap_or_else(|_| "http://localhost:3000/api/auth/google/callback".to_string());
|
||||
|
||||
let url = format!(
|
||||
"https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=email%20profile&access_type=offline&prompt=consent",
|
||||
client_id, redirect_uri
|
||||
);
|
||||
|
||||
Ok(Json(GoogleAuthUrl { url }))
|
||||
}
|
||||
|
||||
pub async fn google_callback(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<GoogleCallback>,
|
||||
) -> Result<Json<super::auth::AuthResponse>, StatusCode> {
|
||||
let client_id = std::env::var("GOOGLE_CLIENT_ID").unwrap_or_default();
|
||||
let client_secret = std::env::var("GOOGLE_CLIENT_SECRET").unwrap_or_default();
|
||||
let redirect_uri = std::env::var("GOOGLE_REDIRECT_URI")
|
||||
.unwrap_or_else(|_| "http://localhost:3000/api/auth/google/callback".to_string());
|
||||
|
||||
if client_id.is_empty() || client_secret.is_empty() {
|
||||
// MVP fallback: create mock user
|
||||
let email = format!("google_user_{}@example.com", ¶ms.code[..8.min(params.code.len())]);
|
||||
let user = match users::find_by_email(&state.db, &email).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? {
|
||||
Some(u) => u,
|
||||
None => {
|
||||
let u = users::create(&state.db, &email, None, Some(¶ms.code)).await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let _ = oauth::create(&state.db, u.id, "google", ¶ms.code).await;
|
||||
u
|
||||
}
|
||||
};
|
||||
let token = super::auth::create_jwt(&user.id.to_string(), &state.config.jwt_secret)?;
|
||||
return Ok(Json(super::auth::AuthResponse { user, token }));
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
let client = reqwest::Client::new();
|
||||
let token_res = client
|
||||
.post("https://oauth2.googleapis.com/token")
|
||||
.form(&[
|
||||
("code", params.code.as_str()),
|
||||
("client_id", client_id.as_str()),
|
||||
("client_secret", client_secret.as_str()),
|
||||
("redirect_uri", redirect_uri.as_str()),
|
||||
("grant_type", "authorization_code"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.json::<GoogleTokenResponse>()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
// Get user info
|
||||
let user_info = client
|
||||
.get("https://openidconnect.googleapis.com/v1/userinfo")
|
||||
.header("Authorization", format!("Bearer {}", token_res.access_token))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.json::<GoogleUserInfo>()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
// Find or create user
|
||||
let user = match oauth::find_by_provider(&state.db, "google", &user_info.sub).await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
{
|
||||
Some(oauth_account) => {
|
||||
users::find_by_id(&state.db, oauth_account.user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
}
|
||||
None => {
|
||||
let user = match users::find_by_email(&state.db, &user_info.email).await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
{
|
||||
Some(u) => u,
|
||||
None => {
|
||||
users::create(&state.db, &user_info.email, None, Some(&user_info.sub))
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
}
|
||||
};
|
||||
|
||||
oauth::create(&state.db, user.id, "google", &user_info.sub)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
user
|
||||
}
|
||||
};
|
||||
|
||||
let token = super::auth::create_jwt(&user.id.to_string(), &state.config.jwt_secret)?;
|
||||
Ok(Json(super::auth::AuthResponse { user, token }))
|
||||
}
|
||||
146
crates/api/src/routes/stripe.rs
Normal file
146
crates/api/src/routes/stripe.rs
Normal file
@@ -0,0 +1,146 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::{HeaderMap, StatusCode},
|
||||
Json,
|
||||
};
|
||||
use db::repos::{subscriptions, users};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{middleware::jwt::JwtAuth, state::AppState};
|
||||
use axum::Extension;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateCheckoutRequest {
|
||||
pub price_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct CheckoutResponse {
|
||||
pub checkout_url: String,
|
||||
}
|
||||
|
||||
pub async fn create_checkout(
|
||||
State(state): State<AppState>,
|
||||
Extension(auth): Extension<JwtAuth>,
|
||||
Json(body): Json<CreateCheckoutRequest>,
|
||||
) -> Result<Json<CheckoutResponse>, StatusCode> {
|
||||
let stripe_secret = std::env::var("STRIPE_SECRET_KEY").map_err(|_| StatusCode::NOT_IMPLEMENTED)?;
|
||||
|
||||
let user = users::find_by_id(&state.db, auth.user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
// Create Stripe customer via HTTP API directly (simpler than SDK for MVP)
|
||||
let client = reqwest::Client::new();
|
||||
let customer_res = client
|
||||
.post("https://api.stripe.com/v1/customers")
|
||||
.basic_auth(&stripe_secret, Some(""))
|
||||
.form(&[("email", &user.email)])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let customer_data: serde_json::Value = customer_res.json().await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let customer_id = customer_data["id"].as_str().unwrap_or("");
|
||||
|
||||
let success_url = std::env::var("STRIPE_SUCCESS_URL")
|
||||
.unwrap_or_else(|_| "http://localhost:3000/dashboard?success=true".to_string());
|
||||
let cancel_url = std::env::var("STRIPE_CANCEL_URL")
|
||||
.unwrap_or_else(|_| "http://localhost:3000/dashboard?canceled=true".to_string());
|
||||
|
||||
let session_res = client
|
||||
.post("https://api.stripe.com/v1/checkout/sessions")
|
||||
.basic_auth(&stripe_secret, Some(""))
|
||||
.form(&[
|
||||
("customer", customer_id),
|
||||
("success_url", &success_url),
|
||||
("cancel_url", &cancel_url),
|
||||
("mode", "subscription"),
|
||||
("line_items[0][price]", &body.price_id),
|
||||
("line_items[0][quantity]", "1"),
|
||||
("metadata[user_id]", &auth.user_id.to_string()),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let session_data: serde_json::Value = session_res.json().await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let url = session_data["url"].as_str().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(CheckoutResponse { checkout_url: url.to_string() }))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct StripeWebhook {
|
||||
#[serde(rename = "type")]
|
||||
pub event_type: String,
|
||||
pub data: StripeEventData,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct StripeEventData {
|
||||
pub object: serde_json::Value,
|
||||
}
|
||||
|
||||
pub async fn webhook(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
body: String,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
let stripe_secret = std::env::var("STRIPE_WEBHOOK_SECRET").unwrap_or_default();
|
||||
|
||||
// Verify webhook signature if configured
|
||||
if !stripe_secret.is_empty() {
|
||||
let sig = headers
|
||||
.get("stripe-signature")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.ok_or(StatusCode::BAD_REQUEST)?;
|
||||
|
||||
// In production, verify signature using Stripe library
|
||||
// For MVP, we log and process
|
||||
tracing::info!("Webhook signature: {}", sig);
|
||||
}
|
||||
|
||||
let event: serde_json::Value = serde_json::from_str(&body).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let event_type = event["type"].as_str().unwrap_or("");
|
||||
|
||||
match event_type {
|
||||
"checkout.session.completed" => {
|
||||
if let Some(metadata) = event["data"]["object"]["metadata"].as_object() {
|
||||
if let Some(user_id_str) = metadata.get("user_id").and_then(|v| v.as_str()) {
|
||||
if let Ok(user_id) = uuid::Uuid::parse_str(user_id_str) {
|
||||
let customer_id = event["data"]["object"]["customer"].as_str().unwrap_or("");
|
||||
let subscription_id = event["data"]["object"]["subscription"].as_str().unwrap_or("");
|
||||
let _ = subscriptions::create_or_update(
|
||||
&state.db,
|
||||
user_id,
|
||||
Some(customer_id),
|
||||
Some(subscription_id),
|
||||
None,
|
||||
"active",
|
||||
"paid",
|
||||
).await;
|
||||
tracing::info!("Subscription activated for user {}", user_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"invoice.payment_succeeded" => {
|
||||
tracing::info!("Invoice payment succeeded");
|
||||
}
|
||||
"customer.subscription.deleted" => {
|
||||
if let Some(sub_id) = event["data"]["object"]["id"].as_str() {
|
||||
let _ = subscriptions::update_status(&state.db, sub_id, "canceled").await;
|
||||
tracing::info!("Subscription {} canceled", sub_id);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
95
crates/api/src/routes/teams.rs
Normal file
95
crates/api/src/routes/teams.rs
Normal file
@@ -0,0 +1,95 @@
|
||||
use axum::{
|
||||
extract::{Json, Path, State},
|
||||
http::StatusCode,
|
||||
Extension,
|
||||
};
|
||||
use db::repos::teams;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use shared::models::{Team, TeamMember};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{middleware::jwt::JwtAuth, state::AppState};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateTeamRequest {
|
||||
pub name: String,
|
||||
pub slug: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct TeamResponse {
|
||||
pub team: Team,
|
||||
pub members: Vec<TeamMember>,
|
||||
}
|
||||
|
||||
pub async fn create(
|
||||
State(state): State<AppState>,
|
||||
Extension(auth): Extension<JwtAuth>,
|
||||
Json(body): Json<CreateTeamRequest>,
|
||||
) -> Result<Json<Team>, StatusCode> {
|
||||
let team = teams::create(&state.db, &body.name, &body.slug, auth.user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
// Add owner as member
|
||||
let _ = teams::add_member(&state.db, team.id, auth.user_id, "owner").await;
|
||||
|
||||
Ok(Json(team))
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
State(state): State<AppState>,
|
||||
Extension(auth): Extension<JwtAuth>,
|
||||
Path(slug): Path<String>,
|
||||
) -> Result<Json<TeamResponse>, StatusCode> {
|
||||
let team = teams::find_by_slug(&state.db, &slug)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
// Check membership
|
||||
let member = teams::find_member(&state.db, team.id, auth.user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if member.is_none() {
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
let members = teams::list_members(&state.db, team.id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(TeamResponse { team, members }))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AddMemberRequest {
|
||||
pub user_id: Uuid,
|
||||
#[serde(default)]
|
||||
pub role: String,
|
||||
}
|
||||
|
||||
pub async fn add_member(
|
||||
State(state): State<AppState>,
|
||||
Extension(auth): Extension<JwtAuth>,
|
||||
Path(slug): Path<String>,
|
||||
Json(body): Json<AddMemberRequest>,
|
||||
) -> Result<Json<TeamMember>, StatusCode> {
|
||||
let team = teams::find_by_slug(&state.db, &slug)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
// Only owner can add members
|
||||
if team.owner_id != auth.user_id {
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
let role = if body.role.is_empty() { "member" } else { &body.role };
|
||||
let member = teams::add_member(&state.db, team.id, body.user_id, role)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(member))
|
||||
}
|
||||
41
crates/api/src/routes/ws.rs
Normal file
41
crates/api/src/routes/ws.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use axum::{
|
||||
extract::ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
response::IntoResponse,
|
||||
};
|
||||
|
||||
pub async fn live_logs(ws: WebSocketUpgrade) -> impl IntoResponse {
|
||||
ws.on_upgrade(handle_socket)
|
||||
}
|
||||
|
||||
async fn handle_socket(mut socket: WebSocket) {
|
||||
// Send initial connection message
|
||||
let _ = socket.send(Message::Text(r#"{"type":"connected","message":"Live logs connected"}"#.to_string())).await;
|
||||
|
||||
// In a real implementation, this would subscribe to a Redis pub/sub channel
|
||||
// and stream logs to the client. For MVP, we send a heartbeat.
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = interval.tick() => {
|
||||
let msg = r#"{"type":"heartbeat","timestamp":""#.to_string()
|
||||
+ &chrono::Utc::now().to_rfc3339()
|
||||
+ "\"}";
|
||||
if socket.send(Message::Text(msg)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
msg = socket.recv() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Close(_))) | None => break,
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
if text == "ping" {
|
||||
let _ = socket.send(Message::Text("pong".to_string())).await;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
69
crates/api/src/secrets/mod.rs
Normal file
69
crates/api/src/secrets/mod.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
use shared::error::AppError;
|
||||
|
||||
pub async fn get_secret(key: &str) -> Result<String, AppError> {
|
||||
// Priority 1: Environment variable (for local dev)
|
||||
if let Ok(val) = std::env::var(key) {
|
||||
return Ok(val);
|
||||
}
|
||||
|
||||
// Priority 2: Try Vault
|
||||
if let Ok(val) = get_vault_secret(key).await {
|
||||
return Ok(val);
|
||||
}
|
||||
|
||||
// Priority 3: Try AWS Secrets Manager
|
||||
if let Ok(val) = get_aws_secret(key).await {
|
||||
return Ok(val);
|
||||
}
|
||||
|
||||
Err(AppError::Internal(format!("Secret {} not found in any provider", key)))
|
||||
}
|
||||
|
||||
async fn get_vault_secret(key: &str) -> Result<String, AppError> {
|
||||
let vault_addr = match std::env::var("VAULT_ADDR") {
|
||||
Ok(addr) => addr,
|
||||
Err(_) => return Err(AppError::Internal("VAULT_ADDR not set".to_string())),
|
||||
};
|
||||
|
||||
let vault_token = match std::env::var("VAULT_TOKEN") {
|
||||
Ok(token) => token,
|
||||
Err(_) => return Err(AppError::Internal("VAULT_TOKEN not set".to_string())),
|
||||
};
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let response = client
|
||||
.get(format!("{}/v1/secret/data/{}", vault_addr, key))
|
||||
.header("X-Vault-Token", vault_token)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Vault request failed: {}", e)))?;
|
||||
|
||||
let data: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Vault response parse failed: {}", e)))?;
|
||||
|
||||
data["data"]["data"][key]
|
||||
.as_str()
|
||||
.map(|s| s.to_string())
|
||||
.ok_or_else(|| AppError::Internal(format!("Secret {} not found in Vault", key)))
|
||||
}
|
||||
|
||||
async fn get_aws_secret(key: &str) -> Result<String, AppError> {
|
||||
let secret_name = format!("crawlapi/{}", key.to_lowercase().replace('_', "/"));
|
||||
|
||||
let config = aws_config::from_env().load().await;
|
||||
let client = aws_sdk_secretsmanager::Client::new(&config);
|
||||
|
||||
let response = client
|
||||
.get_secret_value()
|
||||
.secret_id(&secret_name)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("AWS Secrets Manager error: {}", e)))?;
|
||||
|
||||
response
|
||||
.secret_string()
|
||||
.map(|s| s.to_string())
|
||||
.ok_or_else(|| AppError::Internal(format!("Secret {} not found in AWS", key)))
|
||||
}
|
||||
13
crates/api/src/state.rs
Normal file
13
crates/api/src/state.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
use aws_sdk_s3::Client as S3Client;
|
||||
use db::DbPool;
|
||||
use redis::aio::MultiplexedConnection;
|
||||
use shared::config::AppConfig;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub config: Arc<AppConfig>,
|
||||
pub db: DbPool,
|
||||
pub redis: MultiplexedConnection,
|
||||
pub s3: S3Client,
|
||||
}
|
||||
54
crates/api/src/storage/mod.rs
Normal file
54
crates/api/src/storage/mod.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use aws_sdk_s3::Client as S3Client;
|
||||
use shared::error::AppError;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn upload_file(
|
||||
s3: &S3Client,
|
||||
bucket: &str,
|
||||
key: &str,
|
||||
content_type: &str,
|
||||
data: Vec<u8>,
|
||||
) -> Result<String, AppError> {
|
||||
s3.put_object()
|
||||
.bucket(bucket)
|
||||
.key(key)
|
||||
.content_type(content_type)
|
||||
.body(data.into())
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::S3(e.to_string()))?;
|
||||
|
||||
Ok(format!("{}/{}", bucket, key))
|
||||
}
|
||||
|
||||
pub async fn ensure_bucket_exists(s3: &S3Client, bucket: &str) -> Result<(), AppError> {
|
||||
let exists = s3
|
||||
.head_bucket()
|
||||
.bucket(bucket)
|
||||
.send()
|
||||
.await
|
||||
.is_ok();
|
||||
|
||||
if !exists {
|
||||
s3.create_bucket()
|
||||
.bucket(bucket)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::S3(e.to_string()))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn generate_file_key(endpoint: &str, ext: &str) -> String {
|
||||
let id = Uuid::new_v4().to_string().replace('-', "");
|
||||
match endpoint {
|
||||
"screenshot" => format!("screenshots/{}.png", id),
|
||||
"pdf" => format!("pdfs/{}.pdf", id),
|
||||
_ => format!("files/{}.{}", id, ext),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_public_url(_endpoint: &str, s3_endpoint: &str, bucket: &str, key: &str) -> String {
|
||||
format!("{}/{}/{}", s3_endpoint, bucket, key)
|
||||
}
|
||||
62
crates/api/src/validation.rs
Normal file
62
crates/api/src/validation.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
use shared::error::AppError;
|
||||
|
||||
pub fn validate_url(url: &str) -> Result<(), AppError> {
|
||||
let parsed = url::Url::parse(url).map_err(|_| AppError::InvalidUrl(url.to_string()))?;
|
||||
|
||||
// Only allow http and https
|
||||
if parsed.scheme() != "http" && parsed.scheme() != "https" {
|
||||
return Err(AppError::InvalidUrl("Only HTTP and HTTPS URLs are allowed".to_string()));
|
||||
}
|
||||
|
||||
// Block private IP ranges
|
||||
if let Some(host) = parsed.host_str() {
|
||||
if host == "localhost" || host == "127.0.0.1" || host.starts_with("10.") || host.starts_with("192.168.") {
|
||||
return Err(AppError::InvalidUrl("Private IP addresses are not allowed".to_string()));
|
||||
}
|
||||
if host.starts_with("172.") {
|
||||
if let Some(seg) = host.split('.').nth(1) {
|
||||
if let Ok(n) = seg.parse::<u8>() {
|
||||
if n >= 16 && n <= 31 {
|
||||
return Err(AppError::InvalidUrl("Private IP addresses are not allowed".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Block file://, ftp://, etc.
|
||||
if parsed.scheme() == "file" {
|
||||
return Err(AppError::InvalidUrl("File URLs are not allowed".to_string()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn validate_webhook_url(url: &str) -> Result<(), AppError> {
|
||||
let parsed = url::Url::parse(url).map_err(|_| AppError::InvalidUrl(url.to_string()))?;
|
||||
|
||||
// Only allow http and https
|
||||
if parsed.scheme() != "http" && parsed.scheme() != "https" {
|
||||
return Err(AppError::InvalidUrl("Webhook must use HTTP or HTTPS".to_string()));
|
||||
}
|
||||
|
||||
// Block private IPs and localhost for webhooks
|
||||
if let Some(host) = parsed.host_str() {
|
||||
if host == "localhost" || host == "127.0.0.1" || host.starts_with("10.") || host.starts_with("192.168.") {
|
||||
return Err(AppError::InvalidUrl("Webhook cannot point to private addresses".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn validate_size(content: &[u8], max_mb: usize) -> Result<(), AppError> {
|
||||
let max_bytes = max_mb * 1024 * 1024;
|
||||
if content.len() > max_bytes {
|
||||
return Err(AppError::BadRequest(format!(
|
||||
"Content exceeds maximum size of {}MB",
|
||||
max_mb
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
29
crates/api/tests/integration_test.rs
Normal file
29
crates/api/tests/integration_test.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_check() {
|
||||
// This is a placeholder for integration tests
|
||||
// In a real setup, you would spawn the API server and make HTTP requests
|
||||
assert!(true);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_crawl_request_validation() {
|
||||
let req = shared::models::CrawlRequest {
|
||||
url: "https://example.com".to_string(),
|
||||
options: shared::models::CrawlOptions::default(),
|
||||
};
|
||||
assert_eq!(req.url, "https://example.com");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_api_response_format() {
|
||||
let response = shared::api::ApiResponse::ok(json!({"test": true}));
|
||||
assert!(response.success);
|
||||
assert!(response.data.is_some());
|
||||
assert!(response.error.is_none());
|
||||
|
||||
let error = shared::api::ApiResponse::<()>::err("Something went wrong");
|
||||
assert!(!error.success);
|
||||
assert!(error.error.is_some());
|
||||
}
|
||||
17
crates/db/Cargo.toml
Normal file
17
crates/db/Cargo.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "db"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
shared = { path = "../shared" }
|
||||
sqlx = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
serde_json = { workspace = true }
|
||||
38
crates/db/migrations/001_init.sql
Normal file
38
crates/db/migrations/001_init.sql
Normal file
@@ -0,0 +1,38 @@
|
||||
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
||||
|
||||
CREATE TABLE users (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
email VARCHAR(255) UNIQUE NOT NULL,
|
||||
password_hash VARCHAR(255),
|
||||
google_id VARCHAR(255) UNIQUE,
|
||||
credits BIGINT NOT NULL DEFAULT 30,
|
||||
tier VARCHAR(50) NOT NULL DEFAULT 'free',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE api_keys (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
key_hash VARCHAR(255) UNIQUE NOT NULL,
|
||||
name VARCHAR(255) NOT NULL DEFAULT 'Default',
|
||||
last_used_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE usage_logs (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
api_key_id UUID NOT NULL REFERENCES api_keys(id) ON DELETE CASCADE,
|
||||
endpoint VARCHAR(100) NOT NULL,
|
||||
url TEXT NOT NULL,
|
||||
status VARCHAR(50) NOT NULL,
|
||||
credits_used BIGINT NOT NULL DEFAULT 1,
|
||||
duration_ms BIGINT NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX idx_api_keys_user_id ON api_keys(user_id);
|
||||
CREATE INDEX idx_api_keys_key_hash ON api_keys(key_hash);
|
||||
CREATE INDEX idx_usage_logs_user_id ON usage_logs(user_id);
|
||||
CREATE INDEX idx_usage_logs_created_at ON usage_logs(created_at);
|
||||
27
crates/db/migrations/002_oauth_and_subscriptions.sql
Normal file
27
crates/db/migrations/002_oauth_and_subscriptions.sql
Normal file
@@ -0,0 +1,27 @@
|
||||
CREATE TABLE oauth_accounts (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
provider VARCHAR(50) NOT NULL,
|
||||
provider_account_id VARCHAR(255) NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE(provider, provider_account_id)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_oauth_accounts_user_id ON oauth_accounts(user_id);
|
||||
|
||||
CREATE TABLE subscriptions (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
stripe_customer_id VARCHAR(255),
|
||||
stripe_subscription_id VARCHAR(255),
|
||||
stripe_price_id VARCHAR(255),
|
||||
status VARCHAR(50) NOT NULL DEFAULT 'incomplete',
|
||||
tier VARCHAR(50) NOT NULL DEFAULT 'free',
|
||||
current_period_start TIMESTAMPTZ,
|
||||
current_period_end TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX idx_subscriptions_user_id ON subscriptions(user_id);
|
||||
CREATE INDEX idx_subscriptions_stripe_customer ON subscriptions(stripe_customer_id);
|
||||
21
crates/db/migrations/003_teams.sql
Normal file
21
crates/db/migrations/003_teams.sql
Normal file
@@ -0,0 +1,21 @@
|
||||
CREATE TABLE teams (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
name VARCHAR(255) NOT NULL,
|
||||
slug VARCHAR(255) UNIQUE NOT NULL,
|
||||
owner_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE team_members (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
team_id UUID NOT NULL REFERENCES teams(id) ON DELETE CASCADE,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
role VARCHAR(50) NOT NULL DEFAULT 'member',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE(team_id, user_id)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_teams_owner ON teams(owner_id);
|
||||
CREATE INDEX idx_team_members_team ON team_members(team_id);
|
||||
CREATE INDEX idx_team_members_user ON team_members(user_id);
|
||||
7
crates/db/src/connection.rs
Normal file
7
crates/db/src/connection.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
use sqlx::PgPool;
|
||||
|
||||
pub type DbPool = PgPool;
|
||||
|
||||
pub async fn create_pool(database_url: &str) -> Result<DbPool, sqlx::Error> {
|
||||
PgPool::connect(database_url).await
|
||||
}
|
||||
4
crates/db/src/lib.rs
Normal file
4
crates/db/src/lib.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub mod connection;
|
||||
pub mod repos;
|
||||
|
||||
pub use connection::DbPool;
|
||||
64
crates/db/src/repos/api_keys.rs
Normal file
64
crates/db/src/repos/api_keys.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use shared::models::ApiKey;
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn find_by_key_hash(pool: &PgPool, key_hash: &str) -> Result<Option<ApiKey>, sqlx::Error> {
|
||||
sqlx::query_as::<_, ApiKey>(
|
||||
r#"SELECT id, user_id, key_hash, name, last_used_at, created_at FROM api_keys WHERE key_hash = $1"#,
|
||||
)
|
||||
.bind(key_hash)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn create(
|
||||
pool: &PgPool,
|
||||
user_id: Uuid,
|
||||
key_hash: &str,
|
||||
name: &str,
|
||||
) -> Result<ApiKey, sqlx::Error> {
|
||||
sqlx::query_as::<_, ApiKey>(
|
||||
r#"INSERT INTO api_keys (id, user_id, key_hash, name)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, user_id, key_hash, name, last_used_at, created_at"#,
|
||||
)
|
||||
.bind(Uuid::new_v4())
|
||||
.bind(user_id)
|
||||
.bind(key_hash)
|
||||
.bind(name)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn list_by_user(pool: &PgPool, user_id: Uuid) -> Result<Vec<ApiKey>, sqlx::Error> {
|
||||
sqlx::query_as::<_, ApiKey>(
|
||||
r#"SELECT id, user_id, key_hash, name, last_used_at, created_at FROM api_keys WHERE user_id = $1 ORDER BY created_at DESC"#,
|
||||
)
|
||||
.bind(user_id)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn update_last_used(pool: &PgPool, id: Uuid) -> Result<(), sqlx::Error> {
|
||||
sqlx::query(
|
||||
r#"UPDATE api_keys SET last_used_at = $1 WHERE id = $2"#,
|
||||
)
|
||||
.bind(chrono::Utc::now())
|
||||
.bind(id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete_by_id(pool: &PgPool, id: Uuid, user_id: Uuid) -> Result<bool, sqlx::Error> {
|
||||
let result = sqlx::query(
|
||||
r#"DELETE FROM api_keys WHERE id = $1 AND user_id = $2"#,
|
||||
)
|
||||
.bind(id)
|
||||
.bind(user_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
6
crates/db/src/repos/mod.rs
Normal file
6
crates/db/src/repos/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod api_keys;
|
||||
pub mod oauth;
|
||||
pub mod subscriptions;
|
||||
pub mod teams;
|
||||
pub mod usage_logs;
|
||||
pub mod users;
|
||||
37
crates/db/src/repos/oauth.rs
Normal file
37
crates/db/src/repos/oauth.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use shared::models::OAuthAccount;
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn find_by_provider(
|
||||
pool: &PgPool,
|
||||
provider: &str,
|
||||
provider_account_id: &str,
|
||||
) -> Result<Option<OAuthAccount>, sqlx::Error> {
|
||||
sqlx::query_as::<_, OAuthAccount>(
|
||||
r#"SELECT id, user_id, provider, provider_account_id, created_at
|
||||
FROM oauth_accounts WHERE provider = $1 AND provider_account_id = $2"#,
|
||||
)
|
||||
.bind(provider)
|
||||
.bind(provider_account_id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn create(
|
||||
pool: &PgPool,
|
||||
user_id: Uuid,
|
||||
provider: &str,
|
||||
provider_account_id: &str,
|
||||
) -> Result<OAuthAccount, sqlx::Error> {
|
||||
sqlx::query_as::<_, OAuthAccount>(
|
||||
r#"INSERT INTO oauth_accounts (id, user_id, provider, provider_account_id)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, user_id, provider, provider_account_id, created_at"#,
|
||||
)
|
||||
.bind(Uuid::new_v4())
|
||||
.bind(user_id)
|
||||
.bind(provider)
|
||||
.bind(provider_account_id)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
}
|
||||
76
crates/db/src/repos/subscriptions.rs
Normal file
76
crates/db/src/repos/subscriptions.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
use shared::models::Subscription;
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn find_by_user(pool: &PgPool, user_id: Uuid) -> Result<Option<Subscription>, sqlx::Error> {
|
||||
sqlx::query_as::<_, Subscription>(
|
||||
r#"SELECT id, user_id, stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
||||
status, tier, current_period_start, current_period_end, created_at, updated_at
|
||||
FROM subscriptions WHERE user_id = $1 ORDER BY created_at DESC LIMIT 1"#,
|
||||
)
|
||||
.bind(user_id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn find_by_stripe_subscription(
|
||||
pool: &PgPool,
|
||||
stripe_subscription_id: &str,
|
||||
) -> Result<Option<Subscription>, sqlx::Error> {
|
||||
sqlx::query_as::<_, Subscription>(
|
||||
r#"SELECT id, user_id, stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
||||
status, tier, current_period_start, current_period_end, created_at, updated_at
|
||||
FROM subscriptions WHERE stripe_subscription_id = $1"#,
|
||||
)
|
||||
.bind(stripe_subscription_id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn create_or_update(
|
||||
pool: &PgPool,
|
||||
user_id: Uuid,
|
||||
stripe_customer_id: Option<&str>,
|
||||
stripe_subscription_id: Option<&str>,
|
||||
stripe_price_id: Option<&str>,
|
||||
status: &str,
|
||||
tier: &str,
|
||||
) -> Result<Subscription, sqlx::Error> {
|
||||
sqlx::query_as::<_, Subscription>(
|
||||
r#"INSERT INTO subscriptions (id, user_id, stripe_customer_id, stripe_subscription_id, stripe_price_id, status, tier)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (user_id) DO UPDATE SET
|
||||
stripe_customer_id = EXCLUDED.stripe_customer_id,
|
||||
stripe_subscription_id = EXCLUDED.stripe_subscription_id,
|
||||
stripe_price_id = EXCLUDED.stripe_price_id,
|
||||
status = EXCLUDED.status,
|
||||
tier = EXCLUDED.tier,
|
||||
updated_at = NOW()
|
||||
RETURNING id, user_id, stripe_customer_id, stripe_subscription_id, stripe_price_id,
|
||||
status, tier, current_period_start, current_period_end, created_at, updated_at"#,
|
||||
)
|
||||
.bind(Uuid::new_v4())
|
||||
.bind(user_id)
|
||||
.bind(stripe_customer_id)
|
||||
.bind(stripe_subscription_id)
|
||||
.bind(stripe_price_id)
|
||||
.bind(status)
|
||||
.bind(tier)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn update_status(
|
||||
pool: &PgPool,
|
||||
stripe_subscription_id: &str,
|
||||
status: &str,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
sqlx::query(
|
||||
r#"UPDATE subscriptions SET status = $1, updated_at = NOW() WHERE stripe_subscription_id = $2"#,
|
||||
)
|
||||
.bind(status)
|
||||
.bind(stripe_subscription_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
68
crates/db/src/repos/teams.rs
Normal file
68
crates/db/src/repos/teams.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
use shared::models::{Team, TeamMember};
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn create(pool: &PgPool, name: &str, slug: &str, owner_id: Uuid) -> Result<Team, sqlx::Error> {
|
||||
sqlx::query_as::<_, Team>(
|
||||
r#"INSERT INTO teams (id, name, slug, owner_id)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, name, slug, owner_id, created_at, updated_at"#,
|
||||
)
|
||||
.bind(Uuid::new_v4())
|
||||
.bind(name)
|
||||
.bind(slug)
|
||||
.bind(owner_id)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn find_by_slug(pool: &PgPool, slug: &str) -> Result<Option<Team>, sqlx::Error> {
|
||||
sqlx::query_as::<_, Team>(
|
||||
r#"SELECT id, name, slug, owner_id, created_at, updated_at FROM teams WHERE slug = $1"#,
|
||||
)
|
||||
.bind(slug)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn find_by_id(pool: &PgPool, id: Uuid) -> Result<Option<Team>, sqlx::Error> {
|
||||
sqlx::query_as::<_, Team>(
|
||||
r#"SELECT id, name, slug, owner_id, created_at, updated_at FROM teams WHERE id = $1"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn add_member(pool: &PgPool, team_id: Uuid, user_id: Uuid, role: &str) -> Result<TeamMember, sqlx::Error> {
|
||||
sqlx::query_as::<_, TeamMember>(
|
||||
r#"INSERT INTO team_members (id, team_id, user_id, role)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, team_id, user_id, role, created_at"#,
|
||||
)
|
||||
.bind(Uuid::new_v4())
|
||||
.bind(team_id)
|
||||
.bind(user_id)
|
||||
.bind(role)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn list_members(pool: &PgPool, team_id: Uuid) -> Result<Vec<TeamMember>, sqlx::Error> {
|
||||
sqlx::query_as::<_, TeamMember>(
|
||||
r#"SELECT id, team_id, user_id, role, created_at FROM team_members WHERE team_id = $1"#,
|
||||
)
|
||||
.bind(team_id)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn find_member(pool: &PgPool, team_id: Uuid, user_id: Uuid) -> Result<Option<TeamMember>, sqlx::Error> {
|
||||
sqlx::query_as::<_, TeamMember>(
|
||||
r#"SELECT id, team_id, user_id, role, created_at FROM team_members WHERE team_id = $1 AND user_id = $2"#,
|
||||
)
|
||||
.bind(team_id)
|
||||
.bind(user_id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
}
|
||||
47
crates/db/src/repos/usage_logs.rs
Normal file
47
crates/db/src/repos/usage_logs.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use shared::models::UsageLog;
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn create(
|
||||
pool: &PgPool,
|
||||
user_id: Uuid,
|
||||
api_key_id: Uuid,
|
||||
endpoint: &str,
|
||||
url: &str,
|
||||
status: &str,
|
||||
credits_used: i64,
|
||||
duration_ms: i64,
|
||||
) -> Result<UsageLog, sqlx::Error> {
|
||||
sqlx::query_as::<_, UsageLog>(
|
||||
r#"INSERT INTO usage_logs (id, user_id, api_key_id, endpoint, url, status, credits_used, duration_ms)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, user_id, api_key_id, endpoint, url, status, credits_used, duration_ms, created_at"#,
|
||||
)
|
||||
.bind(Uuid::new_v4())
|
||||
.bind(user_id)
|
||||
.bind(api_key_id)
|
||||
.bind(endpoint)
|
||||
.bind(url)
|
||||
.bind(status)
|
||||
.bind(credits_used)
|
||||
.bind(duration_ms)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn list_by_user(
|
||||
pool: &PgPool,
|
||||
user_id: Uuid,
|
||||
limit: i64,
|
||||
offset: i64,
|
||||
) -> Result<Vec<UsageLog>, sqlx::Error> {
|
||||
sqlx::query_as::<_, UsageLog>(
|
||||
r#"SELECT id, user_id, api_key_id, endpoint, url, status, credits_used, duration_ms, created_at
|
||||
FROM usage_logs WHERE user_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3"#,
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
}
|
||||
64
crates/db/src/repos/users.rs
Normal file
64
crates/db/src/repos/users.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use shared::models::User;
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn find_by_email(pool: &PgPool, email: &str) -> Result<Option<User>, sqlx::Error> {
|
||||
sqlx::query_as::<_, User>(
|
||||
r#"SELECT id, email, password_hash, google_id, credits, tier, created_at, updated_at FROM users WHERE email = $1"#,
|
||||
)
|
||||
.bind(email)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn find_by_id(pool: &PgPool, id: Uuid) -> Result<Option<User>, sqlx::Error> {
|
||||
sqlx::query_as::<_, User>(
|
||||
r#"SELECT id, email, password_hash, google_id, credits, tier, created_at, updated_at FROM users WHERE id = $1"#,
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn create(
|
||||
pool: &PgPool,
|
||||
email: &str,
|
||||
password_hash: Option<&str>,
|
||||
google_id: Option<&str>,
|
||||
) -> Result<User, sqlx::Error> {
|
||||
sqlx::query_as::<_, User>(
|
||||
r#"INSERT INTO users (id, email, password_hash, google_id, credits, tier)
|
||||
VALUES ($1, $2, $3, $4, 30, 'free')
|
||||
RETURNING id, email, password_hash, google_id, credits, tier, created_at, updated_at"#,
|
||||
)
|
||||
.bind(Uuid::new_v4())
|
||||
.bind(email)
|
||||
.bind(password_hash)
|
||||
.bind(google_id)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn deduct_credits(pool: &PgPool, user_id: Uuid, amount: i64) -> Result<bool, sqlx::Error> {
|
||||
let result = sqlx::query(
|
||||
r#"UPDATE users SET credits = credits - $1 WHERE id = $2 AND credits >= $1"#,
|
||||
)
|
||||
.bind(amount)
|
||||
.bind(user_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn add_credits(pool: &PgPool, user_id: Uuid, amount: i64) -> Result<(), sqlx::Error> {
|
||||
sqlx::query(
|
||||
r#"UPDATE users SET credits = credits + $1 WHERE id = $2"#,
|
||||
)
|
||||
.bind(amount)
|
||||
.bind(user_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
18
crates/db/tests/db_test.rs
Normal file
18
crates/db/tests/db_test.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use shared::models::User;
|
||||
|
||||
#[test]
|
||||
fn test_user_model_serialization() {
|
||||
let user = User {
|
||||
id: uuid::Uuid::new_v4(),
|
||||
email: "test@example.com".to_string(),
|
||||
password_hash: Some("hash".to_string()),
|
||||
google_id: None,
|
||||
credits: 30,
|
||||
tier: "free".to_string(),
|
||||
created_at: chrono::Utc::now(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&user).unwrap();
|
||||
assert!(json.contains("test@example.com"));
|
||||
}
|
||||
15
crates/shared/Cargo.toml
Normal file
15
crates/shared/Cargo.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[package]
|
||||
name = "shared"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
url = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
config = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
26
crates/shared/src/api.rs
Normal file
26
crates/shared/src/api.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ApiResponse<T> {
|
||||
pub success: bool,
|
||||
pub data: Option<T>,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
impl<T> ApiResponse<T> {
|
||||
pub fn ok(data: T) -> Self {
|
||||
Self {
|
||||
success: true,
|
||||
data: Some(data),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn err(msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
success: false,
|
||||
data: None,
|
||||
error: Some(msg.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
25
crates/shared/src/config.rs
Normal file
25
crates/shared/src/config.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct AppConfig {
|
||||
pub database_url: String,
|
||||
pub redis_url: String,
|
||||
pub jwt_secret: String,
|
||||
pub s3_endpoint: String,
|
||||
pub s3_bucket: String,
|
||||
pub s3_region: String,
|
||||
pub s3_access_key: String,
|
||||
pub s3_secret_key: String,
|
||||
pub app_port: u16,
|
||||
pub app_host: String,
|
||||
pub playwright_script_path: String,
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
pub fn from_env() -> Result<Self, config::ConfigError> {
|
||||
config::Config::builder()
|
||||
.add_source(config::Environment::default())
|
||||
.build()?
|
||||
.try_deserialize()
|
||||
}
|
||||
}
|
||||
41
crates/shared/src/error.rs
Normal file
41
crates/shared/src/error.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AppError {
|
||||
#[error("Database error: {0}")]
|
||||
Database(#[from] sqlx::Error),
|
||||
#[error("Redis error: {0}")]
|
||||
Redis(String),
|
||||
#[error("S3 error: {0}")]
|
||||
S3(String),
|
||||
#[error("Invalid URL: {0}")]
|
||||
InvalidUrl(String),
|
||||
#[error("Browser automation failed: {0}")]
|
||||
BrowserError(String),
|
||||
#[error("Rate limit exceeded")]
|
||||
RateLimit,
|
||||
#[error("Insufficient credits")]
|
||||
InsufficientCredits,
|
||||
#[error("Unauthorized")]
|
||||
Unauthorized,
|
||||
#[error("Not found")]
|
||||
NotFound,
|
||||
#[error("Bad request: {0}")]
|
||||
BadRequest(String),
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
impl AppError {
|
||||
pub fn status_code(&self) -> u16 {
|
||||
match self {
|
||||
AppError::InvalidUrl(_) | AppError::BadRequest(_) => 400,
|
||||
AppError::Unauthorized => 401,
|
||||
AppError::InsufficientCredits => 403,
|
||||
AppError::NotFound => 404,
|
||||
AppError::RateLimit => 429,
|
||||
AppError::BrowserError(_) => 500,
|
||||
_ => 500,
|
||||
}
|
||||
}
|
||||
}
|
||||
24
crates/shared/src/jobs.rs
Normal file
24
crates/shared/src/jobs.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::models::CrawlOptions;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CrawlJob {
|
||||
pub job_id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub api_key_id: Uuid,
|
||||
pub endpoint: String,
|
||||
pub url: String,
|
||||
pub options: CrawlOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CrawlResult {
|
||||
pub job_id: Uuid,
|
||||
pub success: bool,
|
||||
pub data: Option<serde_json::Value>,
|
||||
pub error: Option<String>,
|
||||
pub duration_ms: i64,
|
||||
pub file_url: Option<String>,
|
||||
}
|
||||
6
crates/shared/src/lib.rs
Normal file
6
crates/shared/src/lib.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod api;
|
||||
pub mod config;
|
||||
pub mod error;
|
||||
pub mod jobs;
|
||||
pub mod models;
|
||||
pub mod queue;
|
||||
136
crates/shared/src/models.rs
Normal file
136
crates/shared/src/models.rs
Normal file
@@ -0,0 +1,136 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct User {
|
||||
pub id: Uuid,
|
||||
pub email: String,
|
||||
pub password_hash: Option<String>,
|
||||
pub google_id: Option<String>,
|
||||
pub credits: i64,
|
||||
pub tier: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct ApiKey {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub key_hash: String,
|
||||
pub name: String,
|
||||
pub last_used_at: Option<DateTime<Utc>>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct OAuthAccount {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub provider: String,
|
||||
pub provider_account_id: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct Subscription {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub stripe_customer_id: Option<String>,
|
||||
pub stripe_subscription_id: Option<String>,
|
||||
pub stripe_price_id: Option<String>,
|
||||
pub status: String,
|
||||
pub tier: String,
|
||||
pub current_period_start: Option<DateTime<Utc>>,
|
||||
pub current_period_end: Option<DateTime<Utc>>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct UsageLog {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub api_key_id: Uuid,
|
||||
pub endpoint: String,
|
||||
pub url: String,
|
||||
pub status: String,
|
||||
pub credits_used: i64,
|
||||
pub duration_ms: i64,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct Team {
|
||||
pub id: Uuid,
|
||||
pub name: String,
|
||||
pub slug: String,
|
||||
pub owner_id: Uuid,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct TeamMember {
|
||||
pub id: Uuid,
|
||||
pub team_id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub role: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CrawlRequest {
|
||||
pub url: String,
|
||||
#[serde(default)]
|
||||
pub options: CrawlOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct CrawlOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub full_page: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub width: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub height: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub wait_for: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub timeout: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user_agent: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub selectors: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub include_html: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub webhook_url: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub session_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub headers: Option<std::collections::HashMap<String, String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mobile: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub scroll_to_bottom: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stealth: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_proxy: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub solve_captcha: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CrawlResponse {
|
||||
pub success: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub calls_remaining: Option<i64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
27
crates/shared/src/queue.rs
Normal file
27
crates/shared/src/queue.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::models::CrawlOptions;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Job {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub api_key_id: Uuid,
|
||||
pub endpoint: String,
|
||||
pub url: String,
|
||||
pub options: CrawlOptions,
|
||||
pub webhook_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JobResult {
|
||||
pub id: Uuid,
|
||||
pub success: bool,
|
||||
pub data: Option<serde_json::Value>,
|
||||
pub error: Option<String>,
|
||||
pub duration_ms: i64,
|
||||
}
|
||||
|
||||
pub const QUEUE_NAME: &str = "crawlapi:jobs";
|
||||
pub const RESULT_PREFIX: &str = "crawlapi:results:";
|
||||
26
crates/worker/Cargo.toml
Normal file
26
crates/worker/Cargo.toml
Normal file
@@ -0,0 +1,26 @@
|
||||
[package]
|
||||
name = "worker"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
shared = { path = "../shared" }
|
||||
db = { path = "../db" }
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
redis = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["json", "env-filter"] }
|
||||
chrono = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
aws-config = { workspace = true }
|
||||
aws-sdk-s3 = { workspace = true }
|
||||
config = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
sentry = "0.36"
|
||||
sqlx = { workspace = true }
|
||||
230
crates/worker/src/main.rs
Normal file
230
crates/worker/src/main.rs
Normal file
@@ -0,0 +1,230 @@
|
||||
use chrono::Utc;
|
||||
use db::connection::create_pool;
|
||||
use redis::AsyncCommands;
|
||||
use shared::{
|
||||
config::AppConfig,
|
||||
queue::{Job, JobResult, QUEUE_NAME, RESULT_PREFIX},
|
||||
};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::process::Command;
|
||||
use tokio::time::sleep;
|
||||
use tracing::{info_span, Instrument};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let sentry_dsn = std::env::var("SENTRY_DSN").ok();
|
||||
let _guard = sentry_dsn.map(|dsn| {
|
||||
sentry::init((dsn, sentry::ClientOptions {
|
||||
release: sentry::release_name!(),
|
||||
..Default::default()
|
||||
}))
|
||||
});
|
||||
|
||||
let json_logging = std::env::var("JSON_LOGGING").unwrap_or_else(|_| "false".to_string()) == "true";
|
||||
|
||||
if json_logging {
|
||||
tracing_subscriber::registry()
|
||||
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| "worker=debug".into()))
|
||||
.with(
|
||||
tracing_subscriber::fmt::layer()
|
||||
.json()
|
||||
.with_current_span(true)
|
||||
.with_span_list(true)
|
||||
.with_target(true),
|
||||
)
|
||||
.init();
|
||||
} else {
|
||||
tracing_subscriber::registry()
|
||||
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| "worker=debug".into()))
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
}
|
||||
|
||||
let config = AppConfig::from_env()?;
|
||||
let db = create_pool(&config.database_url).await?;
|
||||
|
||||
let redis_client = redis::Client::open(config.redis_url.clone())?;
|
||||
let mut redis_conn = redis_client.get_multiplexed_tokio_connection().await?;
|
||||
|
||||
tracing::info!("Worker started. Waiting for jobs...");
|
||||
|
||||
loop {
|
||||
let job_json: Option<(String, String)> = redis::cmd("BLPOP")
|
||||
.arg(QUEUE_NAME)
|
||||
.arg(5)
|
||||
.query_async(&mut redis_conn)
|
||||
.await?;
|
||||
|
||||
if let Some((_, json)) = job_json {
|
||||
let job: Job = match serde_json::from_str(&json) {
|
||||
Ok(j) => j,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to deserialize job: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let span = info_span!(
|
||||
"process_job",
|
||||
job_id = %job.id,
|
||||
user_id = %job.user_id,
|
||||
endpoint = %job.endpoint,
|
||||
url = %job.url,
|
||||
);
|
||||
|
||||
process_single_job(&config, &db, &mut redis_conn, &job)
|
||||
.instrument(span)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_single_job(
|
||||
config: &AppConfig,
|
||||
db: &sqlx::PgPool,
|
||||
redis_conn: &mut redis::aio::MultiplexedConnection,
|
||||
job: &Job,
|
||||
) {
|
||||
tracing::info!("Processing job {}: {} {}", job.id, job.endpoint, job.url);
|
||||
let start = Instant::now();
|
||||
|
||||
let result = process_job_with_retry(config, job).await;
|
||||
let duration = start.elapsed().as_millis() as i64;
|
||||
|
||||
let job_result = match result {
|
||||
Ok(data) => JobResult {
|
||||
id: job.id,
|
||||
success: true,
|
||||
data: Some(data),
|
||||
error: None,
|
||||
duration_ms: duration,
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!("Job {} failed after retries: {}", job.id, e);
|
||||
JobResult {
|
||||
id: job.id,
|
||||
success: false,
|
||||
data: None,
|
||||
error: Some(e.clone()),
|
||||
duration_ms: duration,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let result_json = serde_json::to_string(&job_result).unwrap();
|
||||
let result_key = format!("{}{}", RESULT_PREFIX, job.id);
|
||||
let _: () = redis_conn.set_ex(&result_key, result_json, 300).await.unwrap_or(());
|
||||
|
||||
let status = if job_result.success { "success" } else { "error" };
|
||||
let _ = db::repos::usage_logs::create(
|
||||
db,
|
||||
job.user_id,
|
||||
job.api_key_id,
|
||||
&job.endpoint,
|
||||
&job.url,
|
||||
status,
|
||||
1,
|
||||
duration,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Some(webhook_url) = &job.webhook_url {
|
||||
let _ = send_webhook(webhook_url, &job_result).await;
|
||||
}
|
||||
|
||||
if !job_result.success {
|
||||
let dlq_key = format!("crawlapi:dlq:{}", job.id);
|
||||
let dlq_data = serde_json::json!({
|
||||
"job": job,
|
||||
"error": job_result.error,
|
||||
"failed_at": Utc::now().to_rfc3339(),
|
||||
});
|
||||
let _: () = redis_conn.set_ex(dlq_key, dlq_data.to_string(), 86400).await.unwrap_or(());
|
||||
tracing::warn!("Job {} moved to DLQ", job.id);
|
||||
}
|
||||
|
||||
tracing::info!("Job {} completed in {}ms", job.id, duration);
|
||||
}
|
||||
|
||||
async fn process_job_with_retry(config: &AppConfig, job: &Job) -> Result<serde_json::Value, String> {
|
||||
let max_retries = 3;
|
||||
let mut last_error = String::new();
|
||||
|
||||
for attempt in 0..max_retries {
|
||||
if attempt > 0 {
|
||||
let backoff = Duration::from_secs(2_u64.pow(attempt as u32));
|
||||
tracing::info!(
|
||||
"Retrying job {} (attempt {}/{}), waiting {:?}",
|
||||
job.id,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
backoff
|
||||
);
|
||||
sleep(backoff).await;
|
||||
}
|
||||
|
||||
match process_job(config, job).await {
|
||||
Ok(data) => return Ok(data),
|
||||
Err(e) => {
|
||||
last_error = e;
|
||||
tracing::warn!(
|
||||
"Job {} attempt {}/{} failed: {}",
|
||||
job.id,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
last_error
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(format!("Failed after {} retries: {}", max_retries, last_error))
|
||||
}
|
||||
|
||||
async fn process_job(config: &AppConfig, job: &Job) -> Result<serde_json::Value, String> {
|
||||
let script_path = &config.playwright_script_path;
|
||||
|
||||
let mut cmd = Command::new("node");
|
||||
cmd.arg(script_path)
|
||||
.arg(&job.endpoint)
|
||||
.arg(serde_json::to_string(&job.url).unwrap())
|
||||
.arg(serde_json::to_string(&job.options).unwrap())
|
||||
.env("OUTPUT_DIR", "/tmp/crawlapi")
|
||||
.env("BROWSER_POOL_SIZE", std::env::var("BROWSER_POOL_SIZE").unwrap_or_else(|_| "5".to_string()))
|
||||
.env("MAX_PAGES_PER_BROWSER", std::env::var("MAX_PAGES_PER_BROWSER").unwrap_or_else(|_| "10".to_string()));
|
||||
|
||||
if let Ok(proxy_url) = std::env::var("PROXY_URL") {
|
||||
cmd.env("PROXY_URL", proxy_url);
|
||||
}
|
||||
|
||||
if let Ok(captcha_key) = std::env::var("CAPTCHA_API_KEY") {
|
||||
cmd.env("CAPTCHA_API_KEY", captcha_key);
|
||||
}
|
||||
|
||||
let output = cmd.output()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to execute browser: {}", e))?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Err(format!("Browser error: {}", String::from_utf8_lossy(&output.stderr)));
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let result: serde_json::Value = serde_json::from_str(&stdout)
|
||||
.map_err(|e| format!("Invalid JSON from browser: {} | output: {}", e, stdout))?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn send_webhook(url: &str, result: &JobResult) -> Result<(), reqwest::Error> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(10))
|
||||
.build()?;
|
||||
let _ = client
|
||||
.post(url)
|
||||
.json(result)
|
||||
.send()
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user