Add types and docstrings + upgrade to SQLAlchmey 2

This commit is contained in:
Edward Betts 2023-11-01 20:54:19 +00:00
parent 82671959bb
commit 5e8d1a99b0
8 changed files with 248 additions and 125 deletions

View file

@ -1,3 +1,4 @@
import abc
import json
import re
import typing
@ -12,13 +13,21 @@ from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import backref, column_property, deferred, registry, relationship
from sqlalchemy.orm import (
Mapped,
QueryPropertyDescriptor,
backref,
column_property,
deferred,
registry,
relationship,
)
from sqlalchemy.orm.collections import attribute_mapped_collection
from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.schema import Column, ForeignKey
from sqlalchemy.types import BigInteger, Boolean, DateTime, Float, Integer, String, Text
from . import mail, utils, wikidata
from . import mail, utils, wikidata, wikidata_api
from .database import now_utc, session
mapper_registry = registry()
@ -36,7 +45,7 @@ class Base(metaclass=DeclarativeMeta):
registry = mapper_registry
metadata = mapper_registry.metadata
query = session.query_property()
query: QueryPropertyDescriptor = session.query_property()
__init__ = mapper_registry.constructor
@ -98,12 +107,12 @@ class Item(Base):
sitelinks = Column(postgresql.JSONB)
claims = Column(postgresql.JSONB, nullable=False)
lastrevid = Column(Integer, nullable=False, unique=True)
locations = relationship(
locations: Mapped[list["ItemLocation"]] = relationship(
"ItemLocation", cascade="all, delete-orphan", backref="item"
)
qid = column_property("Q" + cast_to_string(item_id))
qid: Mapped[str] = column_property("Q" + cast_to_string(item_id))
wiki_extracts = relationship(
wiki_extracts: Mapped[list["Extract"]] = relationship(
"Extract",
collection_class=attribute_mapped_collection("site"),
cascade="save-update, merge, delete, delete-orphan",
@ -158,14 +167,15 @@ class Item(Base):
if d_list:
return d_list[0]["value"]
def get_aliases(self, lang="en"):
def get_aliases(self, lang: str = "en") -> list[str]:
"""Get aliases."""
if lang not in self.aliases:
if "en" not in self.aliases:
return []
lang = "en"
return [a["value"] for a in self.aliases[lang]]
def get_part_of_names(self):
def get_part_of_names(self) -> set[str]:
if not self.claims:
return set()
@ -186,11 +196,14 @@ class Item(Base):
return part_of_names
@property
def entity(self):
def entity(self) -> wikidata_api.EntityType:
"""Entity."""
keys = ["labels", "aliases", "descriptions", "sitelinks", "claims"]
return {key: getattr(self, key) for key in keys}
return typing.cast(
wikidata_api.EntityType, {key: getattr(self, key) for key in keys}
)
def names(self, check_part_of=True):
def names(self, check_part_of: bool = True) -> dict[str, list[str]] | None:
part_of_names = self.get_part_of_names() if check_part_of else set()
d = wikidata.names_from_entity(self.entity) or defaultdict(list)
@ -258,7 +271,8 @@ class Item(Base):
"""Get QIDs of items listed instance of (P31) property."""
return [typing.cast(str, isa["id"]) for isa in self.get_isa()]
def is_street(self, isa_qids=None):
def is_street(self, isa_qids: typing.Collection[str] | None = None) -> bool:
"""Item represents a street."""
if isa_qids is None:
isa_qids = self.get_isa_qids()
@ -272,7 +286,8 @@ class Item(Base):
}
return bool(matching_types & set(isa_qids))
def is_watercourse(self, isa_qids=None):
def is_watercourse(self, isa_qids: typing.Collection[str] | None = None) -> bool:
"""Item represents a watercourse."""
if isa_qids is None:
isa_qids = self.get_isa_qids()
matching_types = {
@ -368,7 +383,7 @@ class Item(Base):
return text[: first_end_p_tag + len(close_tag)]
def get_identifiers_tags(self):
def get_identifiers_tags(self) -> dict[str, list[tuple[list[str], str]]]:
tags = defaultdict(list)
for claim, osm_keys, label in property_map:
values = [
@ -386,7 +401,7 @@ class Item(Base):
tags[osm_key].append((values, label))
return dict(tags)
def get_identifiers(self):
def get_identifiers(self) -> dict[str, list[str]]:
ret = {}
for claim, osm_keys, label in property_map:
values = [
@ -420,8 +435,8 @@ class ItemIsA(Base):
item_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True)
isa_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True)
item = relationship("Item", foreign_keys=[item_id])
isa = relationship("Item", foreign_keys=[isa_id])
item: Mapped[Item] = relationship("Item", foreign_keys=[item_id])
isa: Mapped[Item] = relationship("Item", foreign_keys=[isa_id])
class ItemLocation(Base):
@ -458,7 +473,9 @@ def location_objects(
return locations
class MapMixin:
class MapBase(Base):
"""Map base class."""
@declared_attr
def __tablename__(cls):
return "planet_osm_" + cls.__name__.lower()
@ -468,7 +485,7 @@ class MapMixin:
admin_level = Column(String)
boundary = Column(String)
tags = Column(postgresql.HSTORE)
tags: Mapped[postgresql.HSTORE]
@declared_attr
def way(cls):
@ -477,67 +494,92 @@ class MapMixin:
)
@declared_attr
def kml(cls):
def kml(cls) -> sqlalchemy.orm.properties.ColumnProperty:
"""Get object in KML format."""
return column_property(func.ST_AsKML(cls.way), deferred=True)
@declared_attr
def geojson_str(cls):
def geojson_str(cls) -> sqlalchemy.orm.properties.ColumnProperty:
"""Get object as GeoJSON string."""
return column_property(
func.ST_AsGeoJSON(cls.way, maxdecimaldigits=6), deferred=True
)
@declared_attr
def as_EWKT(cls):
def as_EWKT(cls) -> sqlalchemy.orm.properties.ColumnProperty:
"""As EWKT."""
return column_property(func.ST_AsEWKT(cls.way), deferred=True)
@hybrid_property
def has_street_address(self):
def has_street_address(self) -> bool:
"""Object has street address."""
return "addr:housenumber" in self.tags and "addr:street" in self.tags
def display_name(self):
def display_name(self) -> str:
"""Name for display."""
for key in "bridge:name", "tunnel:name", "lock_name":
if key in self.tags:
return self.tags[key]
return typing.cast(str, self.tags[key])
return (
self.name or self.tags.get("addr:housename") or self.tags.get("inscription")
return typing.cast(
str,
self.name
or self.tags.get("addr:housename")
or self.tags.get("inscription"),
)
def geojson(self):
return json.loads(self.geojson_str)
def geojson(self) -> dict[str, typing.Any]:
"""Object GeoJSON parsed into Python data structure."""
return typing.cast(dict[str, typing.Any], json.loads(self.geojson_str))
def get_centroid(self):
def get_centroid(self) -> tuple[float, float]:
"""Centroid."""
centroid = session.query(func.ST_AsText(func.ST_Centroid(self.way))).scalar()
lon, lat = re_point.match(centroid).groups()
assert centroid
assert (m := re_point.match(centroid))
lon, lat = m.groups()
return (float(lat), float(lon))
@classmethod
def coords_within(cls, lat, lon):
def coords_within(cls, lat: float, lon: float):
point = func.ST_SetSRID(func.ST_MakePoint(lon, lat), 4326)
return cls.query.filter(
cls.admin_level.isnot(None), func.ST_Within(point, cls.way)
).order_by(cls.area)
@property
def id(self):
def id(self) -> int:
"""OSM id."""
return abs(self.src_id) # relations have negative IDs
@property
def identifier(self):
@abc.abstractmethod
def type(self) -> str:
"""OSM type."""
@property
def identifier(self) -> str:
"""OSM identifier."""
return f"{self.type}/{self.id}"
@property
def osm_url(self):
"""OSM URL."""
return f"https://www.openstreetmap.org/{self.type}/{self.id}"
class Point(MapMixin, Base):
class Point(MapBase):
"""OSM planet point."""
type = "node"
class Line(MapMixin, Base):
class Line(MapBase):
"""OSM planet line."""
@property
def type(self):
def type(self) -> str:
"""OSM type."""
return "way" if self.src_id > 0 else "relation"
@classmethod
@ -546,7 +588,9 @@ class Line(MapMixin, Base):
return cls.query.get(src_id)
class Polygon(MapMixin, Base):
class Polygon(MapBase):
"""OSM planet polygon."""
way_area = Column(Float)
@classmethod
@ -560,13 +604,15 @@ class Polygon(MapMixin, Base):
return "way" if self.src_id > 0 else "relation"
@declared_attr
def area(cls):
def area(cls) -> sqlalchemy.orm.properties.ColumnProperty:
"""Polygon area."""
return column_property(func.ST_Area(cls.way, False), deferred=True)
@hybrid_property
def area_in_sq_km(self) -> float:
"""Size of area in square km."""
return self.area / (1000 * 1000)
area: float = self.area
return area / (1000 * 1000)
class User(Base, UserMixin):
@ -601,6 +647,8 @@ class User(Base, UserMixin):
class EditSession(Base):
"""Edit session."""
__tablename__ = "edit_session"
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey(User.id))
@ -608,8 +656,10 @@ class EditSession(Base):
edit_list = Column(postgresql.JSONB)
comment = Column(String)
user = relationship("User")
changeset = relationship("Changeset", back_populates="edit_session", uselist=False)
user: Mapped[User] = relationship("User")
changeset: Mapped["Changeset"] = relationship(
"Changeset", back_populates="edit_session", uselist=False
)
class Changeset(Base):
@ -623,14 +673,16 @@ class Changeset(Base):
update_count = Column(Integer, nullable=False)
edit_session_id = Column(Integer, ForeignKey(EditSession.id))
user = relationship(
user: Mapped[User] = relationship(
"User",
backref=backref(
"changesets", lazy="dynamic", order_by="Changeset.created.desc()"
),
)
edit_session = relationship("EditSession", back_populates="changeset")
edit_session: Mapped[EditSession] = relationship(
"EditSession", back_populates="changeset"
)
class ChangesetEdit(Base):
@ -644,7 +696,9 @@ class ChangesetEdit(Base):
osm_type = Column(osm_type_enum, primary_key=True)
saved = Column(DateTime, default=now_utc(), nullable=False)
changeset = relationship("Changeset", backref=backref("edits", lazy="dynamic"))
changeset: Mapped[Changeset] = relationship(
"Changeset", backref=backref("edits", lazy="dynamic")
)
class SkipIsA(Base):
@ -654,7 +708,7 @@ class SkipIsA(Base):
item_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True)
qid = column_property("Q" + cast_to_string(item_id))
item = relationship("Item")
item: Mapped[Item] = relationship("Item")
class ItemExtraKeys(Base):
@ -666,7 +720,7 @@ class ItemExtraKeys(Base):
note = Column(String)
qid = column_property("Q" + cast_to_string(item_id))
item = relationship("Item")
item: Mapped[Item] = relationship("Item")
class Extract(Base):