You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
114 lines
3.1 KiB
Python
114 lines
3.1 KiB
Python
import sqlite3
|
|
from contextlib import contextmanager
|
|
from typing import List
|
|
|
|
import httpx
|
|
from pydantic import BaseModel
|
|
|
|
from app import config
|
|
|
|
_http = httpx.AsyncClient(timeout=20, follow_redirects=True)
|
|
|
|
|
|
class Location(BaseModel):
|
|
id: int
|
|
country: str
|
|
city: str
|
|
region: str = None
|
|
|
|
|
|
@contextmanager
|
|
def get_connection() -> sqlite3.Connection:
|
|
def dict_factory(cursor, row):
|
|
d = {}
|
|
for idx, col in enumerate(cursor.description):
|
|
d[col[0]] = row[idx]
|
|
return d
|
|
|
|
conn = sqlite3.connect(config.DB_PATH, check_same_thread=False)
|
|
conn.row_factory = dict_factory
|
|
yield conn
|
|
conn.close()
|
|
|
|
|
|
def _check_table_exists(conn: sqlite3.Connection) -> bool:
|
|
stmt = conn.execute("select count(*) as count from sqlite_master where type=? and name=?", ('table', 'locations'))
|
|
return bool(stmt.fetchone()['count'])
|
|
|
|
|
|
def init_db(conn: sqlite3.Connection) -> bool:
|
|
if not _check_table_exists(conn):
|
|
conn.execute('''
|
|
CREATE TABLE IF NOT EXISTS locations (
|
|
id INTEGER UNIQUE NOT NULL PRIMARY KEY,
|
|
country TEXT NOT NULL,
|
|
city TEXT NOT NULL,
|
|
region TEXT
|
|
)
|
|
''')
|
|
return _check_table_exists(conn)
|
|
|
|
|
|
def save_locations(conn: sqlite3.Connection, locations: List[dict]):
|
|
sql = 'INSERT INTO locations (id, country, city, region) VALUES (:id, :country, :city, :region)'
|
|
conn.executemany(sql, locations)
|
|
conn.commit()
|
|
|
|
|
|
def find_locations(conn: sqlite3.Connection,
|
|
*,
|
|
country: str,
|
|
city: str = None) -> List[Location]:
|
|
sql = '''
|
|
SELECT DISTINCT id, country, city, region
|
|
FROM locations
|
|
{where}
|
|
ORDER BY city, region
|
|
'''
|
|
values = [country]
|
|
conditions = ['country = ?']
|
|
if city:
|
|
conditions.append('city = ?')
|
|
values.append(city)
|
|
sql = sql.format(where=f"WHERE {' and '.join(conditions)}")
|
|
stmt = conn.execute(sql, values)
|
|
rows = stmt.fetchall()
|
|
if not rows:
|
|
return []
|
|
return [Location(**row) for row in rows]
|
|
|
|
|
|
def find_countries(conn: sqlite3.Connection) -> List[str]:
|
|
sql = 'select DISTINCT country from locations ORDER BY country'
|
|
stmt = conn.execute(sql)
|
|
rows = stmt.fetchall()
|
|
return [row['country'] for row in rows]
|
|
|
|
|
|
def find_cities(conn: sqlite3.Connection, country: str) -> List[str]:
|
|
sql = 'select DISTINCT city from locations WHERE country = ? ORDER BY city'
|
|
stmt = conn.execute(sql, [country])
|
|
rows = stmt.fetchall()
|
|
return [row['city'] for row in rows]
|
|
|
|
|
|
def find_location_by_name(conn: sqlite3.Connection, q: str) -> List[Location]:
|
|
sql = '''
|
|
SELECT DISTINCT id, country, city, region
|
|
FROM locations_search
|
|
WHERE locations_search match :q
|
|
ORDER BY country, city, region
|
|
'''
|
|
stmt = conn.execute(sql, {'q': f'{q}'})
|
|
rows = stmt.fetchall()
|
|
if not rows:
|
|
return []
|
|
return [Location(**row) for row in rows]
|
|
|
|
|
|
if __name__ == '__main__':
|
|
from pprint import pprint
|
|
|
|
with get_connection() as conn:
|
|
pprint(find_location_by_name(conn, 'ala'))
|