import sqlite3 from contextlib import contextmanager from typing import List import httpx from pydantic import BaseModel from app import config _http = httpx.AsyncClient() class Location(BaseModel): id: int country: str city: str region: str = None @contextmanager def get_connection() -> sqlite3.Connection: def dict_factory(cursor, row): d = {} for idx, col in enumerate(cursor.description): d[col[0]] = row[idx] return d conn = sqlite3.connect(config.DB_PATH, check_same_thread=False) conn.row_factory = dict_factory yield conn conn.close() def _check_table_exists(conn: sqlite3.Connection) -> bool: stmt = conn.execute("select count(*) as count from sqlite_master where type=? and name=?", ('table', 'locations')) return bool(stmt.fetchone()['count']) def init_db(conn: sqlite3.Connection) -> bool: if not _check_table_exists(conn): conn.execute(''' CREATE TABLE IF NOT EXISTS locations ( id INTEGER UNIQUE NOT NULL PRIMARY KEY, country TEXT NOT NULL, city TEXT NOT NULL, region TEXT ) ''') return _check_table_exists(conn) def save_locations(conn: sqlite3.Connection, locations: List[dict]): sql = 'INSERT INTO locations (id, country, city, region) VALUES (:id, :country, :city, :region)' conn.executemany(sql, locations) conn.commit() def find_locations(conn: sqlite3.Connection, *, country: str, city: str = None) -> List[Location]: sql = ''' SELECT DISTINCT id, country, city, region FROM locations {where} ORDER BY city, region ''' values = [country] conditions = ['country = ?'] if city: conditions.append('city = ?') values.append(city) sql = sql.format(where=f"WHERE {' and '.join(conditions)}") stmt = conn.execute(sql, values) rows = stmt.fetchall() if not rows: return [] return [Location(**row) for row in rows] def find_countries(conn: sqlite3.Connection) -> List[str]: sql = 'select DISTINCT country from locations ORDER BY country' stmt = conn.execute(sql) rows = stmt.fetchall() return [row['country'] for row in rows] def find_cities(conn: sqlite3.Connection, country: str) -> List[str]: sql = 'select DISTINCT city from locations WHERE country = ? ORDER BY city' stmt = conn.execute(sql, [country]) rows = stmt.fetchall() return [row['city'] for row in rows] def find_location_by_name(conn: sqlite3.Connection, q: str) -> List[Location]: sql = ''' SELECT DISTINCT id, country, city, region FROM locations_search WHERE locations_search match :q ORDER BY country, city, region ''' stmt = conn.execute(sql, {'q': f'{q}'}) rows = stmt.fetchall() if not rows: return [] return [Location(**row) for row in rows] if __name__ == '__main__': from pprint import pprint with get_connection() as conn: pprint(find_location_by_name(conn, 'ala'))