Add types and docstrings

This commit is contained in:
Edward Betts 2023-12-06 11:30:34 +00:00
parent fc0d6f114a
commit 2c267c67e2

View file

@ -1,5 +1,5 @@
import re
from typing import Any
import typing
import requests
from requests.adapters import HTTPAdapter
@ -8,6 +8,8 @@ from simplejson.scanner import JSONDecodeError
from .language import get_current_language
from .util import is_disambig
StrDict = dict[str, typing.Any]
ua = (
"find-link/2.2 "
+ "(https://github.com/EdwardBetts/find_link; contact: edward@4angle.com)"
@ -20,42 +22,47 @@ def get_query_url() -> str:
return f"https://{get_current_language()}.wikipedia.org/w/api.php"
sessions = {}
sessions: dict[str, requests.sessions.Session] = {}
def get_session():
def get_session() -> requests.sessions.Session:
"""Get requests session."""
lang = get_current_language()
if lang in sessions:
return sessions[lang]
s = requests.Session()
s.headers = {"User-Agent": ua}
s.mount("https://en.wikipedia.org", HTTPAdapter(max_retries=10))
s.params = {
s.params = typing.cast(
dict[str, str | int],
{
"format": "json",
"action": "query",
"formatversion": 2,
}
},
)
sessions[lang] = s
return s
class MediawikiError(Exception):
pass
"""Mediawiki error."""
class MultipleRedirects(Exception):
pass
"""Multiple redirects."""
class IncompleteReply(Exception):
pass
"""Incomplete reply."""
class MissingPage(Exception):
pass
"""Missing page."""
def check_for_error(json_data):
def check_for_error(json_data: dict[str, typing.Any]) -> None:
"""Check MediaWiki API reply for error."""
if "error" in json_data:
raise MediawikiError(json_data["error"]["info"])
@ -65,13 +72,13 @@ webpage_error = (
)
def api_get(params: dict[str, Any]) -> dict[str, Any]:
def api_get(params: StrDict) -> StrDict:
"""Make call to Wikipedia API."""
s = get_session()
r = s.get(get_query_url(), params=params)
try:
ret = r.json()
ret: StrDict = r.json()
except JSONDecodeError:
if webpage_error in r.text:
raise MediawikiError(webpage_error)
@ -81,22 +88,23 @@ def api_get(params: dict[str, Any]) -> dict[str, Any]:
return ret
def get_first_page(params: dict[str, str]) -> dict[str, Any]:
def get_first_page(params: dict[str, str]) -> StrDict:
"""Run Wikipedia API query and return the first page."""
page = api_get(params)["query"]["pages"][0]
page: StrDict = api_get(params)["query"]["pages"][0]
if page.get("missing"):
raise MissingPage
return page
def random_article_list(limit=50):
def random_article_list(limit: int = 50) -> list[StrDict]:
"""Get random sample of articles."""
params = {
"list": "random",
"rnnamespace": "0",
"rnlimit": limit,
}
return api_get(params)["query"]["random"]
return typing.cast(list[StrDict], api_get(params)["query"]["random"])
def wiki_search(q):
@ -185,23 +193,6 @@ def categorymembers(q: str) -> list[str]:
return [i["title"] for i in ret["categorymembers"] if i["title"] != q]
def page_links(titles): # unused
titles = list(titles)
assert titles
params = {
"prop": "links",
"pllimit": 500,
"plnamespace": 0,
"titles": "|".join(titles),
}
ret = api_get(params)["query"]
return dict(
(doc["title"], {l["title"] for l in doc["links"]})
for doc in ret["pages"].values()
if "links" in doc
)
def find_disambig(titles: list[str]) -> list[str]:
"""Find disambiguation articles in the given list of titles."""
titles = list(titles)
@ -235,7 +226,8 @@ def find_disambig(titles: list[str]) -> list[str]:
return disambig
def wiki_redirects(q): # pages that link here
def wiki_redirects(q: str) -> typing.Iterator[str]:
"""Pages that link here."""
params = {
"list": "backlinks",
"blfilterredir": "redirects",
@ -269,7 +261,8 @@ def wiki_backlink(q: str) -> tuple[set[str], set[str]]:
return (articles, redirects)
def call_get_diff(title, section_num, section_text):
def call_get_diff(title: str, section_num: int, section_text: str) -> str:
"""Get diff from Wikipedia."""
data = {
"prop": "revisions",
"rvprop": "timestamp",
@ -281,4 +274,4 @@ def call_get_diff(title, section_num, section_text):
s = get_session()
ret = s.post(get_query_url(), data=data).json()
check_for_error(ret)
return ret["query"]["pages"][0]["revisions"][0]["diff"]["body"]
return typing.cast(str, ret["query"]["pages"][0]["revisions"][0]["diff"]["body"])