You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

114 lines
3.1 KiB
Python

import sqlite3
from contextlib import contextmanager
from typing import List
import httpx
from pydantic import BaseModel
from app import config
_http = httpx.AsyncClient()
class Location(BaseModel):
id: int
country: str
city: str
region: str = None
@contextmanager
def get_connection() -> sqlite3.Connection:
def dict_factory(cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d
conn = sqlite3.connect(config.DB_PATH)
conn.row_factory = dict_factory
yield conn
conn.close()
def _check_table_exists(conn: sqlite3.Connection) -> bool:
stmt = conn.execute("select count(*) as count from sqlite_master where type=? and name=?", ('table', 'locations'))
return bool(stmt.fetchone()['count'])
def init_db(conn: sqlite3.Connection) -> bool:
if not _check_table_exists(conn):
conn.execute('''
CREATE TABLE IF NOT EXISTS locations (
id INTEGER UNIQUE NOT NULL PRIMARY KEY,
country TEXT NOT NULL,
city TEXT NOT NULL,
region TEXT
)
''')
return _check_table_exists(conn)
def save_locations(conn: sqlite3.Connection, locations: List[dict]):
sql = 'INSERT INTO locations (id, country, city, region) VALUES (:id, :country, :city, :region)'
conn.executemany(sql, locations)
conn.commit()
def find_locations(conn: sqlite3.Connection,
*,
country: str,
city: str = '') -> List[Location]:
sql = '''
SELECT DISTINCT id, country, city, region
FROM locations
{where}
ORDER BY city, region
'''
values = [country]
conditions = ['country = ?']
if city:
conditions.append('city = ?')
values.append(city)
sql = sql.format(where=f"WHERE {' and '.join(conditions)}")
stmt = conn.execute(sql, values)
rows = stmt.fetchall()
if not rows:
return []
return [Location(**row) for row in rows]
def find_countries(conn: sqlite3.Connection) -> List[str]:
sql = 'select DISTINCT country from locations ORDER BY country'
stmt = conn.execute(sql)
rows = stmt.fetchall()
return [row['country'] for row in rows]
def find_cities(conn: sqlite3.Connection, country: str) -> List[str]:
sql = 'select DISTINCT city from locations WHERE country = ? ORDER BY city'
stmt = conn.execute(sql, [country])
rows = stmt.fetchall()
return [row['city'] for row in rows]
def find_location_by_name(conn: sqlite3.Connection, q: str) -> List[Location]:
sql = '''
SELECT DISTINCT id, country, city, region
FROM locations_search
WHERE locations_search match :q
ORDER BY country, city, region
'''
stmt = conn.execute(sql, {'q': f'%{q}%'})
rows = stmt.fetchall()
if not rows:
return []
return [Location(**row) for row in rows]
if __name__ == '__main__':
from pprint import pprint
with get_connection() as conn:
pprint(find_location_by_name(conn, 'ala'))