diff --git a/console/db.py b/console/db.py new file mode 100644 index 0000000..cda8815 --- /dev/null +++ b/console/db.py @@ -0,0 +1,770 @@ +""" +Database abstraction layer for the AUTOPARTES console application. + +Provides all data access methods the console app needs, reading from the +same SQLite database used by the Flask web dashboard. +""" + +import sqlite3 +from datetime import datetime, timedelta +from typing import Optional + +from console.config import DB_PATH + + +class Database: + """Thin abstraction over the vehicle_database SQLite database.""" + + def __init__(self, db_path: Optional[str] = None): + self.db_path = db_path or DB_PATH + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _connect(self) -> sqlite3.Connection: + """Open a connection with row_factory set to sqlite3.Row.""" + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + return conn + + def _query(self, sql: str, params: tuple = (), one: bool = False): + """Execute a SELECT and return list[dict] (or a single dict if *one*).""" + conn = self._connect() + try: + cursor = conn.cursor() + cursor.execute(sql, params) + if one: + row = cursor.fetchone() + return dict(row) if row else None + return [dict(r) for r in cursor.fetchall()] + finally: + conn.close() + + def _execute(self, sql: str, params: tuple = ()) -> int: + """Execute an INSERT/UPDATE/DELETE and return lastrowid.""" + conn = self._connect() + try: + cursor = conn.cursor() + cursor.execute(sql, params) + conn.commit() + return cursor.lastrowid + finally: + conn.close() + + # ================================================================== + # Vehicle navigation + # ================================================================== + + def get_brands(self) -> list[dict]: + """Return all brands ordered by name: [{id, name, country}].""" + return self._query( + "SELECT id, name, country FROM brands ORDER BY name" + ) + + def get_models(self, brand: Optional[str] = None) -> list[dict]: + """Return models, optionally filtered by brand name (case-insensitive).""" + if brand: + return self._query( + """ + SELECT DISTINCT m.id, m.name + FROM models m + JOIN brands b ON m.brand_id = b.id + WHERE UPPER(b.name) = UPPER(?) + ORDER BY m.name + """, + (brand,), + ) + return self._query( + "SELECT id, name FROM models ORDER BY name" + ) + + def get_years( + self, brand: Optional[str] = None, model: Optional[str] = None + ) -> list[dict]: + """Return years, optionally filtered by brand and/or model.""" + sql = """ + SELECT DISTINCT y.id, y.year + FROM years y + JOIN model_year_engine mye ON y.id = mye.year_id + JOIN models m ON mye.model_id = m.id + JOIN brands b ON m.brand_id = b.id + WHERE 1=1 + """ + params: list = [] + if brand: + sql += " AND UPPER(b.name) = UPPER(?)" + params.append(brand) + if model: + sql += " AND UPPER(m.name) = UPPER(?)" + params.append(model) + sql += " ORDER BY y.year DESC" + return self._query(sql, tuple(params)) + + def get_engines( + self, + brand: Optional[str] = None, + model: Optional[str] = None, + year: Optional[int] = None, + ) -> list[dict]: + """Return engines, optionally filtered by brand/model/year.""" + sql = """ + SELECT DISTINCT e.id, e.name, e.displacement_cc, e.cylinders, + e.fuel_type, e.power_hp, e.torque_nm, e.engine_code + FROM engines e + JOIN model_year_engine mye ON e.id = mye.engine_id + JOIN models m ON mye.model_id = m.id + JOIN brands b ON m.brand_id = b.id + JOIN years y ON mye.year_id = y.id + WHERE 1=1 + """ + params: list = [] + if brand: + sql += " AND UPPER(b.name) = UPPER(?)" + params.append(brand) + if model: + sql += " AND UPPER(m.name) = UPPER(?)" + params.append(model) + if year: + sql += " AND y.year = ?" + params.append(int(year)) + sql += " ORDER BY e.name" + return self._query(sql, tuple(params)) + + def get_model_year_engine( + self, + brand: str, + model: str, + year: int, + engine_id: Optional[int] = None, + ) -> list[dict]: + """Return model_year_engine records for a specific vehicle config.""" + sql = """ + SELECT + mye.id, + b.name AS brand, + m.name AS model, + y.year, + e.id AS engine_id, + e.name AS engine, + mye.trim_level, + mye.drivetrain, + mye.transmission + FROM model_year_engine mye + JOIN models m ON mye.model_id = m.id + JOIN brands b ON m.brand_id = b.id + JOIN years y ON mye.year_id = y.id + JOIN engines e ON mye.engine_id = e.id + WHERE UPPER(b.name) = UPPER(?) + AND UPPER(m.name) = UPPER(?) + AND y.year = ? + """ + params: list = [brand, model, int(year)] + if engine_id: + sql += " AND e.id = ?" + params.append(engine_id) + sql += " ORDER BY e.name, mye.trim_level" + return self._query(sql, tuple(params)) + + # ================================================================== + # Parts catalog + # ================================================================== + + def get_categories(self) -> list[dict]: + """Return all part categories ordered by display_order.""" + return self._query( + """ + SELECT id, name, name_es, slug, icon_name, display_order + FROM part_categories + ORDER BY display_order, name + """ + ) + + def get_groups(self, category_id: int) -> list[dict]: + """Return part groups for a given category.""" + return self._query( + """ + SELECT id, name, name_es, slug, display_order + FROM part_groups + WHERE category_id = ? + ORDER BY display_order, name + """, + (category_id,), + ) + + def get_parts( + self, + group_id: Optional[int] = None, + mye_id: Optional[int] = None, + page: int = 1, + per_page: int = 15, + ) -> list[dict]: + """Return parts with optional group/vehicle filter and pagination.""" + per_page = min(per_page, 100) + offset = (page - 1) * per_page + + sql = """ + SELECT + p.id, + p.oem_part_number, + p.name, + p.name_es, + p.group_id, + pg.name AS group_name, + pc.name AS category_name + FROM parts p + JOIN part_groups pg ON p.group_id = pg.id + JOIN part_categories pc ON pg.category_id = pc.id + """ + where_parts: list[str] = [] + params: list = [] + + if group_id: + where_parts.append("p.group_id = ?") + params.append(group_id) + if mye_id: + where_parts.append( + "p.id IN (SELECT part_id FROM vehicle_parts WHERE model_year_engine_id = ?)" + ) + params.append(mye_id) + + if where_parts: + sql += " WHERE " + " AND ".join(where_parts) + + sql += " ORDER BY p.name LIMIT ? OFFSET ?" + params.extend([per_page, offset]) + + return self._query(sql, tuple(params)) + + def get_part(self, part_id: int) -> Optional[dict]: + """Return a single part with group/category info, or None.""" + return self._query( + """ + SELECT + p.id, + p.oem_part_number, + p.name, + p.name_es, + p.description, + p.description_es, + p.weight_kg, + p.material, + p.is_discontinued, + p.superseded_by_id, + p.group_id, + pg.name AS group_name, + pg.name_es AS group_name_es, + pc.id AS category_id, + pc.name AS category_name, + pc.name_es AS category_name_es + FROM parts p + JOIN part_groups pg ON p.group_id = pg.id + JOIN part_categories pc ON pg.category_id = pc.id + WHERE p.id = ? + """, + (part_id,), + one=True, + ) + + def get_alternatives(self, part_id: int) -> list[dict]: + """Return aftermarket alternatives for an OEM part.""" + return self._query( + """ + SELECT + ap.id, + ap.part_number, + ap.name, + ap.name_es, + m.name AS manufacturer_name, + ap.manufacturer_id, + ap.quality_tier, + ap.price_usd, + ap.warranty_months, + ap.in_stock + FROM aftermarket_parts ap + JOIN manufacturers m ON ap.manufacturer_id = m.id + WHERE ap.oem_part_id = ? + ORDER BY ap.quality_tier DESC, ap.price_usd ASC + """, + (part_id,), + ) + + def get_cross_references(self, part_id: int) -> list[dict]: + """Return cross-reference numbers for a part.""" + return self._query( + """ + SELECT id, cross_reference_number, reference_type, source, notes + FROM part_cross_references + WHERE part_id = ? + ORDER BY reference_type, cross_reference_number + """, + (part_id,), + ) + + def get_vehicles_for_part(self, part_id: int) -> list[dict]: + """Return vehicles that use a specific part.""" + return self._query( + """ + SELECT + b.name AS brand, + m.name AS model, + y.year, + e.name AS engine, + mye.trim_level, + vp.quantity_required, + vp.position, + vp.fitment_notes + FROM vehicle_parts vp + JOIN model_year_engine mye ON vp.model_year_engine_id = mye.id + JOIN models m ON mye.model_id = m.id + JOIN brands b ON m.brand_id = b.id + JOIN years y ON mye.year_id = y.id + JOIN engines e ON mye.engine_id = e.id + WHERE vp.part_id = ? + ORDER BY b.name, m.name, y.year + """, + (part_id,), + ) + + # ================================================================== + # Search + # ================================================================== + + def search_parts( + self, query: str, page: int = 1, per_page: int = 15 + ) -> list[dict]: + """Full-text search using FTS5, with fallback to LIKE.""" + per_page = min(per_page, 100) + offset = (page - 1) * per_page + + conn = self._connect() + try: + cursor = conn.cursor() + + # Check if FTS5 table exists + cursor.execute( + "SELECT name FROM sqlite_master " + "WHERE type='table' AND name='parts_fts'" + ) + fts_exists = cursor.fetchone() is not None + + if fts_exists: + # Escape FTS5 special chars by quoting each term + terms = query.split() + quoted = ['"' + t.replace('"', '""') + '"' for t in terms] + fts_query = " ".join(quoted) + + cursor.execute( + """ + SELECT + p.id, + p.oem_part_number, + p.name, + p.name_es, + p.description, + pg.name AS group_name, + pc.name AS category_name, + bm25(parts_fts) AS rank + FROM parts_fts + JOIN parts p ON parts_fts.rowid = p.id + JOIN part_groups pg ON p.group_id = pg.id + JOIN part_categories pc ON pg.category_id = pc.id + WHERE parts_fts MATCH ? + ORDER BY rank + LIMIT ? OFFSET ? + """, + (fts_query, per_page, offset), + ) + else: + search_term = f"%{query}%" + cursor.execute( + """ + SELECT + p.id, + p.oem_part_number, + p.name, + p.name_es, + p.description, + pg.name AS group_name, + pc.name AS category_name, + 0 AS rank + FROM parts p + JOIN part_groups pg ON p.group_id = pg.id + JOIN part_categories pc ON pg.category_id = pc.id + WHERE p.name LIKE ? OR p.name_es LIKE ? + OR p.oem_part_number LIKE ? OR p.description LIKE ? + ORDER BY p.name + LIMIT ? OFFSET ? + """, + ( + search_term, + search_term, + search_term, + search_term, + per_page, + offset, + ), + ) + + return [dict(r) for r in cursor.fetchall()] + finally: + conn.close() + + def search_part_number(self, number: str) -> list[dict]: + """Search OEM, aftermarket, and cross-reference part numbers.""" + search_term = f"%{number}%" + results: list[dict] = [] + + conn = self._connect() + try: + cursor = conn.cursor() + + # OEM parts + cursor.execute( + """ + SELECT id, oem_part_number, name, name_es + FROM parts + WHERE oem_part_number LIKE ? + """, + (search_term,), + ) + for row in cursor.fetchall(): + results.append( + { + **dict(row), + "match_type": "oem", + "matched_number": row["oem_part_number"], + } + ) + + # Aftermarket parts + cursor.execute( + """ + SELECT p.id, p.oem_part_number, p.name, p.name_es, ap.part_number + FROM aftermarket_parts ap + JOIN parts p ON ap.oem_part_id = p.id + WHERE ap.part_number LIKE ? + """, + (search_term,), + ) + for row in cursor.fetchall(): + results.append( + { + "id": row["id"], + "oem_part_number": row["oem_part_number"], + "name": row["name"], + "name_es": row["name_es"], + "match_type": "aftermarket", + "matched_number": row["part_number"], + } + ) + + # Cross-references + cursor.execute( + """ + SELECT p.id, p.oem_part_number, p.name, p.name_es, + pcr.cross_reference_number + FROM part_cross_references pcr + JOIN parts p ON pcr.part_id = p.id + WHERE pcr.cross_reference_number LIKE ? + """, + (search_term,), + ) + for row in cursor.fetchall(): + results.append( + { + "id": row["id"], + "oem_part_number": row["oem_part_number"], + "name": row["name"], + "name_es": row["name_es"], + "match_type": "cross_reference", + "matched_number": row["cross_reference_number"], + } + ) + + return results + finally: + conn.close() + + # ================================================================== + # VIN cache + # ================================================================== + + def get_vin_cache(self, vin: str) -> Optional[dict]: + """Return cached VIN decode data if still valid, else None.""" + return self._query( + """ + SELECT + vin, decoded_data, make, model, year, + engine_info, body_class, drive_type, + model_year_engine_id, created_at, expires_at + FROM vin_cache + WHERE vin = ? AND expires_at > datetime('now') + """, + (vin.upper().strip(),), + one=True, + ) + + def save_vin_cache( + self, + vin: str, + data: str, + make: str, + model: str, + year: int, + engine_info: str, + body_class: str, + drive_type: str, + ) -> int: + """Insert or replace a VIN cache entry (30-day expiry).""" + expires = datetime.utcnow() + timedelta(days=30) + conn = self._connect() + try: + cursor = conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO vin_cache + (vin, decoded_data, make, model, year, + engine_info, body_class, drive_type, expires_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + vin.upper().strip(), + data, + make, + model, + year, + engine_info, + body_class, + drive_type, + expires.isoformat(), + ), + ) + conn.commit() + return cursor.lastrowid + finally: + conn.close() + + # ================================================================== + # Stats + # ================================================================== + + def get_stats(self) -> dict: + """Return counts for all major tables plus top brands by fitment.""" + conn = self._connect() + try: + cursor = conn.cursor() + stats: dict = {} + + for table in [ + "brands", + "models", + "years", + "engines", + "part_categories", + "part_groups", + "parts", + "aftermarket_parts", + "manufacturers", + "vehicle_parts", + "part_cross_references", + ]: + cursor.execute(f"SELECT COUNT(*) FROM {table}") + stats[table] = cursor.fetchone()[0] + + # Top brands by number of fitments + cursor.execute( + """ + SELECT b.name, COUNT(DISTINCT vp.id) AS cnt + FROM brands b + JOIN models m ON m.brand_id = b.id + JOIN model_year_engine mye ON mye.model_id = m.id + JOIN vehicle_parts vp ON vp.model_year_engine_id = mye.id + GROUP BY b.name + ORDER BY cnt DESC + LIMIT 10 + """ + ) + stats["top_brands"] = [ + {"name": r["name"], "count": r["cnt"]} for r in cursor.fetchall() + ] + + return stats + finally: + conn.close() + + # ================================================================== + # Admin — Manufacturers + # ================================================================== + + def get_manufacturers(self) -> list[dict]: + """Return all manufacturers ordered by name.""" + return self._query( + """ + SELECT id, name, type, quality_tier, country, logo_url, website + FROM manufacturers + ORDER BY name + """ + ) + + def create_manufacturer(self, data: dict) -> int: + """Insert a new manufacturer and return its id.""" + return self._execute( + """ + INSERT INTO manufacturers (name, type, quality_tier, country, logo_url, website) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + data["name"], + data.get("type"), + data.get("quality_tier"), + data.get("country"), + data.get("logo_url"), + data.get("website"), + ), + ) + + def update_manufacturer(self, mfr_id: int, data: dict) -> None: + """Update an existing manufacturer.""" + self._execute( + """ + UPDATE manufacturers + SET name = ?, type = ?, quality_tier = ?, country = ?, logo_url = ?, website = ? + WHERE id = ? + """, + ( + data["name"], + data.get("type"), + data.get("quality_tier"), + data.get("country"), + data.get("logo_url"), + data.get("website"), + mfr_id, + ), + ) + + def delete_manufacturer(self, mfr_id: int) -> None: + """Delete a manufacturer by id.""" + self._execute("DELETE FROM manufacturers WHERE id = ?", (mfr_id,)) + + # ================================================================== + # Admin — Parts + # ================================================================== + + def create_part(self, data: dict) -> int: + """Insert a new part and return its id.""" + return self._execute( + """ + INSERT INTO parts + (oem_part_number, name, name_es, group_id, + description, description_es, weight_kg, material) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + data["oem_part_number"], + data["name"], + data.get("name_es"), + data.get("group_id"), + data.get("description"), + data.get("description_es"), + data.get("weight_kg"), + data.get("material"), + ), + ) + + def update_part(self, part_id: int, data: dict) -> None: + """Update an existing part.""" + self._execute( + """ + UPDATE parts + SET oem_part_number = ?, name = ?, name_es = ?, group_id = ?, + description = ?, description_es = ?, weight_kg = ?, material = ? + WHERE id = ? + """, + ( + data["oem_part_number"], + data["name"], + data.get("name_es"), + data.get("group_id"), + data.get("description"), + data.get("description_es"), + data.get("weight_kg"), + data.get("material"), + part_id, + ), + ) + + def delete_part(self, part_id: int) -> None: + """Delete a part by id.""" + self._execute("DELETE FROM parts WHERE id = ?", (part_id,)) + + # ================================================================== + # Admin — Cross-references + # ================================================================== + + def create_crossref(self, data: dict) -> int: + """Insert a new cross-reference and return its id.""" + return self._execute( + """ + INSERT INTO part_cross_references + (part_id, cross_reference_number, reference_type, source, notes) + VALUES (?, ?, ?, ?, ?) + """, + ( + data["part_id"], + data["cross_reference_number"], + data.get("reference_type"), + data.get("source"), + data.get("notes"), + ), + ) + + def update_crossref(self, xref_id: int, data: dict) -> None: + """Update an existing cross-reference.""" + self._execute( + """ + UPDATE part_cross_references + SET part_id = ?, cross_reference_number = ?, + reference_type = ?, source = ?, notes = ? + WHERE id = ? + """, + ( + data["part_id"], + data["cross_reference_number"], + data.get("reference_type"), + data.get("source"), + data.get("notes"), + xref_id, + ), + ) + + def delete_crossref(self, xref_id: int) -> None: + """Delete a cross-reference by id.""" + self._execute( + "DELETE FROM part_cross_references WHERE id = ?", (xref_id,) + ) + + def get_crossrefs_paginated( + self, page: int = 1, per_page: int = 15 + ) -> list[dict]: + """Return paginated cross-references with part info.""" + per_page = min(per_page, 100) + offset = (page - 1) * per_page + return self._query( + """ + SELECT + pcr.id, + pcr.part_id, + pcr.cross_reference_number, + pcr.reference_type, + pcr.source, + pcr.notes, + p.oem_part_number, + p.name AS part_name + FROM part_cross_references pcr + JOIN parts p ON pcr.part_id = p.id + ORDER BY pcr.id + LIMIT ? OFFSET ? + """, + (per_page, offset), + ) diff --git a/console/tests/test_db.py b/console/tests/test_db.py new file mode 100644 index 0000000..acff599 --- /dev/null +++ b/console/tests/test_db.py @@ -0,0 +1,273 @@ +""" +Tests for the Database abstraction layer. + +All tests run against the real SQLite database at vehicle_database/vehicle_database.db. +""" + +import pytest + +from console.db import Database + + +@pytest.fixture(scope="module") +def db(): + """Provide a shared Database instance for all tests in this module.""" + return Database() + + +# ========================================================================= +# Vehicle navigation +# ========================================================================= + +class TestGetBrands: + def test_returns_nonempty_list(self, db): + brands = db.get_brands() + assert isinstance(brands, list) + assert len(brands) > 0 + + def test_each_brand_has_name_key(self, db): + brands = db.get_brands() + for b in brands: + assert "name" in b + + def test_each_brand_has_id_and_country(self, db): + brands = db.get_brands() + for b in brands: + assert "id" in b + assert "country" in b + + +class TestGetModels: + def test_no_filter_returns_nonempty(self, db): + models = db.get_models() + assert isinstance(models, list) + assert len(models) > 0 + + def test_filter_by_uppercase_brand(self, db): + models = db.get_models(brand="TOYOTA") + assert isinstance(models, list) + assert len(models) > 0 + + def test_filter_by_lowercase_brand(self, db): + """Brand filtering must be case-insensitive.""" + models = db.get_models(brand="toyota") + assert isinstance(models, list) + assert len(models) > 0 + + def test_each_model_has_id_and_name(self, db): + models = db.get_models() + for m in models[:5]: + assert "id" in m + assert "name" in m + + +class TestGetYears: + def test_returns_list(self, db): + years = db.get_years() + assert isinstance(years, list) + assert len(years) > 0 + + def test_filter_by_brand(self, db): + years = db.get_years(brand="TOYOTA") + assert isinstance(years, list) + assert len(years) > 0 + + def test_each_year_has_id_and_year(self, db): + years = db.get_years() + for y in years[:5]: + assert "id" in y + assert "year" in y + + +class TestGetEngines: + def test_returns_list(self, db): + engines = db.get_engines() + assert isinstance(engines, list) + assert len(engines) > 0 + + def test_filter_by_brand(self, db): + engines = db.get_engines(brand="TOYOTA") + assert isinstance(engines, list) + assert len(engines) > 0 + + +class TestGetModelYearEngine: + def test_returns_list(self, db): + result = db.get_model_year_engine( + brand="TOYOTA", model="Corolla", year=2020, engine_id=None + ) + assert isinstance(result, list) + + +# ========================================================================= +# Parts catalog +# ========================================================================= + +class TestGetCategories: + def test_returns_exactly_12(self, db): + categories = db.get_categories() + assert isinstance(categories, list) + assert len(categories) == 12 + + def test_each_has_expected_keys(self, db): + categories = db.get_categories() + for c in categories: + assert "id" in c + assert "name" in c + + +class TestGetGroups: + def test_returns_nonempty_for_known_category(self, db): + groups = db.get_groups(category_id=2) + assert isinstance(groups, list) + assert len(groups) > 0 + + def test_each_group_has_name(self, db): + groups = db.get_groups(category_id=2) + for g in groups: + assert "name" in g + + +class TestGetParts: + def test_returns_list(self, db): + parts = db.get_parts() + assert isinstance(parts, list) + assert len(parts) > 0 + + def test_pagination(self, db): + page1 = db.get_parts(page=1, per_page=5) + page2 = db.get_parts(page=2, per_page=5) + assert len(page1) <= 5 + assert len(page2) <= 5 + # Pages should contain different items (if enough data) + if page1 and page2: + ids1 = {p["id"] for p in page1} + ids2 = {p["id"] for p in page2} + assert ids1.isdisjoint(ids2) + + +class TestGetPart: + def test_returns_dict_with_oem_part_number(self, db): + part = db.get_part(1) + assert isinstance(part, dict) + assert "oem_part_number" in part + + def test_includes_group_and_category_info(self, db): + part = db.get_part(1) + assert "group_name" in part + assert "category_name" in part + + def test_nonexistent_returns_none(self, db): + part = db.get_part(999999) + assert part is None + + +class TestGetAlternatives: + def test_returns_list(self, db): + alts = db.get_alternatives(1) + assert isinstance(alts, list) + + +class TestGetCrossReferences: + def test_returns_list(self, db): + refs = db.get_cross_references(1) + assert isinstance(refs, list) + + +class TestGetVehiclesForPart: + def test_returns_list(self, db): + vehicles = db.get_vehicles_for_part(1) + assert isinstance(vehicles, list) + assert len(vehicles) > 0 + + +# ========================================================================= +# Search +# ========================================================================= + +class TestSearchParts: + def test_returns_results_for_brake(self, db): + results = db.search_parts("brake") + assert isinstance(results, list) + assert len(results) > 0 + + def test_each_result_has_expected_keys(self, db): + results = db.search_parts("brake") + for r in results[:3]: + assert "id" in r + assert "name" in r + assert "oem_part_number" in r + + +class TestSearchPartNumber: + def test_returns_results_for_04465(self, db): + results = db.search_part_number("04465") + assert isinstance(results, list) + assert len(results) > 0 + + def test_each_result_has_match_type(self, db): + results = db.search_part_number("04465") + for r in results: + assert "match_type" in r + + +# ========================================================================= +# VIN cache +# ========================================================================= + +class TestVinCache: + def test_get_nonexistent_vin_returns_none(self, db): + result = db.get_vin_cache("00000000000000000") + assert result is None + + +# ========================================================================= +# Stats +# ========================================================================= + +class TestGetStats: + def test_returns_dict_with_required_keys(self, db): + stats = db.get_stats() + assert isinstance(stats, dict) + assert "brands" in stats + assert "models" in stats + assert "parts" in stats + + def test_counts_are_positive(self, db): + stats = db.get_stats() + assert stats["brands"] > 0 + assert stats["models"] > 0 + assert stats["parts"] > 0 + + def test_includes_top_brands(self, db): + stats = db.get_stats() + assert "top_brands" in stats + assert isinstance(stats["top_brands"], list) + + +# ========================================================================= +# Manufacturers +# ========================================================================= + +class TestGetManufacturers: + def test_returns_nonempty_list(self, db): + manufacturers = db.get_manufacturers() + assert isinstance(manufacturers, list) + assert len(manufacturers) > 0 + + def test_each_has_name(self, db): + manufacturers = db.get_manufacturers() + for m in manufacturers: + assert "name" in m + assert "id" in m + + +# ========================================================================= +# Admin CRUD — smoke tests +# ========================================================================= + +class TestCrossrefsPaginated: + def test_returns_list(self, db): + refs = db.get_crossrefs_paginated(page=1, per_page=5) + assert isinstance(refs, list) + assert len(refs) <= 5