Type hints and docstrings.

This commit is contained in:
Edward Betts 2023-05-17 16:28:44 +00:00
parent dd9078f258
commit b8ed296f78

View file

@ -6,12 +6,9 @@ import typing
import flask
import geoalchemy2
from sqlalchemy import and_, func, or_, text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import selectinload
import sqlalchemy
from sqlalchemy import and_, or_
from sqlalchemy.sql import select
from sqlalchemy.sql.expression import cast, column, literal, union
from sqlalchemy.types import Float
from matcher import database, model, wikidata, wikidata_api
from matcher.planet import line, point, polygon
@ -51,10 +48,11 @@ def get_country_iso3166_1(lat: float, lon: float) -> set[str]:
Normally there should be only one country.
"""
point = func.ST_SetSRID(func.ST_MakePoint(lon, lat), srid)
point = sqlalchemy.func.ST_SetSRID(sqlalchemy.func.ST_MakePoint(lon, lat), srid)
alpha2_codes: set[str] = set()
q = model.Polygon.query.filter(
func.ST_Covers(model.Polygon.way, point), model.Polygon.admin_level == "2"
sqlalchemy.func.ST_Covers(model.Polygon.way, point),
model.Polygon.admin_level == "2",
)
for country in q:
alpha2: str = country.tags.get("ISO3166-1")
@ -90,13 +88,15 @@ def is_street_number_first(lat: float, lon: float) -> bool:
def make_envelope(bounds: list[float]) -> geoalchemy2.functions.ST_MakeEnvelope:
"""Make en envelope for the given bounds."""
return func.ST_MakeEnvelope(*bounds, srid)
return sqlalchemy.func.ST_MakeEnvelope(*bounds, srid)
def get_bbox_centroid(bbox: list[float]) -> tuple[str, str]:
"""Get centroid of bounding box."""
bbox = make_envelope(bbox)
centroid = database.session.query(func.ST_AsText(func.ST_Centroid(bbox))).scalar()
centroid = database.session.query(
sqlalchemy.func.ST_AsText(sqlalchemy.func.ST_Centroid(bbox))
).scalar()
m = re_point.match(centroid)
assert m
lon, lat = m.groups()
@ -107,16 +107,25 @@ def get_bbox_centroid(bbox: list[float]) -> tuple[str, str]:
def make_envelope_around_point(
lat: float, lon: float, distance: float
) -> geoalchemy2.functions.ST_MakeEnvelope:
"""Make an envelope around a point, the distance parameter specifies the size."""
conn = database.session.connection()
p = func.ST_MakePoint(lon, lat)
p = sqlalchemy.func.ST_MakePoint(lon, lat)
s = select(
[
func.ST_AsText(func.ST_Project(p, distance, func.radians(0))),
func.ST_AsText(func.ST_Project(p, distance, func.radians(90))),
func.ST_AsText(func.ST_Project(p, distance, func.radians(180))),
func.ST_AsText(func.ST_Project(p, distance, func.radians(270))),
sqlalchemy.func.ST_AsText(
sqlalchemy.func.ST_Project(p, distance, sqlalchemy.func.radians(0))
),
sqlalchemy.func.ST_AsText(
sqlalchemy.func.ST_Project(p, distance, sqlalchemy.func.radians(90))
),
sqlalchemy.func.ST_AsText(
sqlalchemy.func.ST_Project(p, distance, sqlalchemy.func.radians(180))
),
sqlalchemy.func.ST_AsText(
sqlalchemy.func.ST_Project(p, distance, sqlalchemy.func.radians(270))
),
]
)
row = conn.execute(s).fetchone()
@ -127,7 +136,7 @@ def make_envelope_around_point(
south = coords[2][1]
west = coords[3][0]
return func.ST_MakeEnvelope(west, south, east, north, srid)
return sqlalchemy.func.ST_MakeEnvelope(west, south, east, north, srid)
def drop_way_area(tags: TagsType) -> TagsType:
@ -146,13 +155,13 @@ def get_part_of(table_name, src_id, bbox):
[
polygon.c.osm_id,
polygon.c.tags,
func.ST_Area(func.ST_Collect(polygon.c.way)),
sqlalchemy.func.ST_Area(sqlalchemy.func.ST_Collect(polygon.c.way)),
]
)
.where(
and_(
func.ST_Intersects(bbox, polygon.c.way),
func.ST_Covers(polygon.c.way, table_alias.c.way),
sqlalchemy.func.ST_Intersects(bbox, polygon.c.way),
sqlalchemy.func.ST_Covers(polygon.c.way, table_alias.c.way),
table_alias.c.osm_id == src_id,
polygon.c.tags.has_key("name"),
or_(
@ -227,8 +236,8 @@ def get_items_in_bbox(bbox: list[float]):
q = (
model.Item.query.join(model.ItemLocation)
.filter(func.ST_Covers(db_bbox, model.ItemLocation.location))
.options(selectinload(model.Item.locations))
.filter(sqlalchemy.func.ST_Covers(db_bbox, model.ItemLocation.location))
.options(sqlalchemy.orm.selectinload(model.Item.locations))
)
return q
@ -239,7 +248,7 @@ def get_osm_with_wikidata_tag(bbox, isa_filter=None):
extra_sql = ""
if isa_filter:
q = model.Item.query.join(model.ItemLocation).filter(
func.ST_Covers(make_envelope(bbox), model.ItemLocation.location)
sqlalchemy.func.ST_Covers(make_envelope(bbox), model.ItemLocation.location)
)
q = add_isa_filter(q, isa_filter)
qids = [isa.qid for isa in q]
@ -274,7 +283,7 @@ WHERE tags ? 'wikidata'
+ extra_sql
)
conn = database.session.connection()
result = conn.execute(text(sql))
result = conn.execute(sqlalchemy.text(sql))
# print(sql)
@ -344,7 +353,9 @@ def get_item_tags(item: model.Item) -> dict[str, list[str]]:
osm_list = collections.defaultdict(list)
skip_isa = {row[0] for row in database.session.query(model.SkipIsA.item_id)}
skip_isa: set[int] = {
row[0] for row in database.session.query(model.SkipIsA.item_id)
}
tram_stop_id = 41176
airport_id = 1248784
@ -352,7 +363,7 @@ def get_item_tags(item: model.Item) -> dict[str, list[str]]:
if {tram_stop_id, airport_id, aerodrome_id} & set(isa_list):
skip_isa.add(41176) # building (Q41176)
seen = set(isa_list) | skip_isa
seen: set[int] = set(isa_list) | skip_isa
stop = {
"Q11799049": "public institution",
"Q7075": "library",
@ -364,7 +375,9 @@ def get_item_tags(item: model.Item) -> dict[str, list[str]]:
continue
isa_qid: str = typing.cast(str, isa.qid)
isa_path = isa_path + [{"qid": isa_qid, "label": isa.label()}]
osm = [v for v in isa.get_claim("P1282") if v not in skip_tags]
osm: list[str] = [
typing.cast(str, v) for v in isa.get_claim("P1282") if v not in skip_tags
]
osm += [
extra.tag_or_key
@ -378,7 +391,7 @@ def get_item_tags(item: model.Item) -> dict[str, list[str]]:
# item is specific enough, no need to keep walking the item hierarchy
continue
check = set()
check: set[int] = set()
properties = [
("P279", "subclass of"),
("P140", "religion"),
@ -389,11 +402,15 @@ def get_item_tags(item: model.Item) -> dict[str, list[str]]:
]
for pid, label in properties:
check |= {v["numeric-id"] for v in (isa.get_claim(pid) or []) if v}
check |= {
typing.cast(dict[str, int], v)["numeric-id"]
for v in (isa.get_claim(pid) or [])
if v
}
print(isa.qid, isa.label(), check)
isa_list = check - seen
seen.update(isa_list)
isa_list_set = check - seen
seen.update(isa_list_set)
isa_items += [(isa, isa_path) for isa in get_items(isa_list)]
return {key: list(values) for key, values in osm_list.items()}
@ -467,7 +484,7 @@ def get_tags_for_isa_item(item):
def add_isa_filter(q, isa_qids):
q_subclass = database.session.query(model.Item.qid).filter(
func.jsonb_path_query_array(
sqlalchemy.func.jsonb_path_query_array(
model.Item.claims,
"$.P279[*].mainsnak.datavalue.value.id",
).bool_op("?|")(list(isa_qids))
@ -475,7 +492,7 @@ def add_isa_filter(q, isa_qids):
subclass_qid = {qid for qid, in q_subclass.all()}
isa = func.jsonb_path_query_array(
isa = sqlalchemy.func.jsonb_path_query_array(
model.Item.claims,
"$.P31[*].mainsnak.datavalue.value.id",
).bool_op("?|")
@ -484,7 +501,7 @@ def add_isa_filter(q, isa_qids):
def wikidata_items_count(bounds, isa_filter=None):
q = model.Item.query.join(model.ItemLocation).filter(
func.ST_Covers(make_envelope(bounds), model.ItemLocation.location)
sqlalchemy.func.ST_Covers(make_envelope(bounds), model.ItemLocation.location)
)
if isa_filter:
@ -499,7 +516,7 @@ def wikidata_isa_counts(bounds, isa_filter=None):
db_bbox = make_envelope(bounds)
q = model.Item.query.join(model.ItemLocation).filter(
func.ST_Covers(db_bbox, model.ItemLocation.location)
sqlalchemy.func.ST_Covers(db_bbox, model.ItemLocation.location)
)
if isa_filter:
@ -529,8 +546,11 @@ def wikidata_isa_counts(bounds, isa_filter=None):
return isa_count
def get_tag_filter(tags, tag_list):
def get_tag_filter(
tags: sqlalchemy.sql.schema.Column, tag_list: list[str]
) -> list[sqlalchemy.sql.elements.BooleanClauseList]:
tag_filter = []
print("tags type:", type(tags))
for tag_or_key in tag_list:
if tag_or_key.startswith("Key:"):
key = tag_or_key[4:]
@ -544,10 +564,11 @@ def get_tag_filter(tags, tag_list):
for prefix in tag_prefixes:
tag_filter.append(tags[f"{prefix}:{k}"] == v)
print("tag_filter type:", [type(i) for i in tag_filter])
return tag_filter
def get_preset_translations():
def get_preset_translations() -> dict[str, typing.Any]:
app = flask.current_app
country_language = {
"AU": "en-AU", # Australia
@ -569,7 +590,9 @@ def get_preset_translations():
continue
try:
return json_data[lang_code]["presets"]["presets"]
return typing.cast(
dict[str, typing.Any], json_data[lang_code]["presets"]["presets"]
)
except KeyError:
pass
@ -665,8 +688,13 @@ def address_node_label(tags: TagsType) -> str | None:
def get_address_nodes_within_building(osm_id, bbox_list):
q = model.Point.query.filter(
polygon.c.osm_id == osm_id,
or_(*[func.ST_Intersects(bbox, model.Point.way) for bbox in bbox_list]),
func.ST_Covers(polygon.c.way, model.Point.way),
or_(
*[
sqlalchemy.func.ST_Intersects(bbox, model.Point.way)
for bbox in bbox_list
]
),
sqlalchemy.func.ST_Covers(polygon.c.way, model.Point.way),
model.Point.tags.has_key("addr:street"),
model.Point.tags.has_key("addr:housenumber"),
)
@ -708,9 +736,11 @@ def find_osm_candidates(item, limit=80, max_distance=450, names=None):
for loc in item.locations
]
null_area = cast(None, Float)
dist = column("dist")
tags = column("tags", postgresql.HSTORE)
null_area = sqlalchemy.sql.expression.cast(None, sqlalchemy.types.Float)
dist = sqlalchemy.sql.expression.column("dist")
tags = sqlalchemy.sql.expression.column(
"tags", sqlalchemy.dialects.postgresql.HSTORE
)
tag_list = get_item_tags(item)
# tag_filters = get_tag_filter(point.c.tags, tag_list)
@ -719,20 +749,27 @@ def find_osm_candidates(item, limit=80, max_distance=450, names=None):
s_point = (
select(
[
literal("point").label("t"),
sqlalchemy.sql.expression.literal("point").label("t"),
point.c.osm_id,
point.c.tags.label("tags"),
func.min(
func.ST_DistanceSphere(model.ItemLocation.location, point.c.way)
sqlalchemy.func.min(
sqlalchemy.func.ST_DistanceSphere(
model.ItemLocation.location, point.c.way
)
).label("dist"),
func.ST_AsText(point.c.way),
func.ST_AsGeoJSON(point.c.way),
sqlalchemy.func.ST_AsText(point.c.way),
sqlalchemy.func.ST_AsGeoJSON(point.c.way),
null_area,
]
)
.where(
and_(
or_(*[func.ST_Intersects(bbox, point.c.way) for bbox in bbox_list]),
or_(
*[
sqlalchemy.func.ST_Intersects(bbox, point.c.way)
for bbox in bbox_list
]
),
model.ItemLocation.item_id == item_id,
or_(*get_tag_filter(point.c.tags, tag_list)),
)
@ -743,20 +780,29 @@ def find_osm_candidates(item, limit=80, max_distance=450, names=None):
s_line = (
select(
[
literal("line").label("t"),
sqlalchemy.sql.expression.literal("line").label("t"),
line.c.osm_id,
line.c.tags.label("tags"),
func.min(
func.ST_DistanceSphere(model.ItemLocation.location, line.c.way)
sqlalchemy.func.min(
sqlalchemy.func.ST_DistanceSphere(
model.ItemLocation.location, line.c.way
)
).label("dist"),
func.ST_AsText(func.ST_Centroid(func.ST_Collect(line.c.way))),
func.ST_AsGeoJSON(func.ST_Collect(line.c.way)),
sqlalchemy.func.ST_AsText(
sqlalchemy.func.ST_Centroid(sqlalchemy.func.ST_Collect(line.c.way))
),
sqlalchemy.func.ST_AsGeoJSON(sqlalchemy.func.ST_Collect(line.c.way)),
null_area,
]
)
.where(
and_(
or_(*[func.ST_Intersects(bbox, line.c.way) for bbox in bbox_list]),
or_(
*[
sqlalchemy.func.ST_Intersects(bbox, line.c.way)
for bbox in bbox_list
]
),
model.ItemLocation.item_id == item_id,
or_(*get_tag_filter(line.c.tags, tag_list)),
)
@ -767,33 +813,48 @@ def find_osm_candidates(item, limit=80, max_distance=450, names=None):
s_polygon = (
select(
[
literal("polygon").label("t"),
sqlalchemy.sql.expression.literal("polygon").label("t"),
polygon.c.osm_id,
polygon.c.tags.label("tags"),
func.min(
func.ST_DistanceSphere(model.ItemLocation.location, polygon.c.way)
sqlalchemy.func.min(
sqlalchemy.func.ST_DistanceSphere(
model.ItemLocation.location, polygon.c.way
)
).label("dist"),
func.ST_AsText(func.ST_Centroid(func.ST_Collect(polygon.c.way))),
func.ST_AsGeoJSON(func.ST_Collect(polygon.c.way)),
func.ST_Area(func.ST_Collect(polygon.c.way)),
sqlalchemy.func.ST_AsText(
sqlalchemy.func.ST_Centroid(
sqlalchemy.func.ST_Collect(polygon.c.way)
)
),
sqlalchemy.func.ST_AsGeoJSON(sqlalchemy.func.ST_Collect(polygon.c.way)),
sqlalchemy.func.ST_Area(sqlalchemy.func.ST_Collect(polygon.c.way)),
]
)
.where(
and_(
or_(*[func.ST_Intersects(bbox, polygon.c.way) for bbox in bbox_list]),
or_(
*[
sqlalchemy.func.ST_Intersects(bbox, polygon.c.way)
for bbox in bbox_list
]
),
model.ItemLocation.item_id == item_id,
or_(*get_tag_filter(polygon.c.tags, tag_list)),
)
)
.group_by(polygon.c.osm_id, polygon.c.tags)
.having(
func.ST_Area(func.ST_Collect(polygon.c.way))
< 20 * func.ST_Area(bbox_list[0])
sqlalchemy.func.ST_Area(sqlalchemy.func.ST_Collect(polygon.c.way))
< 20 * sqlalchemy.func.ST_Area(bbox_list[0])
)
)
tables = ([] if item_is_linear_feature else [s_point]) + [s_line, s_polygon]
s = select([union(*tables).alias()]).where(dist < max_distance).order_by(dist)
s = (
select([sqlalchemy.sql.expression.union(*tables).alias()])
.where(dist < max_distance)
.order_by(dist)
)
if names:
s = s.where(or_(tags["name"].in_(names), tags["old_name"].in_(names)))
@ -1056,24 +1117,24 @@ def missing_wikidata_items(qids, lat, lon):
return dict(items=items, isa_count=isa_count)
def isa_incremental_search(search_terms):
en_label = func.jsonb_extract_path_text(model.Item.labels, "en", "value")
def isa_incremental_search(search_terms: str) -> list[dict[str, str]]:
"""Incremental search."""
en_label = sqlalchemy.func.jsonb_extract_path_text(model.Item.labels, "en", "value")
q = model.Item.query.filter(
model.Item.claims.has_key("P1282"),
en_label.ilike(f"%{search_terms}%"),
func.length(en_label) < 20,
sqlalchemy.func.length(en_label) < 20,
)
# print(q.statement.compile(compile_kwargs={"literal_binds": True}))
ret = []
for item in q:
cur = {
return [
{
"qid": item.qid,
"label": item.label(),
}
ret.append(cur)
return ret
for item in q
]
class PlaceItems(typing.TypedDict):
@ -1091,7 +1152,7 @@ def get_place_items(osm_type: str, osm_id: int) -> PlaceItems:
model.Item.query.join(model.ItemLocation)
.join(
model.Polygon,
func.ST_Covers(model.Polygon.way, model.ItemLocation.location),
sqlalchemy.func.ST_Covers(model.Polygon.way, model.ItemLocation.location),
)
.filter(model.Polygon.src_id == src_id)
)