Add types and docstrings
This commit is contained in:
parent
fc0d6f114a
commit
2c267c67e2
|
@ -1,5 +1,5 @@
|
|||
import re
|
||||
from typing import Any
|
||||
import typing
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
|
@ -8,6 +8,8 @@ from simplejson.scanner import JSONDecodeError
|
|||
from .language import get_current_language
|
||||
from .util import is_disambig
|
||||
|
||||
StrDict = dict[str, typing.Any]
|
||||
|
||||
ua = (
|
||||
"find-link/2.2 "
|
||||
+ "(https://github.com/EdwardBetts/find_link; contact: edward@4angle.com)"
|
||||
|
@ -20,42 +22,47 @@ def get_query_url() -> str:
|
|||
return f"https://{get_current_language()}.wikipedia.org/w/api.php"
|
||||
|
||||
|
||||
sessions = {}
|
||||
sessions: dict[str, requests.sessions.Session] = {}
|
||||
|
||||
|
||||
def get_session():
|
||||
def get_session() -> requests.sessions.Session:
|
||||
"""Get requests session."""
|
||||
lang = get_current_language()
|
||||
if lang in sessions:
|
||||
return sessions[lang]
|
||||
s = requests.Session()
|
||||
s.headers = {"User-Agent": ua}
|
||||
s.mount("https://en.wikipedia.org", HTTPAdapter(max_retries=10))
|
||||
s.params = {
|
||||
"format": "json",
|
||||
"action": "query",
|
||||
"formatversion": 2,
|
||||
}
|
||||
s.params = typing.cast(
|
||||
dict[str, str | int],
|
||||
{
|
||||
"format": "json",
|
||||
"action": "query",
|
||||
"formatversion": 2,
|
||||
},
|
||||
)
|
||||
sessions[lang] = s
|
||||
return s
|
||||
|
||||
|
||||
class MediawikiError(Exception):
|
||||
pass
|
||||
"""Mediawiki error."""
|
||||
|
||||
|
||||
class MultipleRedirects(Exception):
|
||||
pass
|
||||
"""Multiple redirects."""
|
||||
|
||||
|
||||
class IncompleteReply(Exception):
|
||||
pass
|
||||
"""Incomplete reply."""
|
||||
|
||||
|
||||
class MissingPage(Exception):
|
||||
pass
|
||||
"""Missing page."""
|
||||
|
||||
|
||||
def check_for_error(json_data):
|
||||
def check_for_error(json_data: dict[str, typing.Any]) -> None:
|
||||
"""Check MediaWiki API reply for error."""
|
||||
if "error" in json_data:
|
||||
raise MediawikiError(json_data["error"]["info"])
|
||||
|
||||
|
@ -65,13 +72,13 @@ webpage_error = (
|
|||
)
|
||||
|
||||
|
||||
def api_get(params: dict[str, Any]) -> dict[str, Any]:
|
||||
def api_get(params: StrDict) -> StrDict:
|
||||
"""Make call to Wikipedia API."""
|
||||
s = get_session()
|
||||
|
||||
r = s.get(get_query_url(), params=params)
|
||||
try:
|
||||
ret = r.json()
|
||||
ret: StrDict = r.json()
|
||||
except JSONDecodeError:
|
||||
if webpage_error in r.text:
|
||||
raise MediawikiError(webpage_error)
|
||||
|
@ -81,22 +88,23 @@ def api_get(params: dict[str, Any]) -> dict[str, Any]:
|
|||
return ret
|
||||
|
||||
|
||||
def get_first_page(params: dict[str, str]) -> dict[str, Any]:
|
||||
def get_first_page(params: dict[str, str]) -> StrDict:
|
||||
"""Run Wikipedia API query and return the first page."""
|
||||
page = api_get(params)["query"]["pages"][0]
|
||||
page: StrDict = api_get(params)["query"]["pages"][0]
|
||||
if page.get("missing"):
|
||||
raise MissingPage
|
||||
return page
|
||||
|
||||
|
||||
def random_article_list(limit=50):
|
||||
def random_article_list(limit: int = 50) -> list[StrDict]:
|
||||
"""Get random sample of articles."""
|
||||
params = {
|
||||
"list": "random",
|
||||
"rnnamespace": "0",
|
||||
"rnlimit": limit,
|
||||
}
|
||||
|
||||
return api_get(params)["query"]["random"]
|
||||
return typing.cast(list[StrDict], api_get(params)["query"]["random"])
|
||||
|
||||
|
||||
def wiki_search(q):
|
||||
|
@ -185,23 +193,6 @@ def categorymembers(q: str) -> list[str]:
|
|||
return [i["title"] for i in ret["categorymembers"] if i["title"] != q]
|
||||
|
||||
|
||||
def page_links(titles): # unused
|
||||
titles = list(titles)
|
||||
assert titles
|
||||
params = {
|
||||
"prop": "links",
|
||||
"pllimit": 500,
|
||||
"plnamespace": 0,
|
||||
"titles": "|".join(titles),
|
||||
}
|
||||
ret = api_get(params)["query"]
|
||||
return dict(
|
||||
(doc["title"], {l["title"] for l in doc["links"]})
|
||||
for doc in ret["pages"].values()
|
||||
if "links" in doc
|
||||
)
|
||||
|
||||
|
||||
def find_disambig(titles: list[str]) -> list[str]:
|
||||
"""Find disambiguation articles in the given list of titles."""
|
||||
titles = list(titles)
|
||||
|
@ -235,7 +226,8 @@ def find_disambig(titles: list[str]) -> list[str]:
|
|||
return disambig
|
||||
|
||||
|
||||
def wiki_redirects(q): # pages that link here
|
||||
def wiki_redirects(q: str) -> typing.Iterator[str]:
|
||||
"""Pages that link here."""
|
||||
params = {
|
||||
"list": "backlinks",
|
||||
"blfilterredir": "redirects",
|
||||
|
@ -269,7 +261,8 @@ def wiki_backlink(q: str) -> tuple[set[str], set[str]]:
|
|||
return (articles, redirects)
|
||||
|
||||
|
||||
def call_get_diff(title, section_num, section_text):
|
||||
def call_get_diff(title: str, section_num: int, section_text: str) -> str:
|
||||
"""Get diff from Wikipedia."""
|
||||
data = {
|
||||
"prop": "revisions",
|
||||
"rvprop": "timestamp",
|
||||
|
@ -281,4 +274,4 @@ def call_get_diff(title, section_num, section_text):
|
|||
s = get_session()
|
||||
ret = s.post(get_query_url(), data=data).json()
|
||||
check_for_error(ret)
|
||||
return ret["query"]["pages"][0]["revisions"][0]["diff"]["body"]
|
||||
return typing.cast(str, ret["query"]["pages"][0]["revisions"][0]["diff"]["body"])
|
||||
|
|
Loading…
Reference in a new issue