"""Models.""" from __future__ import annotations import re import typing from flask import current_app, url_for from flask_login import UserMixin from hashids import Hashids from sqlalchemy import Column, ForeignKey, func from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import configure_mappers, relationship, synonym, validates from sqlalchemy.sql import exists from sqlalchemy.types import ( Boolean, DateTime, Enum, Integer, String, Unicode, UnicodeText, ) from sqlalchemy_continuum import make_versioned from sqlalchemy_continuum.plugins import ActivityPlugin, FlaskPlugin from werkzeug.security import check_password_hash, generate_password_hash from .database import session from .parse import parse_link, parse_span from .text import first_non_empty_line activity_plugin = ActivityPlugin() make_versioned(plugins=[FlaskPlugin(), activity_plugin]) doc_hashids = Hashids(min_length=8) Base = declarative_base() Base.query = session.query_property() re_server_url = re.compile(r"^http://perma.pub/\d+/([^/]+)/([^/]+)$") # list of disallowed usernames - maybe this should be in the database reserved_name = [ "root", "admin", "administrator", "support", "info", "test", "tech", "online", "old", "new", "jobs", "login", "job", "ipad" "iphone", "javascript", "script", "host", "mail", "image", "faq", "file", "ftp", "error", "warning", "the", "assistance", "maintenance", "controller", "head", "chief", "anon", ] re_username = re.compile(r"^\w+$", re.U) re_full_name = re.compile(r'^([-.\'" ]|[^\W\d_])+$', re.U) re_comment = re.compile(r"#.*") def item_url(): return url_for("view.view_item", username=self.user.username, hashid=self.hashid) def user_exists(field, value): return session.query(exists().where(field == value)).scalar() class TimeStampedModel(Base): __abstract__ = True created = Column(DateTime, default=func.now()) modified = Column(DateTime, default=func.now(), onupdate=func.now()) class LoginError(Exception): def __init__(self, msg): self.msg = msg class User(TimeStampedModel, UserMixin): __tablename__ = "user" id = Column(Integer, primary_key=True) username = Column(Unicode(32), unique=True, nullable=False) pw_hash = Column(String(160), nullable=False) email = Column(Unicode(64), unique=True, nullable=False) email_verified = Column(Boolean(), nullable=False, default=False) disabled = Column(Boolean(), nullable=False, default=False) deleted = Column(Boolean(), nullable=False, default=False) is_super = Column(Boolean, nullable=False, default=False) last_login = Column(DateTime) full_name = Column(Unicode(64)) balance = Column(Integer, nullable=False, default=0) user_id = synonym("id") name = synonym("full_name") user_name = synonym("username") def __init__(self, **kwargs): pw_hash = generate_password_hash(kwargs.pop("password")) return super(User, self).__init__(pw_hash=pw_hash, **kwargs) def __repr__(self): return "<User: {!r}>".format(self.username) def set_password(self, password): self.pw_hash = generate_password_hash(password) def check_password(self, password): return check_password_hash(self.pw_hash, password) def get_id(self): return self.id @validates("email") def validate_email(self, key, value): assert "@" in value return value @validates("username") def validate_usernane(self, key, value): assert re_username.match(value) return value @validates("full_name") def validate_full_name(self, key, value): if value: assert re_full_name.match(value) return value @hybrid_property def is_live(self): return self.email_verified & ~self.disabled & ~self.deleted @classmethod def lookup_user_or_email(cls, user_or_email): field = cls.email if "@" in user_or_email else cls.username return cls.query.filter(field == user_or_email).one_or_none() @property def mail_to_name(self): """Name to use on e-mails sent to the user.""" return self.full_name or self.username @classmethod def attempt_login(cls, user_or_email, password): user = cls.lookup_user_or_email(user_or_email) if not user: raise LoginError("user not found") if user.disabled: raise LoginError("user account disabled") if not user.check_password(password): raise LoginError("incorrect password") return user class Reference(Base): __tablename__ = "reference" subject_id = Column(Integer, ForeignKey("item.id"), primary_key=True) object_id = Column(Integer, ForeignKey("item.id"), primary_key=True) class Item(TimeStampedModel): __tablename__ = "item" __versioned__ = {"base_classes": (TimeStampedModel,)} id = Column(Integer, primary_key=True) user_id = Column(Integer, ForeignKey("user.id")) published = Column(DateTime) type = Column( Enum("sourcedoc", "xanapage", "xanalink", name="item_type"), nullable=False ) filename = Column(Unicode) text = Column(UnicodeText) subjects = relationship( "Item", lazy="dynamic", secondary="reference", primaryjoin=id == Reference.object_id, secondaryjoin=id == Reference.subject_id, ) objects = relationship( "Item", lazy="dynamic", secondary="reference", primaryjoin=id == Reference.subject_id, secondaryjoin=id == Reference.object_id, ) user = relationship("User", backref="items") __mapper_args__ = { "polymorphic_on": type, "with_polymorphic": "*", } @property def hashid(self) -> str: """Hashid for item.""" return doc_hashids.encode(self.id) @classmethod def get_by_hashid(cls, hashid: str) -> Item | None: """Return the item with the given hashid.""" try: item_id = doc_hashids.decode(hashid)[0] except IndexError: return None return typing.cast("Item", cls.query.get(item_id)) def view_url(self, endpoint, **kwargs) -> str: return url_for( "view." + endpoint, username=self.user.username, hashid=self.hashid, **kwargs, ) @property def url(self) -> str: return self.view_url("view_item") def url_fragment(self): return self.user.username + "/" + self.hashid def version_url(self, version): return self.view_url("view_item", v=version) @property def history_url(self): return self.view_url("history") @property def external_url(self): base_url = current_app.config.get("BASE_URL") if not base_url.endswith("/"): base_url += "/" if base_url: return base_url + self.url_fragment() else: return self.view_url("view_item", _external=True) @property def edit_url(self): return self.view_url("edit_item") @property def set_title_url(self): return self.view_url("set_title") def title_from_link(self, titles=None): if not titles: titles = XanaLink.get_all_titles() return titles.get(self) def title(self, titles=None): return self.type + ": " + (self.title_from_link(titles) or self.hashid) def has_title(self) -> bool: """Item has a title.""" titles = XanaLink.get_all_titles() return self in titles def set_title(self, title, user): title_source_doc = SourceDoc(text=title, user=user) session.add(title_source_doc) session.commit() link_text = """type=title facet= sourcedoc: {} facet= span: {},start=0,length={}""".format( self.external_url, title_source_doc.external_url, len(title) ) title_link = XanaLink(text=link_text, user=user) session.add(title_link) session.commit() @classmethod def from_external(cls, url: str, home: str | None = None) -> None | "Item": """Get item from URL.""" base = current_app.config.get("BASE_URL") username, hashid = None, None if home is None: home = url_for("view.home", _external=True) if url.startswith(home): username, _, hashid = url[len(home) :].partition("/") elif base and url.startswith(base): username, _, hashid = url[len(base) :].lstrip("/").partition("/") if username and "/" in username or hashid and "/" in hashid: username, hashid = None, None if not username or not hashid: m = re_server_url.match(url) if not m: return None username, hashid = m.groups() item_id = doc_hashids.decode(hashid)[0] q = cls.query.filter(User.username == username, cls.id == item_id) return q.one_or_none() class XanaPage(Item): __tablename__ = "xanapage" __mapper_args__ = {"polymorphic_identity": "xanapage"} id = Column(Integer, ForeignKey(Item.id), primary_key=True) def snippet(self): return self.text @property def xanaedit_url(self): return self.view_url("xanaedit_item") @property def save_xanaedit_url(self): return self.view_url("save_xanaedit") def iter_spans(self): for line in self.text.splitlines(): line = re_comment.sub("", line).strip() if not line: continue span_pointer = parse_span(line) if span_pointer: yield span_pointer def update_references(self) -> None: """Update references.""" for url, start, length in self.iter_spans(): src_doc = Item.from_external(url) if not src_doc or not src_doc.id: continue existing = Reference.query.get((self.id, src_doc.id)) if existing: continue ref = Reference(subject_id=self.id, object_id=src_doc.id) session.add(ref) session.commit() class XanaLink(Item): __tablename__ = "xanalink" __mapper_args__ = {"polymorphic_identity": "xanalink"} id = Column(Integer, ForeignKey(Item.id), primary_key=True) def parse(self): return parse_link(self.text) @property def link_type(self): return self.parse()["type"] def title(self, titles=None): if titles is None: titles = XanaLink.get_all_titles() if self in titles: return self.type + ": " + titles[self] parsed = self.parse() if parsed["type"] == "title": ident = parsed["facets"][0][0].partition(": ")[2] item = Item.from_external(ident) if item in titles: return parsed["type"] + " link for " + item.title(titles=titles) if parsed["type"]: return parsed["type"] + " link: " + self.hashid else: return "link: " + self.hashid def item_and_title(self, home=None): link = self.parse() if link["type"] != "title": return try: facet1, facet2 = link["facets"] except ValueError: return link_type, _, ident = facet1[0].partition(": ") item = Item.from_external(ident, home) try: ident2, start, length = parse_span(facet2[0]) except TypeError: return source_of_title = SourceDoc.from_external(ident2, home) if source_of_title: return (item, source_of_title.text[start : length + start]) @classmethod def get_all_titles(cls, home: str | None = None) -> dict["Item", str]: titles = {} for link in cls.query: ret = link.item_and_title(home) if ret is None: continue item, title = ret titles[item] = title return titles def snippet(self): return self.text class SourceDoc(Item): __tablename__ = "sourcedoc" __mapper_args__ = {"polymorphic_identity": "sourcedoc"} id = Column(Integer, ForeignKey(Item.id), primary_key=True) db_price_per_character = Column(Integer) db_document_price = Column(Integer) @property def document_price(self): return self.db_document_price or self.db_price_per_character * len(self.text) @property def price_per_character(self): return self.db_price_per_character or self.db_document_price / len(self.text) def snippet(self, length=255, killwords=False, end="...", leeway=5): s = self.text assert length >= len(end), "expected length >= %s, got %s" % (len(end), length) assert leeway >= 0, "expected leeway >= 0, got %s" % leeway if len(s) <= length + leeway: return s if killwords: return s[: length - len(end)] + end result = s[: length - len(end)].rsplit(" ", 1)[0] return result + end def raw_title(self): return self.title(with_type=False) def title(self, titles=None, with_type=True): start = self.type + ": " if with_type else "" titles = XanaLink.get_all_titles() from_link = self.title_from_link(titles=titles) if from_link: return start + from_link first_line = first_non_empty_line(self.text) if first_line: return start + first_line return start + self.hashid @property def create_xanapage_url(self): return self.view_url("create_xanapage_from_sourcedoc") @property def entire_span(self): return self.external_url + f",start=0,length={len(self.text)}" configure_mappers()