Add types and docstrings

This commit is contained in:
Edward Betts 2023-10-31 14:40:55 +00:00
parent 2229605672
commit 2e9ea504f0
1 changed files with 35 additions and 16 deletions

View File

@ -4,6 +4,7 @@ import typing
from collections import defaultdict
from typing import Any
import sqlalchemy
from flask_login import UserMixin
from geoalchemy2 import Geometry
from sqlalchemy import func
@ -15,7 +16,6 @@ from sqlalchemy.orm import backref, column_property, deferred, registry, relatio
from sqlalchemy.orm.collections import attribute_mapped_collection
from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.schema import Column, ForeignKey
from sqlalchemy.sql.expression import cast
from sqlalchemy.types import BigInteger, Boolean, DateTime, Float, Integer, String, Text
from . import mail, utils, wikidata
@ -24,7 +24,14 @@ from .database import now_utc, session
mapper_registry = registry()
def cast_to_string(v: Column[int]) -> sqlalchemy.sql.elements.Cast[str]:
"""Cast an value to a string."""
return sqlalchemy.sql.expression.cast(v, String)
class Base(metaclass=DeclarativeMeta):
"""Database model base class."""
__abstract__ = True
registry = mapper_registry
@ -94,7 +101,7 @@ class Item(Base):
locations = relationship(
"ItemLocation", cascade="all, delete-orphan", backref="item"
)
qid = column_property("Q" + cast(item_id, String))
qid = column_property("Q" + cast_to_string(item_id))
wiki_extracts = relationship(
"Extract",
@ -106,6 +113,7 @@ class Item(Base):
@classmethod
def get_by_qid(cls: typing.Type[T], qid: str) -> T | None:
"""Lookup Item via QID."""
if qid and len(qid) > 1 and qid[0].upper() == "Q" and qid[1:].isdigit():
obj: T = cls.query.get(qid[1:])
return obj
@ -246,8 +254,9 @@ class Item(Base):
isa_list.append(of_qualifier["datavalue"]["value"])
return isa_list
def get_isa_qids(self):
return [isa["id"] for isa in self.get_isa()]
def get_isa_qids(self) -> list[str]:
"""Get QIDs of items listed instance of (P31) property."""
return [typing.cast(str, isa["id"]) for isa in self.get_isa()]
def is_street(self, isa_qids=None):
if isa_qids is None:
@ -281,10 +290,12 @@ class Item(Base):
isa_qids = set(self.get_isa_qids())
return self.is_street(isa_qids) or self.is_watercourse(isa_qids)
def is_tram_stop(self):
def is_tram_stop(self) -> bool:
"""Item is a tram stop."""
return "Q2175765" in self.get_isa_qids()
def alert_admin_about_bad_time(self, v):
def alert_admin_about_bad_time(self, v: utils.WikibaseTime) -> None:
"""Send an email to admin when encountering an unparsable time in Wikibase."""
body = (
"Wikidata item has an unsupported time precision\n\n"
+ self.wd_url
@ -294,9 +305,10 @@ class Item(Base):
)
mail.send_mail(f"OWL Map: bad time value in {self.qid}", body)
def time_claim(self, pid):
def time_claim(self, pid: str) -> list[str]:
"""Read values from time statement."""
ret = []
for v in self.get_claim(pid):
for v in typing.cast(list[utils.WikibaseTime | None], self.get_claim(pid)):
if not v:
continue
try:
@ -312,15 +324,18 @@ class Item(Base):
return ret
def closed(self):
def closed(self) -> list[str]:
"""Date when item closed."""
return self.time_claim("P3999")
def first_paragraph_language(self, lang):
def first_paragraph_language(self, lang: str) -> str | None:
"""First paragraph of Wikipedia article in the given languages."""
if lang not in self.sitelinks():
return
return None
extract = self.extracts.get(lang)
if not extract:
return
return None
assert isinstance(extract, str)
empty_list = [
"<p><span></span></p>",
@ -399,6 +414,8 @@ class Item(Base):
class ItemIsA(Base):
"""Item IsA."""
__tablename__ = "item_isa"
item_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True)
isa_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True)
@ -408,14 +425,16 @@ class ItemIsA(Base):
class ItemLocation(Base):
"""Location of an item."""
__tablename__ = "item_location"
item_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True)
property_id = Column(Integer, primary_key=True)
statement_order = Column(Integer, primary_key=True)
location = Column(Geometry("POINT", srid=4326, spatial_index=True), nullable=False)
qid = column_property("Q" + cast(item_id, String))
pid = column_property("P" + cast(item_id, String))
qid = column_property("Q" + cast_to_string(item_id))
pid = column_property("P" + cast_to_string(property_id))
def get_lat_lon(self) -> tuple[float, float]:
"""Get latitude and longitude of item."""
@ -633,7 +652,7 @@ class SkipIsA(Base):
__tablename__ = "skip_isa"
item_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True)
qid = column_property("Q" + cast(item_id, String))
qid = column_property("Q" + cast_to_string(item_id))
item = relationship("Item")
@ -645,7 +664,7 @@ class ItemExtraKeys(Base):
item_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True)
tag_or_key = Column(String, primary_key=True)
note = Column(String)
qid = column_property("Q" + cast(item_id, String))
qid = column_property("Q" + cast_to_string(item_id))
item = relationship("Item")