From b8ed296f78e662170cdcd728c791046d1d89f511 Mon Sep 17 00:00:00 2001 From: Edward Betts Date: Wed, 17 May 2023 16:28:44 +0000 Subject: [PATCH] Type hints and docstrings. --- matcher/api.py | 205 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 133 insertions(+), 72 deletions(-) diff --git a/matcher/api.py b/matcher/api.py index fb1e919..305015e 100644 --- a/matcher/api.py +++ b/matcher/api.py @@ -6,12 +6,9 @@ import typing import flask import geoalchemy2 -from sqlalchemy import and_, func, or_, text -from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import selectinload +import sqlalchemy +from sqlalchemy import and_, or_ from sqlalchemy.sql import select -from sqlalchemy.sql.expression import cast, column, literal, union -from sqlalchemy.types import Float from matcher import database, model, wikidata, wikidata_api from matcher.planet import line, point, polygon @@ -51,10 +48,11 @@ def get_country_iso3166_1(lat: float, lon: float) -> set[str]: Normally there should be only one country. """ - point = func.ST_SetSRID(func.ST_MakePoint(lon, lat), srid) + point = sqlalchemy.func.ST_SetSRID(sqlalchemy.func.ST_MakePoint(lon, lat), srid) alpha2_codes: set[str] = set() q = model.Polygon.query.filter( - func.ST_Covers(model.Polygon.way, point), model.Polygon.admin_level == "2" + sqlalchemy.func.ST_Covers(model.Polygon.way, point), + model.Polygon.admin_level == "2", ) for country in q: alpha2: str = country.tags.get("ISO3166-1") @@ -90,13 +88,15 @@ def is_street_number_first(lat: float, lon: float) -> bool: def make_envelope(bounds: list[float]) -> geoalchemy2.functions.ST_MakeEnvelope: """Make en envelope for the given bounds.""" - return func.ST_MakeEnvelope(*bounds, srid) + return sqlalchemy.func.ST_MakeEnvelope(*bounds, srid) def get_bbox_centroid(bbox: list[float]) -> tuple[str, str]: """Get centroid of bounding box.""" bbox = make_envelope(bbox) - centroid = database.session.query(func.ST_AsText(func.ST_Centroid(bbox))).scalar() + centroid = database.session.query( + sqlalchemy.func.ST_AsText(sqlalchemy.func.ST_Centroid(bbox)) + ).scalar() m = re_point.match(centroid) assert m lon, lat = m.groups() @@ -107,16 +107,25 @@ def get_bbox_centroid(bbox: list[float]) -> tuple[str, str]: def make_envelope_around_point( lat: float, lon: float, distance: float ) -> geoalchemy2.functions.ST_MakeEnvelope: + """Make an envelope around a point, the distance parameter specifies the size.""" conn = database.session.connection() - p = func.ST_MakePoint(lon, lat) + p = sqlalchemy.func.ST_MakePoint(lon, lat) s = select( [ - func.ST_AsText(func.ST_Project(p, distance, func.radians(0))), - func.ST_AsText(func.ST_Project(p, distance, func.radians(90))), - func.ST_AsText(func.ST_Project(p, distance, func.radians(180))), - func.ST_AsText(func.ST_Project(p, distance, func.radians(270))), + sqlalchemy.func.ST_AsText( + sqlalchemy.func.ST_Project(p, distance, sqlalchemy.func.radians(0)) + ), + sqlalchemy.func.ST_AsText( + sqlalchemy.func.ST_Project(p, distance, sqlalchemy.func.radians(90)) + ), + sqlalchemy.func.ST_AsText( + sqlalchemy.func.ST_Project(p, distance, sqlalchemy.func.radians(180)) + ), + sqlalchemy.func.ST_AsText( + sqlalchemy.func.ST_Project(p, distance, sqlalchemy.func.radians(270)) + ), ] ) row = conn.execute(s).fetchone() @@ -127,7 +136,7 @@ def make_envelope_around_point( south = coords[2][1] west = coords[3][0] - return func.ST_MakeEnvelope(west, south, east, north, srid) + return sqlalchemy.func.ST_MakeEnvelope(west, south, east, north, srid) def drop_way_area(tags: TagsType) -> TagsType: @@ -146,13 +155,13 @@ def get_part_of(table_name, src_id, bbox): [ polygon.c.osm_id, polygon.c.tags, - func.ST_Area(func.ST_Collect(polygon.c.way)), + sqlalchemy.func.ST_Area(sqlalchemy.func.ST_Collect(polygon.c.way)), ] ) .where( and_( - func.ST_Intersects(bbox, polygon.c.way), - func.ST_Covers(polygon.c.way, table_alias.c.way), + sqlalchemy.func.ST_Intersects(bbox, polygon.c.way), + sqlalchemy.func.ST_Covers(polygon.c.way, table_alias.c.way), table_alias.c.osm_id == src_id, polygon.c.tags.has_key("name"), or_( @@ -227,8 +236,8 @@ def get_items_in_bbox(bbox: list[float]): q = ( model.Item.query.join(model.ItemLocation) - .filter(func.ST_Covers(db_bbox, model.ItemLocation.location)) - .options(selectinload(model.Item.locations)) + .filter(sqlalchemy.func.ST_Covers(db_bbox, model.ItemLocation.location)) + .options(sqlalchemy.orm.selectinload(model.Item.locations)) ) return q @@ -239,7 +248,7 @@ def get_osm_with_wikidata_tag(bbox, isa_filter=None): extra_sql = "" if isa_filter: q = model.Item.query.join(model.ItemLocation).filter( - func.ST_Covers(make_envelope(bbox), model.ItemLocation.location) + sqlalchemy.func.ST_Covers(make_envelope(bbox), model.ItemLocation.location) ) q = add_isa_filter(q, isa_filter) qids = [isa.qid for isa in q] @@ -274,7 +283,7 @@ WHERE tags ? 'wikidata' + extra_sql ) conn = database.session.connection() - result = conn.execute(text(sql)) + result = conn.execute(sqlalchemy.text(sql)) # print(sql) @@ -344,7 +353,9 @@ def get_item_tags(item: model.Item) -> dict[str, list[str]]: osm_list = collections.defaultdict(list) - skip_isa = {row[0] for row in database.session.query(model.SkipIsA.item_id)} + skip_isa: set[int] = { + row[0] for row in database.session.query(model.SkipIsA.item_id) + } tram_stop_id = 41176 airport_id = 1248784 @@ -352,7 +363,7 @@ def get_item_tags(item: model.Item) -> dict[str, list[str]]: if {tram_stop_id, airport_id, aerodrome_id} & set(isa_list): skip_isa.add(41176) # building (Q41176) - seen = set(isa_list) | skip_isa + seen: set[int] = set(isa_list) | skip_isa stop = { "Q11799049": "public institution", "Q7075": "library", @@ -364,7 +375,9 @@ def get_item_tags(item: model.Item) -> dict[str, list[str]]: continue isa_qid: str = typing.cast(str, isa.qid) isa_path = isa_path + [{"qid": isa_qid, "label": isa.label()}] - osm = [v for v in isa.get_claim("P1282") if v not in skip_tags] + osm: list[str] = [ + typing.cast(str, v) for v in isa.get_claim("P1282") if v not in skip_tags + ] osm += [ extra.tag_or_key @@ -378,7 +391,7 @@ def get_item_tags(item: model.Item) -> dict[str, list[str]]: # item is specific enough, no need to keep walking the item hierarchy continue - check = set() + check: set[int] = set() properties = [ ("P279", "subclass of"), ("P140", "religion"), @@ -389,11 +402,15 @@ def get_item_tags(item: model.Item) -> dict[str, list[str]]: ] for pid, label in properties: - check |= {v["numeric-id"] for v in (isa.get_claim(pid) or []) if v} + check |= { + typing.cast(dict[str, int], v)["numeric-id"] + for v in (isa.get_claim(pid) or []) + if v + } print(isa.qid, isa.label(), check) - isa_list = check - seen - seen.update(isa_list) + isa_list_set = check - seen + seen.update(isa_list_set) isa_items += [(isa, isa_path) for isa in get_items(isa_list)] return {key: list(values) for key, values in osm_list.items()} @@ -467,7 +484,7 @@ def get_tags_for_isa_item(item): def add_isa_filter(q, isa_qids): q_subclass = database.session.query(model.Item.qid).filter( - func.jsonb_path_query_array( + sqlalchemy.func.jsonb_path_query_array( model.Item.claims, "$.P279[*].mainsnak.datavalue.value.id", ).bool_op("?|")(list(isa_qids)) @@ -475,7 +492,7 @@ def add_isa_filter(q, isa_qids): subclass_qid = {qid for qid, in q_subclass.all()} - isa = func.jsonb_path_query_array( + isa = sqlalchemy.func.jsonb_path_query_array( model.Item.claims, "$.P31[*].mainsnak.datavalue.value.id", ).bool_op("?|") @@ -484,7 +501,7 @@ def add_isa_filter(q, isa_qids): def wikidata_items_count(bounds, isa_filter=None): q = model.Item.query.join(model.ItemLocation).filter( - func.ST_Covers(make_envelope(bounds), model.ItemLocation.location) + sqlalchemy.func.ST_Covers(make_envelope(bounds), model.ItemLocation.location) ) if isa_filter: @@ -499,7 +516,7 @@ def wikidata_isa_counts(bounds, isa_filter=None): db_bbox = make_envelope(bounds) q = model.Item.query.join(model.ItemLocation).filter( - func.ST_Covers(db_bbox, model.ItemLocation.location) + sqlalchemy.func.ST_Covers(db_bbox, model.ItemLocation.location) ) if isa_filter: @@ -529,8 +546,11 @@ def wikidata_isa_counts(bounds, isa_filter=None): return isa_count -def get_tag_filter(tags, tag_list): +def get_tag_filter( + tags: sqlalchemy.sql.schema.Column, tag_list: list[str] +) -> list[sqlalchemy.sql.elements.BooleanClauseList]: tag_filter = [] + print("tags type:", type(tags)) for tag_or_key in tag_list: if tag_or_key.startswith("Key:"): key = tag_or_key[4:] @@ -544,10 +564,11 @@ def get_tag_filter(tags, tag_list): for prefix in tag_prefixes: tag_filter.append(tags[f"{prefix}:{k}"] == v) + print("tag_filter type:", [type(i) for i in tag_filter]) return tag_filter -def get_preset_translations(): +def get_preset_translations() -> dict[str, typing.Any]: app = flask.current_app country_language = { "AU": "en-AU", # Australia @@ -569,7 +590,9 @@ def get_preset_translations(): continue try: - return json_data[lang_code]["presets"]["presets"] + return typing.cast( + dict[str, typing.Any], json_data[lang_code]["presets"]["presets"] + ) except KeyError: pass @@ -665,8 +688,13 @@ def address_node_label(tags: TagsType) -> str | None: def get_address_nodes_within_building(osm_id, bbox_list): q = model.Point.query.filter( polygon.c.osm_id == osm_id, - or_(*[func.ST_Intersects(bbox, model.Point.way) for bbox in bbox_list]), - func.ST_Covers(polygon.c.way, model.Point.way), + or_( + *[ + sqlalchemy.func.ST_Intersects(bbox, model.Point.way) + for bbox in bbox_list + ] + ), + sqlalchemy.func.ST_Covers(polygon.c.way, model.Point.way), model.Point.tags.has_key("addr:street"), model.Point.tags.has_key("addr:housenumber"), ) @@ -708,9 +736,11 @@ def find_osm_candidates(item, limit=80, max_distance=450, names=None): for loc in item.locations ] - null_area = cast(None, Float) - dist = column("dist") - tags = column("tags", postgresql.HSTORE) + null_area = sqlalchemy.sql.expression.cast(None, sqlalchemy.types.Float) + dist = sqlalchemy.sql.expression.column("dist") + tags = sqlalchemy.sql.expression.column( + "tags", sqlalchemy.dialects.postgresql.HSTORE + ) tag_list = get_item_tags(item) # tag_filters = get_tag_filter(point.c.tags, tag_list) @@ -719,20 +749,27 @@ def find_osm_candidates(item, limit=80, max_distance=450, names=None): s_point = ( select( [ - literal("point").label("t"), + sqlalchemy.sql.expression.literal("point").label("t"), point.c.osm_id, point.c.tags.label("tags"), - func.min( - func.ST_DistanceSphere(model.ItemLocation.location, point.c.way) + sqlalchemy.func.min( + sqlalchemy.func.ST_DistanceSphere( + model.ItemLocation.location, point.c.way + ) ).label("dist"), - func.ST_AsText(point.c.way), - func.ST_AsGeoJSON(point.c.way), + sqlalchemy.func.ST_AsText(point.c.way), + sqlalchemy.func.ST_AsGeoJSON(point.c.way), null_area, ] ) .where( and_( - or_(*[func.ST_Intersects(bbox, point.c.way) for bbox in bbox_list]), + or_( + *[ + sqlalchemy.func.ST_Intersects(bbox, point.c.way) + for bbox in bbox_list + ] + ), model.ItemLocation.item_id == item_id, or_(*get_tag_filter(point.c.tags, tag_list)), ) @@ -743,20 +780,29 @@ def find_osm_candidates(item, limit=80, max_distance=450, names=None): s_line = ( select( [ - literal("line").label("t"), + sqlalchemy.sql.expression.literal("line").label("t"), line.c.osm_id, line.c.tags.label("tags"), - func.min( - func.ST_DistanceSphere(model.ItemLocation.location, line.c.way) + sqlalchemy.func.min( + sqlalchemy.func.ST_DistanceSphere( + model.ItemLocation.location, line.c.way + ) ).label("dist"), - func.ST_AsText(func.ST_Centroid(func.ST_Collect(line.c.way))), - func.ST_AsGeoJSON(func.ST_Collect(line.c.way)), + sqlalchemy.func.ST_AsText( + sqlalchemy.func.ST_Centroid(sqlalchemy.func.ST_Collect(line.c.way)) + ), + sqlalchemy.func.ST_AsGeoJSON(sqlalchemy.func.ST_Collect(line.c.way)), null_area, ] ) .where( and_( - or_(*[func.ST_Intersects(bbox, line.c.way) for bbox in bbox_list]), + or_( + *[ + sqlalchemy.func.ST_Intersects(bbox, line.c.way) + for bbox in bbox_list + ] + ), model.ItemLocation.item_id == item_id, or_(*get_tag_filter(line.c.tags, tag_list)), ) @@ -767,33 +813,48 @@ def find_osm_candidates(item, limit=80, max_distance=450, names=None): s_polygon = ( select( [ - literal("polygon").label("t"), + sqlalchemy.sql.expression.literal("polygon").label("t"), polygon.c.osm_id, polygon.c.tags.label("tags"), - func.min( - func.ST_DistanceSphere(model.ItemLocation.location, polygon.c.way) + sqlalchemy.func.min( + sqlalchemy.func.ST_DistanceSphere( + model.ItemLocation.location, polygon.c.way + ) ).label("dist"), - func.ST_AsText(func.ST_Centroid(func.ST_Collect(polygon.c.way))), - func.ST_AsGeoJSON(func.ST_Collect(polygon.c.way)), - func.ST_Area(func.ST_Collect(polygon.c.way)), + sqlalchemy.func.ST_AsText( + sqlalchemy.func.ST_Centroid( + sqlalchemy.func.ST_Collect(polygon.c.way) + ) + ), + sqlalchemy.func.ST_AsGeoJSON(sqlalchemy.func.ST_Collect(polygon.c.way)), + sqlalchemy.func.ST_Area(sqlalchemy.func.ST_Collect(polygon.c.way)), ] ) .where( and_( - or_(*[func.ST_Intersects(bbox, polygon.c.way) for bbox in bbox_list]), + or_( + *[ + sqlalchemy.func.ST_Intersects(bbox, polygon.c.way) + for bbox in bbox_list + ] + ), model.ItemLocation.item_id == item_id, or_(*get_tag_filter(polygon.c.tags, tag_list)), ) ) .group_by(polygon.c.osm_id, polygon.c.tags) .having( - func.ST_Area(func.ST_Collect(polygon.c.way)) - < 20 * func.ST_Area(bbox_list[0]) + sqlalchemy.func.ST_Area(sqlalchemy.func.ST_Collect(polygon.c.way)) + < 20 * sqlalchemy.func.ST_Area(bbox_list[0]) ) ) tables = ([] if item_is_linear_feature else [s_point]) + [s_line, s_polygon] - s = select([union(*tables).alias()]).where(dist < max_distance).order_by(dist) + s = ( + select([sqlalchemy.sql.expression.union(*tables).alias()]) + .where(dist < max_distance) + .order_by(dist) + ) if names: s = s.where(or_(tags["name"].in_(names), tags["old_name"].in_(names))) @@ -1056,24 +1117,24 @@ def missing_wikidata_items(qids, lat, lon): return dict(items=items, isa_count=isa_count) -def isa_incremental_search(search_terms): - en_label = func.jsonb_extract_path_text(model.Item.labels, "en", "value") +def isa_incremental_search(search_terms: str) -> list[dict[str, str]]: + """Incremental search.""" + en_label = sqlalchemy.func.jsonb_extract_path_text(model.Item.labels, "en", "value") q = model.Item.query.filter( model.Item.claims.has_key("P1282"), en_label.ilike(f"%{search_terms}%"), - func.length(en_label) < 20, + sqlalchemy.func.length(en_label) < 20, ) # print(q.statement.compile(compile_kwargs={"literal_binds": True})) - ret = [] - for item in q: - cur = { + return [ + { "qid": item.qid, "label": item.label(), } - ret.append(cur) - return ret + for item in q + ] class PlaceItems(typing.TypedDict): @@ -1091,7 +1152,7 @@ def get_place_items(osm_type: str, osm_id: int) -> PlaceItems: model.Item.query.join(model.ItemLocation) .join( model.Polygon, - func.ST_Covers(model.Polygon.way, model.ItemLocation.location), + sqlalchemy.func.ST_Covers(model.Polygon.way, model.ItemLocation.location), ) .filter(model.Polygon.src_id == src_id) )