Add types and docstrings + upgrade to SQLAlchmey 2

This commit is contained in:
Edward Betts 2023-11-01 20:54:19 +00:00
parent 82671959bb
commit 5e8d1a99b0
8 changed files with 248 additions and 125 deletions

View File

@ -8,6 +8,8 @@ import flask
import geoalchemy2 import geoalchemy2
import sqlalchemy import sqlalchemy
from sqlalchemy import and_, or_ from sqlalchemy import and_, or_
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped
from sqlalchemy.sql import select from sqlalchemy.sql import select
from matcher import database, model, wikidata, wikidata_api from matcher import database, model, wikidata, wikidata_api
@ -91,16 +93,22 @@ def make_envelope(bounds: list[float]) -> geoalchemy2.functions.ST_MakeEnvelope:
return sqlalchemy.func.ST_MakeEnvelope(*bounds, srid) return sqlalchemy.func.ST_MakeEnvelope(*bounds, srid)
def parse_point(point: str) -> tuple[str, str]:
"""Parse point from PostGIS."""
m = re_point.match(point)
assert m
lon, lat = m.groups()
assert lon and lat
return (lon, lat)
def get_bbox_centroid(bbox: list[float]) -> tuple[str, str]: def get_bbox_centroid(bbox: list[float]) -> tuple[str, str]:
"""Get centroid of bounding box.""" """Get centroid of bounding box."""
bbox = make_envelope(bbox) bbox = make_envelope(bbox)
centroid = database.session.query( centroid = database.session.query(
sqlalchemy.func.ST_AsText(sqlalchemy.func.ST_Centroid(bbox)) sqlalchemy.func.ST_AsText(sqlalchemy.func.ST_Centroid(bbox))
).scalar() ).scalar()
m = re_point.match(centroid) lon, lat = parse_point(centroid)
assert m
lon, lat = m.groups()
assert lon and lat
return (lat, lon) return (lat, lon)
@ -117,26 +125,17 @@ def make_envelope_around_point(
s = select( s = select(
[ [
sqlalchemy.func.ST_AsText( sqlalchemy.func.ST_AsText(
sqlalchemy.func.ST_Project(p, distance, sqlalchemy.func.radians(0)) sqlalchemy.func.ST_Project(p, distance, sqlalchemy.func.radians(deg))
), )
sqlalchemy.func.ST_AsText( for deg in (0, 90, 180, 270)
sqlalchemy.func.ST_Project(p, distance, sqlalchemy.func.radians(90))
),
sqlalchemy.func.ST_AsText(
sqlalchemy.func.ST_Project(p, distance, sqlalchemy.func.radians(180))
),
sqlalchemy.func.ST_AsText(
sqlalchemy.func.ST_Project(p, distance, sqlalchemy.func.radians(270))
),
] ]
) )
row = conn.execute(s).fetchone() coords = [parse_point(i) for i in conn.execute(s).fetchone()]
coords = [[float(v) for v in re_point.match(i).groups()] for i in row]
north = coords[0][1] north = float(coords[0][1])
east = coords[1][0] east = float(coords[1][0])
south = coords[2][1] south = float(coords[2][1])
west = coords[3][0] west = float(coords[3][0])
return sqlalchemy.func.ST_MakeEnvelope(west, south, east, north, srid) return sqlalchemy.func.ST_MakeEnvelope(west, south, east, north, srid)
@ -148,10 +147,15 @@ def drop_way_area(tags: TagsType) -> TagsType:
return tags return tags
def get_part_of(table_name, src_id, bbox): def get_part_of(
table_name: str, src_id: int, bbox: geoalchemy2.functions.ST_MakeEnvelope
) -> list[dict[str, typing.Any]]:
"""Get part of."""
table_map = {"point": point, "line": line, "polygon": polygon} table_map = {"point": point, "line": line, "polygon": polygon}
table_alias = table_map[table_name].alias() table_alias = table_map[table_name].alias()
tags: Mapped[postgresql.HSTORE] = polygon.c.tags
s = ( s = (
select( select(
[ [
@ -165,11 +169,8 @@ def get_part_of(table_name, src_id, bbox):
sqlalchemy.func.ST_Intersects(bbox, polygon.c.way), sqlalchemy.func.ST_Intersects(bbox, polygon.c.way),
sqlalchemy.func.ST_Covers(polygon.c.way, table_alias.c.way), sqlalchemy.func.ST_Covers(polygon.c.way, table_alias.c.way),
table_alias.c.osm_id == src_id, table_alias.c.osm_id == src_id,
polygon.c.tags.has_key("name"), tags.has_key("name"),
or_( or_(tags.has_key("landuse"), tags.has_key("amenity")),
polygon.c.tags.has_key("landuse"),
polygon.c.tags.has_key("amenity"),
),
) )
) )
.group_by(polygon.c.osm_id, polygon.c.tags) .group_by(polygon.c.osm_id, polygon.c.tags)
@ -228,6 +229,7 @@ def get_isa_count(items: list[model.Item]) -> list[tuple[str, int]]:
if not isa: if not isa:
print("missing IsA:", item.qid) print("missing IsA:", item.qid)
continue continue
assert isinstance(isa, dict) and isinstance(isa["id"], str)
isa_count[isa["id"]] += 1 isa_count[isa["id"]] += 1
return isa_count.most_common() return isa_count.most_common()
@ -920,7 +922,7 @@ def find_osm_candidates(item, limit=80, max_distance=450, names=None):
"geojson": json.loads(geojson), "geojson": json.loads(geojson),
"presets": get_presets_from_tags(shape, tags), "presets": get_presets_from_tags(shape, tags),
"address_list": address_list, "address_list": address_list,
"centroid": list(reversed(re_point.match(centroid).groups())), "centroid": list(reversed(parse_point(centroid))),
} }
if area is not None: if area is not None:
cur["area"] = area cur["area"] = area
@ -980,23 +982,23 @@ def check_is_street_number_first(latlng):
flask.g.street_number_first = is_street_number_first(*latlng) flask.g.street_number_first = is_street_number_first(*latlng)
class ItemDetailType(typing.TypedDict): class ItemDetailType(typing.TypedDict, total=False):
"""Details of an item as a dict.""" """Details of an item as a dict."""
qid: str qid: str
label: str label: str
description: str description: str | None
markers: list[dict[str, float]] markers: list[dict[str, float]]
image_list: list[str] image_list: list[str]
street_address: list[str] street_address: list[str]
isa_list: list[dict[str, str]] isa_list: list[dict[str, str]]
closed: bool closed: list[str]
inception: str inception: str
p1619: str p1619: str
p576: str p576: str
heritage_designation: str heritage_designation: str
wikipedia: dict[str, str] wikipedia: list[dict[str, str]]
identifiers: list[str] identifiers: dict[str, list[str]]
def item_detail(item: model.Item) -> ItemDetailType: def item_detail(item: model.Item) -> ItemDetailType:
@ -1036,7 +1038,7 @@ def item_detail(item: model.Item) -> ItemDetailType:
if site.endswith("wiki") and len(site) < 8 if site.endswith("wiki") and len(site) < 8
] ]
d = { d: ItemDetailType = {
"qid": item.qid, "qid": item.qid,
"label": item.label(), "label": item.label(),
"description": item.description(), "description": item.description(),

View File

@ -1,4 +1,6 @@
"""Database functions.""" """Database."""
from datetime import datetime
import flask import flask
import sqlalchemy import sqlalchemy
@ -8,7 +10,7 @@ from sqlalchemy.orm import scoped_session, sessionmaker
session: sqlalchemy.orm.scoping.scoped_session = scoped_session(sessionmaker()) session: sqlalchemy.orm.scoping.scoped_session = scoped_session(sessionmaker())
timeout = 20_000 # 20 seconds timeout = 2_000 # 20 seconds
def init_db(db_url: str, echo: bool = False) -> None: def init_db(db_url: str, echo: bool = False) -> None:
@ -17,7 +19,7 @@ def init_db(db_url: str, echo: bool = False) -> None:
def get_engine(db_url: str, echo: bool = False) -> sqlalchemy.engine.base.Engine: def get_engine(db_url: str, echo: bool = False) -> sqlalchemy.engine.base.Engine:
"""Create an engine objcet.""" """Create an engine object."""
return create_engine( return create_engine(
db_url, db_url,
pool_recycle=3600, pool_recycle=3600,
@ -40,9 +42,10 @@ def init_app(app: flask.app.Flask, echo: bool = False) -> None:
session.configure(bind=get_engine(db_url, echo=echo)) session.configure(bind=get_engine(db_url, echo=echo))
@app.teardown_appcontext @app.teardown_appcontext
def shutdown_session(exception: Exception | None = None) -> None: def shutdown_session(exception: BaseException | None = None) -> None:
session.remove() session.remove()
def now_utc(): def now_utc() -> sqlalchemy.sql.functions.Function[datetime]:
"""Now with UTC timezone."""
return func.timezone("utc", func.now()) return func.timezone("utc", func.now())

View File

@ -1,14 +1,17 @@
from flask import g
from . import user_agent_headers, database, osm_oauth, mail
from .model import Changeset
import requests
import html import html
import requests
from flask import g
from . import database, mail, osm_oauth, user_agent_headers
from .model import Changeset
really_save = True really_save = True
osm_api_base = "https://api.openstreetmap.org/api/0.6" osm_api_base = "https://api.openstreetmap.org/api/0.6"
def new_changeset(comment: str) -> str: def new_changeset(comment: str) -> str:
"""XML for a new changeset."""
return f""" return f"""
<osm> <osm>
<changeset> <changeset>
@ -18,11 +21,12 @@ def new_changeset(comment: str) -> str:
</osm>""" </osm>"""
def osm_request(path, **kwargs): def osm_request(path, **kwargs) -> requests.Response:
return osm_oauth.api_put_request(path, **kwargs) return osm_oauth.api_put_request(path, **kwargs)
def create_changeset(changeset): def create_changeset(changeset: str) -> requests.Response:
"""Create new changeset."""
try: try:
return osm_request("/changeset/create", data=changeset.encode("utf-8")) return osm_request("/changeset/create", data=changeset.encode("utf-8"))
except requests.exceptions.HTTPError as r: except requests.exceptions.HTTPError as r:
@ -31,11 +35,15 @@ def create_changeset(changeset):
raise raise
def close_changeset(changeset_id): def close_changeset(changeset_id: int) -> requests.Response:
"""Close changeset."""
return osm_request(f"/changeset/{changeset_id}/close") return osm_request(f"/changeset/{changeset_id}/close")
def save_element(osm_type, osm_id, element_data): def save_element(
osm_type: str, osm_id: int, element_data: str
) -> requests.Response | None:
"""Upload new version of object to OSM map."""
osm_path = f"/{osm_type}/{osm_id}" osm_path = f"/{osm_type}/{osm_id}"
r = osm_request(osm_path, data=element_data) r = osm_request(osm_path, data=element_data)
reply = r.text.strip() reply = r.text.strip()
@ -56,9 +64,12 @@ error:
mail.send_mail(subject, body) mail.send_mail(subject, body)
return None
def record_changeset(**kwargs):
change = Changeset(created=database.now_utc(), **kwargs) def record_changeset(**kwargs: str) -> Changeset:
"""Record changeset in the database."""
change: Changeset = Changeset(created=database.now_utc(), **kwargs)
database.session.add(change) database.session.add(change)
database.session.commit() database.session.commit()
@ -66,6 +77,7 @@ def record_changeset(**kwargs):
return change return change
def get_existing(osm_type, osm_id): def get_existing(osm_type: str, osm_id: int) -> requests.Response:
"""Get existing OSM object."""
url = f"{osm_api_base}/{osm_type}/{osm_id}" url = f"{osm_api_base}/{osm_type}/{osm_id}"
return requests.get(url, headers=user_agent_headers()) return requests.get(url, headers=user_agent_headers())

View File

@ -1,3 +1,4 @@
import abc
import json import json
import re import re
import typing import typing
@ -12,13 +13,21 @@ from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import backref, column_property, deferred, registry, relationship from sqlalchemy.orm import (
Mapped,
QueryPropertyDescriptor,
backref,
column_property,
deferred,
registry,
relationship,
)
from sqlalchemy.orm.collections import attribute_mapped_collection from sqlalchemy.orm.collections import attribute_mapped_collection
from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.schema import Column, ForeignKey from sqlalchemy.schema import Column, ForeignKey
from sqlalchemy.types import BigInteger, Boolean, DateTime, Float, Integer, String, Text from sqlalchemy.types import BigInteger, Boolean, DateTime, Float, Integer, String, Text
from . import mail, utils, wikidata from . import mail, utils, wikidata, wikidata_api
from .database import now_utc, session from .database import now_utc, session
mapper_registry = registry() mapper_registry = registry()
@ -36,7 +45,7 @@ class Base(metaclass=DeclarativeMeta):
registry = mapper_registry registry = mapper_registry
metadata = mapper_registry.metadata metadata = mapper_registry.metadata
query = session.query_property() query: QueryPropertyDescriptor = session.query_property()
__init__ = mapper_registry.constructor __init__ = mapper_registry.constructor
@ -98,12 +107,12 @@ class Item(Base):
sitelinks = Column(postgresql.JSONB) sitelinks = Column(postgresql.JSONB)
claims = Column(postgresql.JSONB, nullable=False) claims = Column(postgresql.JSONB, nullable=False)
lastrevid = Column(Integer, nullable=False, unique=True) lastrevid = Column(Integer, nullable=False, unique=True)
locations = relationship( locations: Mapped[list["ItemLocation"]] = relationship(
"ItemLocation", cascade="all, delete-orphan", backref="item" "ItemLocation", cascade="all, delete-orphan", backref="item"
) )
qid = column_property("Q" + cast_to_string(item_id)) qid: Mapped[str] = column_property("Q" + cast_to_string(item_id))
wiki_extracts = relationship( wiki_extracts: Mapped[list["Extract"]] = relationship(
"Extract", "Extract",
collection_class=attribute_mapped_collection("site"), collection_class=attribute_mapped_collection("site"),
cascade="save-update, merge, delete, delete-orphan", cascade="save-update, merge, delete, delete-orphan",
@ -158,14 +167,15 @@ class Item(Base):
if d_list: if d_list:
return d_list[0]["value"] return d_list[0]["value"]
def get_aliases(self, lang="en"): def get_aliases(self, lang: str = "en") -> list[str]:
"""Get aliases."""
if lang not in self.aliases: if lang not in self.aliases:
if "en" not in self.aliases: if "en" not in self.aliases:
return [] return []
lang = "en" lang = "en"
return [a["value"] for a in self.aliases[lang]] return [a["value"] for a in self.aliases[lang]]
def get_part_of_names(self): def get_part_of_names(self) -> set[str]:
if not self.claims: if not self.claims:
return set() return set()
@ -186,11 +196,14 @@ class Item(Base):
return part_of_names return part_of_names
@property @property
def entity(self): def entity(self) -> wikidata_api.EntityType:
"""Entity."""
keys = ["labels", "aliases", "descriptions", "sitelinks", "claims"] keys = ["labels", "aliases", "descriptions", "sitelinks", "claims"]
return {key: getattr(self, key) for key in keys} return typing.cast(
wikidata_api.EntityType, {key: getattr(self, key) for key in keys}
)
def names(self, check_part_of=True): def names(self, check_part_of: bool = True) -> dict[str, list[str]] | None:
part_of_names = self.get_part_of_names() if check_part_of else set() part_of_names = self.get_part_of_names() if check_part_of else set()
d = wikidata.names_from_entity(self.entity) or defaultdict(list) d = wikidata.names_from_entity(self.entity) or defaultdict(list)
@ -258,7 +271,8 @@ class Item(Base):
"""Get QIDs of items listed instance of (P31) property.""" """Get QIDs of items listed instance of (P31) property."""
return [typing.cast(str, isa["id"]) for isa in self.get_isa()] return [typing.cast(str, isa["id"]) for isa in self.get_isa()]
def is_street(self, isa_qids=None): def is_street(self, isa_qids: typing.Collection[str] | None = None) -> bool:
"""Item represents a street."""
if isa_qids is None: if isa_qids is None:
isa_qids = self.get_isa_qids() isa_qids = self.get_isa_qids()
@ -272,7 +286,8 @@ class Item(Base):
} }
return bool(matching_types & set(isa_qids)) return bool(matching_types & set(isa_qids))
def is_watercourse(self, isa_qids=None): def is_watercourse(self, isa_qids: typing.Collection[str] | None = None) -> bool:
"""Item represents a watercourse."""
if isa_qids is None: if isa_qids is None:
isa_qids = self.get_isa_qids() isa_qids = self.get_isa_qids()
matching_types = { matching_types = {
@ -368,7 +383,7 @@ class Item(Base):
return text[: first_end_p_tag + len(close_tag)] return text[: first_end_p_tag + len(close_tag)]
def get_identifiers_tags(self): def get_identifiers_tags(self) -> dict[str, list[tuple[list[str], str]]]:
tags = defaultdict(list) tags = defaultdict(list)
for claim, osm_keys, label in property_map: for claim, osm_keys, label in property_map:
values = [ values = [
@ -386,7 +401,7 @@ class Item(Base):
tags[osm_key].append((values, label)) tags[osm_key].append((values, label))
return dict(tags) return dict(tags)
def get_identifiers(self): def get_identifiers(self) -> dict[str, list[str]]:
ret = {} ret = {}
for claim, osm_keys, label in property_map: for claim, osm_keys, label in property_map:
values = [ values = [
@ -420,8 +435,8 @@ class ItemIsA(Base):
item_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True) item_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True)
isa_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True) isa_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True)
item = relationship("Item", foreign_keys=[item_id]) item: Mapped[Item] = relationship("Item", foreign_keys=[item_id])
isa = relationship("Item", foreign_keys=[isa_id]) isa: Mapped[Item] = relationship("Item", foreign_keys=[isa_id])
class ItemLocation(Base): class ItemLocation(Base):
@ -458,7 +473,9 @@ def location_objects(
return locations return locations
class MapMixin: class MapBase(Base):
"""Map base class."""
@declared_attr @declared_attr
def __tablename__(cls): def __tablename__(cls):
return "planet_osm_" + cls.__name__.lower() return "planet_osm_" + cls.__name__.lower()
@ -468,7 +485,7 @@ class MapMixin:
admin_level = Column(String) admin_level = Column(String)
boundary = Column(String) boundary = Column(String)
tags = Column(postgresql.HSTORE) tags: Mapped[postgresql.HSTORE]
@declared_attr @declared_attr
def way(cls): def way(cls):
@ -477,67 +494,92 @@ class MapMixin:
) )
@declared_attr @declared_attr
def kml(cls): def kml(cls) -> sqlalchemy.orm.properties.ColumnProperty:
"""Get object in KML format."""
return column_property(func.ST_AsKML(cls.way), deferred=True) return column_property(func.ST_AsKML(cls.way), deferred=True)
@declared_attr @declared_attr
def geojson_str(cls): def geojson_str(cls) -> sqlalchemy.orm.properties.ColumnProperty:
"""Get object as GeoJSON string."""
return column_property( return column_property(
func.ST_AsGeoJSON(cls.way, maxdecimaldigits=6), deferred=True func.ST_AsGeoJSON(cls.way, maxdecimaldigits=6), deferred=True
) )
@declared_attr @declared_attr
def as_EWKT(cls): def as_EWKT(cls) -> sqlalchemy.orm.properties.ColumnProperty:
"""As EWKT."""
return column_property(func.ST_AsEWKT(cls.way), deferred=True) return column_property(func.ST_AsEWKT(cls.way), deferred=True)
@hybrid_property @hybrid_property
def has_street_address(self): def has_street_address(self) -> bool:
"""Object has street address."""
return "addr:housenumber" in self.tags and "addr:street" in self.tags return "addr:housenumber" in self.tags and "addr:street" in self.tags
def display_name(self): def display_name(self) -> str:
"""Name for display."""
for key in "bridge:name", "tunnel:name", "lock_name": for key in "bridge:name", "tunnel:name", "lock_name":
if key in self.tags: if key in self.tags:
return self.tags[key] return typing.cast(str, self.tags[key])
return ( return typing.cast(
self.name or self.tags.get("addr:housename") or self.tags.get("inscription") str,
self.name
or self.tags.get("addr:housename")
or self.tags.get("inscription"),
) )
def geojson(self): def geojson(self) -> dict[str, typing.Any]:
return json.loads(self.geojson_str) """Object GeoJSON parsed into Python data structure."""
return typing.cast(dict[str, typing.Any], json.loads(self.geojson_str))
def get_centroid(self): def get_centroid(self) -> tuple[float, float]:
"""Centroid."""
centroid = session.query(func.ST_AsText(func.ST_Centroid(self.way))).scalar() centroid = session.query(func.ST_AsText(func.ST_Centroid(self.way))).scalar()
lon, lat = re_point.match(centroid).groups() assert centroid
assert (m := re_point.match(centroid))
lon, lat = m.groups()
return (float(lat), float(lon)) return (float(lat), float(lon))
@classmethod @classmethod
def coords_within(cls, lat, lon): def coords_within(cls, lat: float, lon: float):
point = func.ST_SetSRID(func.ST_MakePoint(lon, lat), 4326) point = func.ST_SetSRID(func.ST_MakePoint(lon, lat), 4326)
return cls.query.filter( return cls.query.filter(
cls.admin_level.isnot(None), func.ST_Within(point, cls.way) cls.admin_level.isnot(None), func.ST_Within(point, cls.way)
).order_by(cls.area) ).order_by(cls.area)
@property @property
def id(self): def id(self) -> int:
"""OSM id."""
return abs(self.src_id) # relations have negative IDs return abs(self.src_id) # relations have negative IDs
@property @property
def identifier(self): @abc.abstractmethod
def type(self) -> str:
"""OSM type."""
@property
def identifier(self) -> str:
"""OSM identifier."""
return f"{self.type}/{self.id}" return f"{self.type}/{self.id}"
@property @property
def osm_url(self): def osm_url(self):
"""OSM URL."""
return f"https://www.openstreetmap.org/{self.type}/{self.id}" return f"https://www.openstreetmap.org/{self.type}/{self.id}"
class Point(MapMixin, Base): class Point(MapBase):
"""OSM planet point."""
type = "node" type = "node"
class Line(MapMixin, Base): class Line(MapBase):
"""OSM planet line."""
@property @property
def type(self): def type(self) -> str:
"""OSM type."""
return "way" if self.src_id > 0 else "relation" return "way" if self.src_id > 0 else "relation"
@classmethod @classmethod
@ -546,7 +588,9 @@ class Line(MapMixin, Base):
return cls.query.get(src_id) return cls.query.get(src_id)
class Polygon(MapMixin, Base): class Polygon(MapBase):
"""OSM planet polygon."""
way_area = Column(Float) way_area = Column(Float)
@classmethod @classmethod
@ -560,13 +604,15 @@ class Polygon(MapMixin, Base):
return "way" if self.src_id > 0 else "relation" return "way" if self.src_id > 0 else "relation"
@declared_attr @declared_attr
def area(cls): def area(cls) -> sqlalchemy.orm.properties.ColumnProperty:
"""Polygon area."""
return column_property(func.ST_Area(cls.way, False), deferred=True) return column_property(func.ST_Area(cls.way, False), deferred=True)
@hybrid_property @hybrid_property
def area_in_sq_km(self) -> float: def area_in_sq_km(self) -> float:
"""Size of area in square km.""" """Size of area in square km."""
return self.area / (1000 * 1000) area: float = self.area
return area / (1000 * 1000)
class User(Base, UserMixin): class User(Base, UserMixin):
@ -601,6 +647,8 @@ class User(Base, UserMixin):
class EditSession(Base): class EditSession(Base):
"""Edit session."""
__tablename__ = "edit_session" __tablename__ = "edit_session"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey(User.id)) user_id = Column(Integer, ForeignKey(User.id))
@ -608,8 +656,10 @@ class EditSession(Base):
edit_list = Column(postgresql.JSONB) edit_list = Column(postgresql.JSONB)
comment = Column(String) comment = Column(String)
user = relationship("User") user: Mapped[User] = relationship("User")
changeset = relationship("Changeset", back_populates="edit_session", uselist=False) changeset: Mapped["Changeset"] = relationship(
"Changeset", back_populates="edit_session", uselist=False
)
class Changeset(Base): class Changeset(Base):
@ -623,14 +673,16 @@ class Changeset(Base):
update_count = Column(Integer, nullable=False) update_count = Column(Integer, nullable=False)
edit_session_id = Column(Integer, ForeignKey(EditSession.id)) edit_session_id = Column(Integer, ForeignKey(EditSession.id))
user = relationship( user: Mapped[User] = relationship(
"User", "User",
backref=backref( backref=backref(
"changesets", lazy="dynamic", order_by="Changeset.created.desc()" "changesets", lazy="dynamic", order_by="Changeset.created.desc()"
), ),
) )
edit_session = relationship("EditSession", back_populates="changeset") edit_session: Mapped[EditSession] = relationship(
"EditSession", back_populates="changeset"
)
class ChangesetEdit(Base): class ChangesetEdit(Base):
@ -644,7 +696,9 @@ class ChangesetEdit(Base):
osm_type = Column(osm_type_enum, primary_key=True) osm_type = Column(osm_type_enum, primary_key=True)
saved = Column(DateTime, default=now_utc(), nullable=False) saved = Column(DateTime, default=now_utc(), nullable=False)
changeset = relationship("Changeset", backref=backref("edits", lazy="dynamic")) changeset: Mapped[Changeset] = relationship(
"Changeset", backref=backref("edits", lazy="dynamic")
)
class SkipIsA(Base): class SkipIsA(Base):
@ -654,7 +708,7 @@ class SkipIsA(Base):
item_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True) item_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True)
qid = column_property("Q" + cast_to_string(item_id)) qid = column_property("Q" + cast_to_string(item_id))
item = relationship("Item") item: Mapped[Item] = relationship("Item")
class ItemExtraKeys(Base): class ItemExtraKeys(Base):
@ -666,7 +720,7 @@ class ItemExtraKeys(Base):
note = Column(String) note = Column(String)
qid = column_property("Q" + cast_to_string(item_id)) qid = column_property("Q" + cast_to_string(item_id))
item = relationship("Item") item: Mapped[Item] = relationship("Item")
class Extract(Base): class Extract(Base):

View File

@ -1,17 +1,24 @@
from collections import OrderedDict """Nominatim."""
import json import json
import typing
from collections import OrderedDict
import requests import requests
from . import CallParams
Hit = dict[str, typing.Any]
class SearchError(Exception): class SearchError(Exception):
pass """Search error."""
def lookup_with_params(**kwargs): def lookup_with_params(**kwargs: str) -> list[Hit]:
url = "http://nominatim.openstreetmap.org/search" url = "http://nominatim.openstreetmap.org/search"
params = { params: CallParams = {
"format": "jsonv2", "format": "jsonv2",
"addressdetails": 1, "addressdetails": 1,
"extratags": 1, "extratags": 1,
@ -26,21 +33,24 @@ def lookup_with_params(**kwargs):
raise SearchError raise SearchError
try: try:
return json.loads(r.text, object_pairs_hook=OrderedDict) reply: list[Hit] = json.loads(r.text, object_pairs_hook=OrderedDict)
return reply
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
raise SearchError(r) raise SearchError(r)
def lookup(q): def lookup(q: str) -> list[Hit]:
"""Nominatim search."""
return lookup_with_params(q=q) return lookup_with_params(q=q)
def get_us_county(county, state): def get_us_county(county: str, state: str) -> Hit | None:
"""Search for US county and return resulting hit."""
if " " not in county and "county" not in county: if " " not in county and "county" not in county:
county += " county" county += " county"
results = lookup(q="{}, {}".format(county, state)) results = lookup(q="{}, {}".format(county, state))
def pred(hit): def pred(hit: Hit) -> typing.TypeGuard[Hit]:
return ( return (
"osm_type" in hit "osm_type" in hit
and hit["osm_type"] != "node" and hit["osm_type"] != "node"
@ -50,7 +60,8 @@ def get_us_county(county, state):
return next(filter(pred, results), None) return next(filter(pred, results), None)
def get_us_city(name, state): def get_us_city(name: str, state: str) -> Hit | None:
"""Search for US city and return resulting hit."""
results = lookup_with_params(city=name, state=state) results = lookup_with_params(city=name, state=state)
if len(results) != 1: if len(results) != 1:
results = [ results = [
@ -58,29 +69,32 @@ def get_us_city(name, state):
] ]
if len(results) != 1: if len(results) != 1:
print("more than one") print("more than one")
return return None
hit = results[0] hit = results[0]
if hit["type"] not in ("administrative", "city"): if hit["type"] not in ("administrative", "city"):
print("not a city") print("not a city")
return return None
if hit["osm_type"] == "node": if hit["osm_type"] == "node":
print("node") print("node")
return return None
if not hit["display_name"].startswith(name): if not hit["display_name"].startswith(name):
print("wrong name") print("wrong name")
return return None
assert "osm_type" in hit and "osm_id" in hit and "geotext" in hit assert "osm_type" in hit and "osm_id" in hit and "geotext" in hit
return hit return hit
def get_hit_name(hit): def get_hit_name(hit: Hit) -> str:
"""Get name from hit."""
address = hit.get("address") address = hit.get("address")
if not address: if not address:
assert isinstance(hit["display_name"], str)
return hit["display_name"] return hit["display_name"]
address_values = list(address.values()) address_values = list(address.values())
n1 = address_values[0] n1 = address_values[0]
if len(address) == 1: if len(address) == 1:
assert isinstance(n1, str)
return n1 return n1
country = address.pop("country", None) country = address.pop("country", None)
@ -102,13 +116,15 @@ def get_hit_name(hit):
return f"{n1}, {n2}, {country}" return f"{n1}, {n2}, {country}"
def get_hit_label(hit): def get_hit_label(hit: Hit) -> str:
"""Parse hit and generate label."""
tags = hit["extratags"] tags = hit["extratags"]
designation = tags.get("designation") designation = tags.get("designation")
category = hit["category"] category = hit["category"]
hit_type = hit["type"] hit_type = hit["type"]
if designation: if designation:
assert isinstance(designation, str)
return designation.replace("_", " ") return designation.replace("_", " ")
if category == "boundary" and hit_type == "administrative": if category == "boundary" and hit_type == "administrative":

View File

@ -1,14 +1,15 @@
from flask import current_app, session """OSM Authentication."""
from requests_oauthlib import OAuth1Session
from urllib.parse import urlencode import typing
from datetime import datetime from datetime import datetime
from flask import g from urllib.parse import urlencode
from .model import User
from . import user_agent_headers
import lxml.etree import lxml.etree
from flask import current_app, g, session
from requests_oauthlib import OAuth1Session
from . import user_agent_headers
from .model import User
osm_api_base = "https://api.openstreetmap.org/api/0.6" osm_api_base = "https://api.openstreetmap.org/api/0.6"
@ -67,11 +68,12 @@ def parse_userinfo_call(xml):
} }
def get_username(): def get_username() -> str | None:
"""Get username of current user."""
if "user_id" not in session: if "user_id" not in session:
return # not authorized return None # not authorized
user_id = session["user_id"] user_id = session["user_id"]
user = User.query.get(user_id) user = User.query.get(user_id)
return user.username return typing.cast(str, user.username)

View File

@ -1,6 +1,8 @@
"""Wikidata API."""
import json import json
import typing import typing
from typing import Any, cast from typing import cast
import requests import requests
import simplejson.errors import simplejson.errors
@ -9,7 +11,26 @@ from . import CallParams, user_agent_headers
wd_api_url = "https://www.wikidata.org/w/api.php" wd_api_url = "https://www.wikidata.org/w/api.php"
EntityType = dict[str, Any] Claims = dict[str, list[dict[str, typing.Any]]]
Sitelinks = dict[str, dict[str, typing.Any]]
class EntityType(typing.TypedDict, total=False):
"""Wikidata Entity."""
id: str
ns: str
type: str
pageid: int
title: str
labels: dict[str, typing.Any]
descriptions: dict[str, typing.Any]
claims: Claims
lastrevid: int
sitelinks: Sitelinks
modified: str
redirects: dict[str, typing.Any]
aliases: dict[str, list[dict[str, typing.Any]]]
def api_get(params: CallParams) -> requests.Response: def api_get(params: CallParams) -> requests.Response:

13
requirements.txt Normal file
View File

@ -0,0 +1,13 @@
flask
-e git+https://github.com/maxcountryman/flask-login.git#egg=Flask-Login
GeoIP
lxml
maxminddb
requests
sqlalchemy
requests_oauthlib
geoalchemy2
simplejson
user_agents
num2words
psycopg2