"""Tenant DB connection manager. Master DB: creates a fresh connection each time (very light load thanks to tenant_id → db_name cache, so we only hit master ~once per 5 min). Tenant DBs: use psycopg2.pool.ThreadedConnectionPool with maxconn=20. """ import time import threading import psycopg2 from psycopg2 import pool from config import MASTER_DB_URL, TENANT_DB_URL_TEMPLATE # ─── Tenant Pools ────────────────────────────── _tenant_pools = {} # ─── Tenant cache ────────────────────────────── _tenant_cache = {} _tenant_cache_ttl = 300 _tenant_cache_lock = threading.Lock() def _get_tenant_pool(db_name): """Lazy-initialize tenant DB connection pool by db_name.""" global _tenant_pools if db_name not in _tenant_pools: dsn = TENANT_DB_URL_TEMPLATE.format(db_name=db_name) _tenant_pools[db_name] = pool.ThreadedConnectionPool( minconn=2, maxconn=20, dsn=dsn ) return _tenant_pools[db_name] def _resolve_tenant_db(tenant_id): """Return db_name for tenant_id, using cache first.""" now = time.time() with _tenant_cache_lock: entry = _tenant_cache.get(tenant_id) if entry and entry['expires'] > now: return entry['db_name'] # Cache miss or expired — query master DB with a fresh connection conn = psycopg2.connect(MASTER_DB_URL) try: cur = conn.cursor() cur.execute( "SELECT db_name FROM tenants WHERE id = %s AND is_active = true", (tenant_id,) ) row = cur.fetchone() cur.close() db_name = row[0] if row else None finally: conn.close() if db_name: with _tenant_cache_lock: _tenant_cache[tenant_id] = {'db_name': db_name, 'expires': now + _tenant_cache_ttl} return db_name class _PooledConnection: """Thin wrapper that delegates all attribute access to the real psycopg2 connection, but returns it to the pool on .close(). """ __slots__ = ('_conn', '_pool') def __init__(self, conn, db_pool): self._conn = conn self._pool = db_pool def __getattr__(self, name): return getattr(self._conn, name) def close(self): try: if self._conn: try: self._conn.rollback() except Exception: pass self._pool.putconn(self._conn) except Exception: try: self._conn.close() except Exception: pass def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() # ─── Public API ──────────────────────────────── def get_master_conn(): """Get a direct connection to the master DB (no pool). Caller MUST close() the connection when done. """ return psycopg2.connect(MASTER_DB_URL) def get_tenant_conn(tenant_id): """Get a pooled connection to a tenant's DB.""" db_name = _resolve_tenant_db(tenant_id) if not db_name: raise ValueError(f"Tenant {tenant_id} not found or inactive") p = _get_tenant_pool(db_name) return _PooledConnection(p.getconn(), p) def get_tenant_conn_by_dbname(db_name): """Get a pooled connection to a tenant DB directly by name.""" p = _get_tenant_pool(db_name) return _PooledConnection(p.getconn(), p)