Type hints and docstrings.

This commit is contained in:
Edward Betts 2023-05-17 16:28:44 +00:00
parent dd9078f258
commit b8ed296f78

View file

@ -6,12 +6,9 @@ import typing
import flask import flask
import geoalchemy2 import geoalchemy2
from sqlalchemy import and_, func, or_, text import sqlalchemy
from sqlalchemy.dialects import postgresql from sqlalchemy import and_, or_
from sqlalchemy.orm import selectinload
from sqlalchemy.sql import select from sqlalchemy.sql import select
from sqlalchemy.sql.expression import cast, column, literal, union
from sqlalchemy.types import Float
from matcher import database, model, wikidata, wikidata_api from matcher import database, model, wikidata, wikidata_api
from matcher.planet import line, point, polygon from matcher.planet import line, point, polygon
@ -51,10 +48,11 @@ def get_country_iso3166_1(lat: float, lon: float) -> set[str]:
Normally there should be only one country. Normally there should be only one country.
""" """
point = func.ST_SetSRID(func.ST_MakePoint(lon, lat), srid) point = sqlalchemy.func.ST_SetSRID(sqlalchemy.func.ST_MakePoint(lon, lat), srid)
alpha2_codes: set[str] = set() alpha2_codes: set[str] = set()
q = model.Polygon.query.filter( q = model.Polygon.query.filter(
func.ST_Covers(model.Polygon.way, point), model.Polygon.admin_level == "2" sqlalchemy.func.ST_Covers(model.Polygon.way, point),
model.Polygon.admin_level == "2",
) )
for country in q: for country in q:
alpha2: str = country.tags.get("ISO3166-1") alpha2: str = country.tags.get("ISO3166-1")
@ -90,13 +88,15 @@ def is_street_number_first(lat: float, lon: float) -> bool:
def make_envelope(bounds: list[float]) -> geoalchemy2.functions.ST_MakeEnvelope: def make_envelope(bounds: list[float]) -> geoalchemy2.functions.ST_MakeEnvelope:
"""Make en envelope for the given bounds.""" """Make en envelope for the given bounds."""
return func.ST_MakeEnvelope(*bounds, srid) return sqlalchemy.func.ST_MakeEnvelope(*bounds, srid)
def get_bbox_centroid(bbox: list[float]) -> tuple[str, str]: def get_bbox_centroid(bbox: list[float]) -> tuple[str, str]:
"""Get centroid of bounding box.""" """Get centroid of bounding box."""
bbox = make_envelope(bbox) bbox = make_envelope(bbox)
centroid = database.session.query(func.ST_AsText(func.ST_Centroid(bbox))).scalar() centroid = database.session.query(
sqlalchemy.func.ST_AsText(sqlalchemy.func.ST_Centroid(bbox))
).scalar()
m = re_point.match(centroid) m = re_point.match(centroid)
assert m assert m
lon, lat = m.groups() lon, lat = m.groups()
@ -107,16 +107,25 @@ def get_bbox_centroid(bbox: list[float]) -> tuple[str, str]:
def make_envelope_around_point( def make_envelope_around_point(
lat: float, lon: float, distance: float lat: float, lon: float, distance: float
) -> geoalchemy2.functions.ST_MakeEnvelope: ) -> geoalchemy2.functions.ST_MakeEnvelope:
"""Make an envelope around a point, the distance parameter specifies the size."""
conn = database.session.connection() conn = database.session.connection()
p = func.ST_MakePoint(lon, lat) p = sqlalchemy.func.ST_MakePoint(lon, lat)
s = select( s = select(
[ [
func.ST_AsText(func.ST_Project(p, distance, func.radians(0))), sqlalchemy.func.ST_AsText(
func.ST_AsText(func.ST_Project(p, distance, func.radians(90))), sqlalchemy.func.ST_Project(p, distance, sqlalchemy.func.radians(0))
func.ST_AsText(func.ST_Project(p, distance, func.radians(180))), ),
func.ST_AsText(func.ST_Project(p, distance, func.radians(270))), sqlalchemy.func.ST_AsText(
sqlalchemy.func.ST_Project(p, distance, sqlalchemy.func.radians(90))
),
sqlalchemy.func.ST_AsText(
sqlalchemy.func.ST_Project(p, distance, sqlalchemy.func.radians(180))
),
sqlalchemy.func.ST_AsText(
sqlalchemy.func.ST_Project(p, distance, sqlalchemy.func.radians(270))
),
] ]
) )
row = conn.execute(s).fetchone() row = conn.execute(s).fetchone()
@ -127,7 +136,7 @@ def make_envelope_around_point(
south = coords[2][1] south = coords[2][1]
west = coords[3][0] west = coords[3][0]
return func.ST_MakeEnvelope(west, south, east, north, srid) return sqlalchemy.func.ST_MakeEnvelope(west, south, east, north, srid)
def drop_way_area(tags: TagsType) -> TagsType: def drop_way_area(tags: TagsType) -> TagsType:
@ -146,13 +155,13 @@ def get_part_of(table_name, src_id, bbox):
[ [
polygon.c.osm_id, polygon.c.osm_id,
polygon.c.tags, polygon.c.tags,
func.ST_Area(func.ST_Collect(polygon.c.way)), sqlalchemy.func.ST_Area(sqlalchemy.func.ST_Collect(polygon.c.way)),
] ]
) )
.where( .where(
and_( and_(
func.ST_Intersects(bbox, polygon.c.way), sqlalchemy.func.ST_Intersects(bbox, polygon.c.way),
func.ST_Covers(polygon.c.way, table_alias.c.way), sqlalchemy.func.ST_Covers(polygon.c.way, table_alias.c.way),
table_alias.c.osm_id == src_id, table_alias.c.osm_id == src_id,
polygon.c.tags.has_key("name"), polygon.c.tags.has_key("name"),
or_( or_(
@ -227,8 +236,8 @@ def get_items_in_bbox(bbox: list[float]):
q = ( q = (
model.Item.query.join(model.ItemLocation) model.Item.query.join(model.ItemLocation)
.filter(func.ST_Covers(db_bbox, model.ItemLocation.location)) .filter(sqlalchemy.func.ST_Covers(db_bbox, model.ItemLocation.location))
.options(selectinload(model.Item.locations)) .options(sqlalchemy.orm.selectinload(model.Item.locations))
) )
return q return q
@ -239,7 +248,7 @@ def get_osm_with_wikidata_tag(bbox, isa_filter=None):
extra_sql = "" extra_sql = ""
if isa_filter: if isa_filter:
q = model.Item.query.join(model.ItemLocation).filter( q = model.Item.query.join(model.ItemLocation).filter(
func.ST_Covers(make_envelope(bbox), model.ItemLocation.location) sqlalchemy.func.ST_Covers(make_envelope(bbox), model.ItemLocation.location)
) )
q = add_isa_filter(q, isa_filter) q = add_isa_filter(q, isa_filter)
qids = [isa.qid for isa in q] qids = [isa.qid for isa in q]
@ -274,7 +283,7 @@ WHERE tags ? 'wikidata'
+ extra_sql + extra_sql
) )
conn = database.session.connection() conn = database.session.connection()
result = conn.execute(text(sql)) result = conn.execute(sqlalchemy.text(sql))
# print(sql) # print(sql)
@ -344,7 +353,9 @@ def get_item_tags(item: model.Item) -> dict[str, list[str]]:
osm_list = collections.defaultdict(list) osm_list = collections.defaultdict(list)
skip_isa = {row[0] for row in database.session.query(model.SkipIsA.item_id)} skip_isa: set[int] = {
row[0] for row in database.session.query(model.SkipIsA.item_id)
}
tram_stop_id = 41176 tram_stop_id = 41176
airport_id = 1248784 airport_id = 1248784
@ -352,7 +363,7 @@ def get_item_tags(item: model.Item) -> dict[str, list[str]]:
if {tram_stop_id, airport_id, aerodrome_id} & set(isa_list): if {tram_stop_id, airport_id, aerodrome_id} & set(isa_list):
skip_isa.add(41176) # building (Q41176) skip_isa.add(41176) # building (Q41176)
seen = set(isa_list) | skip_isa seen: set[int] = set(isa_list) | skip_isa
stop = { stop = {
"Q11799049": "public institution", "Q11799049": "public institution",
"Q7075": "library", "Q7075": "library",
@ -364,7 +375,9 @@ def get_item_tags(item: model.Item) -> dict[str, list[str]]:
continue continue
isa_qid: str = typing.cast(str, isa.qid) isa_qid: str = typing.cast(str, isa.qid)
isa_path = isa_path + [{"qid": isa_qid, "label": isa.label()}] isa_path = isa_path + [{"qid": isa_qid, "label": isa.label()}]
osm = [v for v in isa.get_claim("P1282") if v not in skip_tags] osm: list[str] = [
typing.cast(str, v) for v in isa.get_claim("P1282") if v not in skip_tags
]
osm += [ osm += [
extra.tag_or_key extra.tag_or_key
@ -378,7 +391,7 @@ def get_item_tags(item: model.Item) -> dict[str, list[str]]:
# item is specific enough, no need to keep walking the item hierarchy # item is specific enough, no need to keep walking the item hierarchy
continue continue
check = set() check: set[int] = set()
properties = [ properties = [
("P279", "subclass of"), ("P279", "subclass of"),
("P140", "religion"), ("P140", "religion"),
@ -389,11 +402,15 @@ def get_item_tags(item: model.Item) -> dict[str, list[str]]:
] ]
for pid, label in properties: for pid, label in properties:
check |= {v["numeric-id"] for v in (isa.get_claim(pid) or []) if v} check |= {
typing.cast(dict[str, int], v)["numeric-id"]
for v in (isa.get_claim(pid) or [])
if v
}
print(isa.qid, isa.label(), check) print(isa.qid, isa.label(), check)
isa_list = check - seen isa_list_set = check - seen
seen.update(isa_list) seen.update(isa_list_set)
isa_items += [(isa, isa_path) for isa in get_items(isa_list)] isa_items += [(isa, isa_path) for isa in get_items(isa_list)]
return {key: list(values) for key, values in osm_list.items()} return {key: list(values) for key, values in osm_list.items()}
@ -467,7 +484,7 @@ def get_tags_for_isa_item(item):
def add_isa_filter(q, isa_qids): def add_isa_filter(q, isa_qids):
q_subclass = database.session.query(model.Item.qid).filter( q_subclass = database.session.query(model.Item.qid).filter(
func.jsonb_path_query_array( sqlalchemy.func.jsonb_path_query_array(
model.Item.claims, model.Item.claims,
"$.P279[*].mainsnak.datavalue.value.id", "$.P279[*].mainsnak.datavalue.value.id",
).bool_op("?|")(list(isa_qids)) ).bool_op("?|")(list(isa_qids))
@ -475,7 +492,7 @@ def add_isa_filter(q, isa_qids):
subclass_qid = {qid for qid, in q_subclass.all()} subclass_qid = {qid for qid, in q_subclass.all()}
isa = func.jsonb_path_query_array( isa = sqlalchemy.func.jsonb_path_query_array(
model.Item.claims, model.Item.claims,
"$.P31[*].mainsnak.datavalue.value.id", "$.P31[*].mainsnak.datavalue.value.id",
).bool_op("?|") ).bool_op("?|")
@ -484,7 +501,7 @@ def add_isa_filter(q, isa_qids):
def wikidata_items_count(bounds, isa_filter=None): def wikidata_items_count(bounds, isa_filter=None):
q = model.Item.query.join(model.ItemLocation).filter( q = model.Item.query.join(model.ItemLocation).filter(
func.ST_Covers(make_envelope(bounds), model.ItemLocation.location) sqlalchemy.func.ST_Covers(make_envelope(bounds), model.ItemLocation.location)
) )
if isa_filter: if isa_filter:
@ -499,7 +516,7 @@ def wikidata_isa_counts(bounds, isa_filter=None):
db_bbox = make_envelope(bounds) db_bbox = make_envelope(bounds)
q = model.Item.query.join(model.ItemLocation).filter( q = model.Item.query.join(model.ItemLocation).filter(
func.ST_Covers(db_bbox, model.ItemLocation.location) sqlalchemy.func.ST_Covers(db_bbox, model.ItemLocation.location)
) )
if isa_filter: if isa_filter:
@ -529,8 +546,11 @@ def wikidata_isa_counts(bounds, isa_filter=None):
return isa_count return isa_count
def get_tag_filter(tags, tag_list): def get_tag_filter(
tags: sqlalchemy.sql.schema.Column, tag_list: list[str]
) -> list[sqlalchemy.sql.elements.BooleanClauseList]:
tag_filter = [] tag_filter = []
print("tags type:", type(tags))
for tag_or_key in tag_list: for tag_or_key in tag_list:
if tag_or_key.startswith("Key:"): if tag_or_key.startswith("Key:"):
key = tag_or_key[4:] key = tag_or_key[4:]
@ -544,10 +564,11 @@ def get_tag_filter(tags, tag_list):
for prefix in tag_prefixes: for prefix in tag_prefixes:
tag_filter.append(tags[f"{prefix}:{k}"] == v) tag_filter.append(tags[f"{prefix}:{k}"] == v)
print("tag_filter type:", [type(i) for i in tag_filter])
return tag_filter return tag_filter
def get_preset_translations(): def get_preset_translations() -> dict[str, typing.Any]:
app = flask.current_app app = flask.current_app
country_language = { country_language = {
"AU": "en-AU", # Australia "AU": "en-AU", # Australia
@ -569,7 +590,9 @@ def get_preset_translations():
continue continue
try: try:
return json_data[lang_code]["presets"]["presets"] return typing.cast(
dict[str, typing.Any], json_data[lang_code]["presets"]["presets"]
)
except KeyError: except KeyError:
pass pass
@ -665,8 +688,13 @@ def address_node_label(tags: TagsType) -> str | None:
def get_address_nodes_within_building(osm_id, bbox_list): def get_address_nodes_within_building(osm_id, bbox_list):
q = model.Point.query.filter( q = model.Point.query.filter(
polygon.c.osm_id == osm_id, polygon.c.osm_id == osm_id,
or_(*[func.ST_Intersects(bbox, model.Point.way) for bbox in bbox_list]), or_(
func.ST_Covers(polygon.c.way, model.Point.way), *[
sqlalchemy.func.ST_Intersects(bbox, model.Point.way)
for bbox in bbox_list
]
),
sqlalchemy.func.ST_Covers(polygon.c.way, model.Point.way),
model.Point.tags.has_key("addr:street"), model.Point.tags.has_key("addr:street"),
model.Point.tags.has_key("addr:housenumber"), model.Point.tags.has_key("addr:housenumber"),
) )
@ -708,9 +736,11 @@ def find_osm_candidates(item, limit=80, max_distance=450, names=None):
for loc in item.locations for loc in item.locations
] ]
null_area = cast(None, Float) null_area = sqlalchemy.sql.expression.cast(None, sqlalchemy.types.Float)
dist = column("dist") dist = sqlalchemy.sql.expression.column("dist")
tags = column("tags", postgresql.HSTORE) tags = sqlalchemy.sql.expression.column(
"tags", sqlalchemy.dialects.postgresql.HSTORE
)
tag_list = get_item_tags(item) tag_list = get_item_tags(item)
# tag_filters = get_tag_filter(point.c.tags, tag_list) # tag_filters = get_tag_filter(point.c.tags, tag_list)
@ -719,20 +749,27 @@ def find_osm_candidates(item, limit=80, max_distance=450, names=None):
s_point = ( s_point = (
select( select(
[ [
literal("point").label("t"), sqlalchemy.sql.expression.literal("point").label("t"),
point.c.osm_id, point.c.osm_id,
point.c.tags.label("tags"), point.c.tags.label("tags"),
func.min( sqlalchemy.func.min(
func.ST_DistanceSphere(model.ItemLocation.location, point.c.way) sqlalchemy.func.ST_DistanceSphere(
model.ItemLocation.location, point.c.way
)
).label("dist"), ).label("dist"),
func.ST_AsText(point.c.way), sqlalchemy.func.ST_AsText(point.c.way),
func.ST_AsGeoJSON(point.c.way), sqlalchemy.func.ST_AsGeoJSON(point.c.way),
null_area, null_area,
] ]
) )
.where( .where(
and_( and_(
or_(*[func.ST_Intersects(bbox, point.c.way) for bbox in bbox_list]), or_(
*[
sqlalchemy.func.ST_Intersects(bbox, point.c.way)
for bbox in bbox_list
]
),
model.ItemLocation.item_id == item_id, model.ItemLocation.item_id == item_id,
or_(*get_tag_filter(point.c.tags, tag_list)), or_(*get_tag_filter(point.c.tags, tag_list)),
) )
@ -743,20 +780,29 @@ def find_osm_candidates(item, limit=80, max_distance=450, names=None):
s_line = ( s_line = (
select( select(
[ [
literal("line").label("t"), sqlalchemy.sql.expression.literal("line").label("t"),
line.c.osm_id, line.c.osm_id,
line.c.tags.label("tags"), line.c.tags.label("tags"),
func.min( sqlalchemy.func.min(
func.ST_DistanceSphere(model.ItemLocation.location, line.c.way) sqlalchemy.func.ST_DistanceSphere(
model.ItemLocation.location, line.c.way
)
).label("dist"), ).label("dist"),
func.ST_AsText(func.ST_Centroid(func.ST_Collect(line.c.way))), sqlalchemy.func.ST_AsText(
func.ST_AsGeoJSON(func.ST_Collect(line.c.way)), sqlalchemy.func.ST_Centroid(sqlalchemy.func.ST_Collect(line.c.way))
),
sqlalchemy.func.ST_AsGeoJSON(sqlalchemy.func.ST_Collect(line.c.way)),
null_area, null_area,
] ]
) )
.where( .where(
and_( and_(
or_(*[func.ST_Intersects(bbox, line.c.way) for bbox in bbox_list]), or_(
*[
sqlalchemy.func.ST_Intersects(bbox, line.c.way)
for bbox in bbox_list
]
),
model.ItemLocation.item_id == item_id, model.ItemLocation.item_id == item_id,
or_(*get_tag_filter(line.c.tags, tag_list)), or_(*get_tag_filter(line.c.tags, tag_list)),
) )
@ -767,33 +813,48 @@ def find_osm_candidates(item, limit=80, max_distance=450, names=None):
s_polygon = ( s_polygon = (
select( select(
[ [
literal("polygon").label("t"), sqlalchemy.sql.expression.literal("polygon").label("t"),
polygon.c.osm_id, polygon.c.osm_id,
polygon.c.tags.label("tags"), polygon.c.tags.label("tags"),
func.min( sqlalchemy.func.min(
func.ST_DistanceSphere(model.ItemLocation.location, polygon.c.way) sqlalchemy.func.ST_DistanceSphere(
model.ItemLocation.location, polygon.c.way
)
).label("dist"), ).label("dist"),
func.ST_AsText(func.ST_Centroid(func.ST_Collect(polygon.c.way))), sqlalchemy.func.ST_AsText(
func.ST_AsGeoJSON(func.ST_Collect(polygon.c.way)), sqlalchemy.func.ST_Centroid(
func.ST_Area(func.ST_Collect(polygon.c.way)), sqlalchemy.func.ST_Collect(polygon.c.way)
)
),
sqlalchemy.func.ST_AsGeoJSON(sqlalchemy.func.ST_Collect(polygon.c.way)),
sqlalchemy.func.ST_Area(sqlalchemy.func.ST_Collect(polygon.c.way)),
] ]
) )
.where( .where(
and_( and_(
or_(*[func.ST_Intersects(bbox, polygon.c.way) for bbox in bbox_list]), or_(
*[
sqlalchemy.func.ST_Intersects(bbox, polygon.c.way)
for bbox in bbox_list
]
),
model.ItemLocation.item_id == item_id, model.ItemLocation.item_id == item_id,
or_(*get_tag_filter(polygon.c.tags, tag_list)), or_(*get_tag_filter(polygon.c.tags, tag_list)),
) )
) )
.group_by(polygon.c.osm_id, polygon.c.tags) .group_by(polygon.c.osm_id, polygon.c.tags)
.having( .having(
func.ST_Area(func.ST_Collect(polygon.c.way)) sqlalchemy.func.ST_Area(sqlalchemy.func.ST_Collect(polygon.c.way))
< 20 * func.ST_Area(bbox_list[0]) < 20 * sqlalchemy.func.ST_Area(bbox_list[0])
) )
) )
tables = ([] if item_is_linear_feature else [s_point]) + [s_line, s_polygon] tables = ([] if item_is_linear_feature else [s_point]) + [s_line, s_polygon]
s = select([union(*tables).alias()]).where(dist < max_distance).order_by(dist) s = (
select([sqlalchemy.sql.expression.union(*tables).alias()])
.where(dist < max_distance)
.order_by(dist)
)
if names: if names:
s = s.where(or_(tags["name"].in_(names), tags["old_name"].in_(names))) s = s.where(or_(tags["name"].in_(names), tags["old_name"].in_(names)))
@ -1056,24 +1117,24 @@ def missing_wikidata_items(qids, lat, lon):
return dict(items=items, isa_count=isa_count) return dict(items=items, isa_count=isa_count)
def isa_incremental_search(search_terms): def isa_incremental_search(search_terms: str) -> list[dict[str, str]]:
en_label = func.jsonb_extract_path_text(model.Item.labels, "en", "value") """Incremental search."""
en_label = sqlalchemy.func.jsonb_extract_path_text(model.Item.labels, "en", "value")
q = model.Item.query.filter( q = model.Item.query.filter(
model.Item.claims.has_key("P1282"), model.Item.claims.has_key("P1282"),
en_label.ilike(f"%{search_terms}%"), en_label.ilike(f"%{search_terms}%"),
func.length(en_label) < 20, sqlalchemy.func.length(en_label) < 20,
) )
# print(q.statement.compile(compile_kwargs={"literal_binds": True})) # print(q.statement.compile(compile_kwargs={"literal_binds": True}))
ret = [] return [
for item in q: {
cur = {
"qid": item.qid, "qid": item.qid,
"label": item.label(), "label": item.label(),
} }
ret.append(cur) for item in q
return ret ]
class PlaceItems(typing.TypedDict): class PlaceItems(typing.TypedDict):
@ -1091,7 +1152,7 @@ def get_place_items(osm_type: str, osm_id: int) -> PlaceItems:
model.Item.query.join(model.ItemLocation) model.Item.query.join(model.ItemLocation)
.join( .join(
model.Polygon, model.Polygon,
func.ST_Covers(model.Polygon.way, model.ItemLocation.location), sqlalchemy.func.ST_Covers(model.Polygon.way, model.ItemLocation.location),
) )
.filter(model.Polygon.src_id == src_id) .filter(model.Polygon.src_id == src_id)
) )