From 6ecfff56c5ecd63ce089a1e394a412501c93fd1e Mon Sep 17 00:00:00 2001 From: Damien Elmes Date: Thu, 19 Dec 2019 13:02:45 +1000 Subject: [PATCH] add pytype inferred types to anki/ I've corrected some obvious issues, and we can fix others over time. Mypy tests are currently broken, as adding the type hints has increased mypy's testing surface. --- anki/cards.py | 39 ++--- anki/collection.py | 137 ++++++++--------- anki/consts.py | 9 +- anki/db.py | 39 ++--- anki/decks.py | 111 +++++++------- anki/errors.py | 7 +- anki/exporting.py | 43 +++--- anki/find.py | 53 +++---- anki/hooks.py | 10 +- anki/importing/anki2.py | 35 ++--- anki/importing/apkg.py | 5 +- anki/importing/base.py | 9 +- anki/importing/csvfile.py | 11 +- anki/importing/mnemo.py | 10 +- anki/importing/noteimp.py | 29 ++-- anki/importing/pauker.py | 5 +- anki/importing/supermemo_xml.py | 41 +++--- anki/lang.py | 17 ++- anki/latex.py | 13 +- anki/media.py | 81 +++++----- anki/models.py | 99 ++++++------- anki/mpv.py | 67 ++++----- anki/notes.py | 43 +++--- anki/sched.py | 207 +++++++++++++------------- anki/schedv2.py | 253 ++++++++++++++++---------------- anki/sound.py | 66 +++++---- anki/stats.py | 95 ++++++------ anki/stdmodels.py | 15 +- anki/storage.py | 21 +-- anki/sync.py | 131 +++++++++-------- anki/tags.py | 37 ++--- anki/template/__init__.py | 3 +- anki/template/furigana.py | 13 +- anki/template/hint.py | 4 +- anki/template/template.py | 27 ++-- anki/template/view.py | 21 +-- anki/utils.py | 77 +++++----- aqt/main.py | 3 +- 38 files changed, 967 insertions(+), 919 deletions(-) diff --git a/anki/cards.py b/anki/cards.py index b8c78e176..3f5978c04 100644 --- a/anki/cards.py +++ b/anki/cards.py @@ -7,6 +7,7 @@ import time from anki.hooks import runHook from anki.utils import intTime, timestampID, joinFields from anki.consts import * +from typing import Any, Optional # Cards ########################################################################## @@ -21,7 +22,7 @@ from anki.consts import * class Card: - def __init__(self, col, id=None): + def __init__(self, col, id=None) -> None: self.col = col self.timerStarted = None self._qa = None @@ -46,7 +47,7 @@ class Card: self.flags = 0 self.data = "" - def load(self): + def load(self) -> None: (self.id, self.nid, self.did, @@ -69,7 +70,7 @@ class Card: self._qa = None self._note = None - def flush(self): + def flush(self) -> None: self.mod = intTime() self.usn = self.col.usn() # bug check @@ -100,7 +101,7 @@ insert or replace into cards values self.data) self.col.log(self) - def flushSched(self): + def flushSched(self) -> None: self.mod = intTime() self.usn = self.col.usn() # bug checks @@ -116,16 +117,16 @@ lapses=?, left=?, odue=?, odid=?, did=? where id = ?""", self.left, self.odue, self.odid, self.did, self.id) self.col.log(self) - def q(self, reload=False, browser=False): + def q(self, reload=False, browser=False) -> str: return self.css() + self._getQA(reload, browser)['q'] - def a(self): + def a(self) -> str: return self.css() + self._getQA()['a'] - def css(self): + def css(self) -> str: return "" % self.model()['css'] - def _getQA(self, reload=False, browser=False): + def _getQA(self, reload=False, browser=False) -> Any: if not self._qa or reload: f = self.note(reload); m = self.model(); t = self.template() data = [self.id, f.id, m['id'], self.odid or self.did, self.ord, @@ -137,45 +138,45 @@ lapses=?, left=?, odue=?, odid=?, did=? where id = ?""", self._qa = self.col._renderQA(data, *args) return self._qa - def note(self, reload=False): + def note(self, reload=False) -> Any: if not self._note or reload: self._note = self.col.getNote(self.nid) return self._note - def model(self): + def model(self) -> Any: return self.col.models.get(self.note().mid) - def template(self): + def template(self) -> Any: m = self.model() if m['type'] == MODEL_STD: return self.model()['tmpls'][self.ord] else: return self.model()['tmpls'][0] - def startTimer(self): + def startTimer(self) -> None: self.timerStarted = time.time() - def timeLimit(self): + def timeLimit(self) -> Any: "Time limit for answering in milliseconds." conf = self.col.decks.confForDid(self.odid or self.did) return conf['maxTaken']*1000 - def shouldShowTimer(self): + def shouldShowTimer(self) -> Any: conf = self.col.decks.confForDid(self.odid or self.did) return conf['timer'] - def timeTaken(self): + def timeTaken(self) -> Any: "Time taken to answer card, in integer MS." total = int((time.time() - self.timerStarted)*1000) return min(total, self.timeLimit()) - def isEmpty(self): + def isEmpty(self) -> Optional[bool]: ords = self.col.models.availOrds( self.model(), joinFields(self.note().fields)) if self.ord not in ords: return True - def __repr__(self): + def __repr__(self) -> str: d = dict(self.__dict__) # remove non-useful elements del d['_note'] @@ -184,9 +185,9 @@ lapses=?, left=?, odue=?, odid=?, did=? where id = ?""", del d['timerStarted'] return pprint.pformat(d, width=300) - def userFlag(self): + def userFlag(self) -> Any: return self.flags & 0b111 - def setUserFlag(self, flag): + def setUserFlag(self, flag) -> None: assert 0 <= flag <= 7 self.flags = (self.flags & ~0b111) | flag diff --git a/anki/collection.py b/anki/collection.py index 229e3e034..d1d2f2961 100644 --- a/anki/collection.py +++ b/anki/collection.py @@ -29,6 +29,7 @@ import anki.cards import anki.notes import anki.template import anki.find +from typing import Any, List, Optional, Tuple, Union, Dict defaultConf = { # review options @@ -49,7 +50,7 @@ defaultConf = { 'schedVer': 2, } -def timezoneOffset(): +def timezoneOffset() -> int: if time.localtime().tm_isdst: return time.altzone//60 else: @@ -58,7 +59,7 @@ def timezoneOffset(): # this is initialized by storage.Collection class _Collection: - def __init__(self, db, server=False, log=False): + def __init__(self, db, server=False, log=False) -> None: self._debugLog = log self.db = db self.path = db._path @@ -85,7 +86,7 @@ class _Collection: self.conf['newBury'] = True self.setMod() - def name(self): + def name(self) -> Any: n = os.path.splitext(os.path.basename(self.path))[0] return n @@ -94,14 +95,14 @@ class _Collection: supportedSchedulerVersions = (1, 2) - def schedVer(self): + def schedVer(self) -> Any: ver = self.conf.get("schedVer", 1) if ver in self.supportedSchedulerVersions: return ver else: raise Exception("Unsupported scheduler version") - def _loadScheduler(self): + def _loadScheduler(self) -> None: ver = self.schedVer() if ver == 1: from anki.sched import Scheduler @@ -110,7 +111,7 @@ class _Collection: self.sched = Scheduler(self) - def changeSchedulerVer(self, ver): + def changeSchedulerVer(self, ver) -> None: if ver == self.schedVer(): return if ver not in self.supportedSchedulerVersions: @@ -135,7 +136,7 @@ class _Collection: # DB-related ########################################################################## - def load(self): + def load(self) -> None: (self.crt, self.mod, self.scm, @@ -154,14 +155,14 @@ conf, models, decks, dconf, tags from col""") self.decks.load(decks, dconf) self.tags.load(tags) - def setMod(self): + def setMod(self) -> None: """Mark DB modified. DB operations and the deck/tag/model managers do this automatically, so this is only necessary if you modify properties of this object or the conf dict.""" self.db.mod = True - def flush(self, mod=None): + def flush(self, mod=None) -> None: "Flush state to DB, updating mod time." self.mod = intTime(1000) if mod is None else mod self.db.execute( @@ -170,7 +171,7 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""", self.crt, self.mod, self.scm, self.dty, self._usn, self.ls, json.dumps(self.conf)) - def save(self, name=None, mod=None): + def save(self, name=None, mod=None) -> None: "Flush, commit DB, and take out another write lock." # let the managers conditionally flush self.models.flush() @@ -185,19 +186,19 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""", self._markOp(name) self._lastSave = time.time() - def autosave(self): + def autosave(self) -> Optional[bool]: "Save if 5 minutes has passed since last save. True if saved." if time.time() - self._lastSave > 300: self.save() return True - def lock(self): + def lock(self) -> None: # make sure we don't accidentally bump mod time mod = self.db.mod self.db.execute("update col set mod=mod") self.db.mod = mod - def close(self, save=True): + def close(self, save=True) -> None: "Disconnect from DB." if self.db: if save: @@ -213,7 +214,7 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""", self.media.close() self._closeLog() - def reopen(self): + def reopen(self) -> None: "Reconnect to DB (after changing threads, etc)." import anki.db if not self.db: @@ -221,12 +222,12 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""", self.media.connect() self._openLog() - def rollback(self): + def rollback(self) -> None: self.db.rollback() self.load() self.lock() - def modSchema(self, check): + def modSchema(self, check) -> None: "Mark schema modified. Call this first so user can abort if necessary." if not self.schemaChanged(): if check and not runFilter("modSchema", True): @@ -234,14 +235,14 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""", self.scm = intTime(1000) self.setMod() - def schemaChanged(self): + def schemaChanged(self) -> Any: "True if schema changed since last sync." return self.scm > self.ls - def usn(self): + def usn(self) -> Any: return self._usn if self.server else -1 - def beforeUpload(self): + def beforeUpload(self) -> None: "Called before a full upload." tbls = "notes", "cards", "revlog" for t in tbls: @@ -263,44 +264,44 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""", # Object creation helpers ########################################################################## - def getCard(self, id): + def getCard(self, id) -> anki.cards.Card: return anki.cards.Card(self, id) - def getNote(self, id): + def getNote(self, id) -> anki.notes.Note: return anki.notes.Note(self, id=id) # Utils ########################################################################## - def nextID(self, type, inc=True): + def nextID(self, type, inc=True) -> Any: type = "next"+type.capitalize() id = self.conf.get(type, 1) if inc: self.conf[type] = id+1 return id - def reset(self): + def reset(self) -> None: "Rebuild the queue and reload data after DB modified." self.sched.reset() # Deletion logging ########################################################################## - def _logRem(self, ids, type): + def _logRem(self, ids, type) -> None: self.db.executemany("insert into graves values (%d, ?, %d)" % ( self.usn(), type), ([x] for x in ids)) # Notes ########################################################################## - def noteCount(self): + def noteCount(self) -> Any: return self.db.scalar("select count() from notes") - def newNote(self, forDeck=True): + def newNote(self, forDeck=True) -> anki.notes.Note: "Return a new note with the current model." return anki.notes.Note(self, self.models.current(forDeck)) - def addNote(self, note): + def addNote(self, note) -> int: "Add a note to the collection. Return number of new cards." # check we have card models available, then save cms = self.findTemplates(note) @@ -316,11 +317,11 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""", ncards += 1 return ncards - def remNotes(self, ids): + def remNotes(self, ids) -> None: self.remCards(self.db.list("select id from cards where nid in "+ ids2str(ids))) - def _remNotes(self, ids): + def _remNotes(self, ids) -> None: "Bulk delete notes by ID. Don't call this directly." if not ids: return @@ -334,13 +335,13 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""", # Card creation ########################################################################## - def findTemplates(self, note): + def findTemplates(self, note) -> List: "Return (active), non-empty templates." model = note.model() avail = self.models.availOrds(model, joinFields(note.fields)) return self._tmplsFromOrds(model, avail) - def _tmplsFromOrds(self, model, avail): + def _tmplsFromOrds(self, model, avail) -> List: ok = [] if model['type'] == MODEL_STD: for t in model['tmpls']: @@ -354,7 +355,7 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""", ok.append(t) return ok - def genCards(self, nids): + def genCards(self, nids) -> List: "Generate cards for non-empty templates, return ids to remove." # build map of (nid,ord) so we don't create dupes snids = ids2str(nids) @@ -428,7 +429,7 @@ insert into cards values (?,?,?,?,?,?,0,0,?,0,0,0,0,0,0,0,0,"")""", # type 0 - when previewing in add dialog, only non-empty # type 1 - when previewing edit, only existing # type 2 - when previewing in models dialog, all templates - def previewCards(self, note, type=0, did=None): + def previewCards(self, note, type=0, did=None) -> List: if type == 0: cms = self.findTemplates(note) elif type == 1: @@ -442,7 +443,7 @@ insert into cards values (?,?,?,?,?,?,0,0,?,0,0,0,0,0,0,0,0,"")""", cards.append(self._newCard(note, template, 1, flush=False, did=did)) return cards - def _newCard(self, note, template, due, flush=True, did=None): + def _newCard(self, note, template, due, flush=True, did=None) -> anki.cards.Card: "Create a new card." card = anki.cards.Card(self) card.nid = note.id @@ -469,7 +470,7 @@ insert into cards values (?,?,?,?,?,?,0,0,?,0,0,0,0,0,0,0,0,"")""", card.flush() return card - def _dueForDid(self, did, due): + def _dueForDid(self, did, due: int) -> int: conf = self.decks.confForDid(did) # in order due? if conf['new']['order'] == NEW_CARDS_DUE: @@ -484,13 +485,13 @@ insert into cards values (?,?,?,?,?,?,0,0,?,0,0,0,0,0,0,0,0,"")""", # Cards ########################################################################## - def isEmpty(self): + def isEmpty(self) -> bool: return not self.db.scalar("select 1 from cards limit 1") - def cardCount(self): + def cardCount(self) -> Any: return self.db.scalar("select count() from cards") - def remCards(self, ids, notes=True): + def remCards(self, ids, notes=True) -> None: "Bulk delete cards by ID." if not ids: return @@ -507,13 +508,13 @@ select id from notes where id in %s and id not in (select nid from cards)""" % ids2str(nids)) self._remNotes(nids) - def emptyCids(self): + def emptyCids(self) -> List: rem = [] for m in self.models.all(): rem += self.genCards(self.models.nids(m)) return rem - def emptyCardReport(self, cids): + def emptyCardReport(self, cids) -> str: rep = "" for ords, cnt, flds in self.db.all(""" select group_concat(ord+1), count(), flds from cards c, notes n @@ -525,11 +526,11 @@ where c.nid = n.id and c.id in %s group by nid""" % ids2str(cids)): # Field checksums and sorting fields ########################################################################## - def _fieldData(self, snids): + def _fieldData(self, snids) -> Any: return self.db.execute( "select id, mid, flds from notes where id in "+snids) - def updateFieldCache(self, nids): + def updateFieldCache(self, nids) -> None: "Update field checksums and sort cache, after find&replace, etc." snids = ids2str(nids) r = [] @@ -548,7 +549,7 @@ where c.nid = n.id and c.id in %s group by nid""" % ids2str(cids)): # Q/A generation ########################################################################## - def renderQA(self, ids=None, type="card"): + def renderQA(self, ids=None, type="card") -> List: # gather metadata if type == "card": where = "and c.id in " + ids2str(ids) @@ -563,7 +564,7 @@ where c.nid = n.id and c.id in %s group by nid""" % ids2str(cids)): return [self._renderQA(row) for row in self._qaData(where)] - def _renderQA(self, data, qfmt=None, afmt=None): + def _renderQA(self, data, qfmt=None, afmt=None) -> Dict: "Returns hash of id, question, answer." # data is [cid, nid, mid, did, ord, tags, flds, cardFlags] # unpack fields and create dict @@ -610,7 +611,7 @@ where c.nid = n.id and c.id in %s group by nid""" % ids2str(cids)): "%s" % (HELP_SITE, _("help")))) return d - def _qaData(self, where=""): + def _qaData(self, where="") -> Any: "Return [cid, nid, mid, did, ord, tags, flds, cardFlags] db query" return self.db.execute(""" select c.id, f.id, f.mid, c.did, c.ord, f.tags, f.flds, c.flags @@ -618,7 +619,7 @@ from cards c, notes f where c.nid == f.id %s""" % where) - def _flagNameFromCardFlags(self, flags): + def _flagNameFromCardFlags(self, flags) -> str: flag = flags & 0b111 if not flag: return "" @@ -627,37 +628,37 @@ where c.nid == f.id # Finding cards ########################################################################## - def findCards(self, query, order=False): + def findCards(self, query, order=False) -> Any: return anki.find.Finder(self).findCards(query, order) - def findNotes(self, query): + def findNotes(self, query) -> Any: return anki.find.Finder(self).findNotes(query) - def findReplace(self, nids, src, dst, regex=None, field=None, fold=True): + def findReplace(self, nids, src, dst, regex=None, field=None, fold=True) -> int: return anki.find.findReplace(self, nids, src, dst, regex, field, fold) - def findDupes(self, fieldName, search=""): + def findDupes(self, fieldName, search="") -> List[Tuple[Any, list]]: return anki.find.findDupes(self, fieldName, search) # Stats ########################################################################## - def cardStats(self, card): + def cardStats(self, card) -> str: from anki.stats import CardStats return CardStats(self, card).report() - def stats(self): + def stats(self) -> "anki.stats.CollectionStats": from anki.stats import CollectionStats return CollectionStats(self) # Timeboxing ########################################################################## - def startTimebox(self): + def startTimebox(self) -> None: self._startTime = time.time() self._startReps = self.sched.reps - def timeboxReached(self): + def timeboxReached(self) -> Optional[Union[bool, Tuple[Any, int]]]: "Return (elapsedTime, reps) if timebox reached, or False." if not self.conf['timeLim']: # timeboxing disabled @@ -669,24 +670,24 @@ where c.nid == f.id # Undo ########################################################################## - def clearUndo(self): + def clearUndo(self) -> None: # [type, undoName, data] # type 1 = review; type 2 = checkpoint self._undo = None - def undoName(self): + def undoName(self) -> Any: "Undo menu item name, or None if undo unavailable." if not self._undo: return None return self._undo[1] - def undo(self): + def undo(self) -> Any: if self._undo[0] == 1: return self._undoReview() else: self._undoOp() - def markReview(self, card): + def markReview(self, card) -> None: old = [] if self._undo: if self._undo[0] == 1: @@ -695,7 +696,7 @@ where c.nid == f.id wasLeech = card.note().hasTag("leech") or False self._undo = [1, _("Review"), old + [copy.copy(card)], wasLeech] - def _undoReview(self): + def _undoReview(self) -> Any: data = self._undo[2] wasLeech = self._undo[3] c = data.pop() # pytype: disable=attribute-error @@ -723,7 +724,7 @@ where c.nid == f.id self.sched.reps -= 1 return c.id - def _markOp(self, name): + def _markOp(self, name) -> None: "Call via .save()" if name: self._undo = [2, name] @@ -732,14 +733,14 @@ where c.nid == f.id if self._undo and self._undo[0] == 2: self.clearUndo() - def _undoOp(self): + def _undoOp(self) -> None: self.rollback() self.clearUndo() # DB maintenance ########################################################################## - def basicCheck(self): + def basicCheck(self) -> Optional[bool]: "Basic integrity check for syncing. True if ok." # cards without notes if self.db.scalar(""" @@ -763,7 +764,7 @@ select id from notes where mid = ?) limit 1""" % return return True - def fixIntegrity(self): + def fixIntegrity(self) -> Tuple[Any, bool]: "Fix possible problems and rebuild caches." problems = [] curs = self.db.cursor() @@ -901,7 +902,7 @@ and type=0""", [intTime(), self.usn()]) self.save() return ("\n".join(problems), ok) - def optimize(self): + def optimize(self) -> None: self.db.setAutocommit(True) self.db.execute("vacuum") self.db.execute("analyze") @@ -911,7 +912,7 @@ and type=0""", [intTime(), self.usn()]) # Logging ########################################################################## - def log(self, *args, **kwargs): + def log(self, *args, **kwargs) -> None: if not self._debugLog: return def customRepr(x): @@ -926,7 +927,7 @@ and type=0""", [intTime(), self.usn()]) if devMode: print(buf) - def _openLog(self): + def _openLog(self) -> None: if not self._debugLog: return lpath = re.sub(r"\.anki2$", ".log", self.path) @@ -937,7 +938,7 @@ and type=0""", [intTime(), self.usn()]) os.rename(lpath, lpath2) self._logHnd = open(lpath, "a", encoding="utf8") - def _closeLog(self): + def _closeLog(self) -> None: if not self._debugLog: return self._logHnd.close() @@ -946,7 +947,7 @@ and type=0""", [intTime(), self.usn()]) # Card Flags ########################################################################## - def setUserFlag(self, flag, cids): + def setUserFlag(self, flag, cids) -> None: assert 0 <= flag <= 7 self.db.execute("update cards set flags = (flags & ~?) | ?, usn=?, mod=? where id in %s" % ids2str(cids), 0b111, flag, self.usn(), intTime()) diff --git a/anki/consts.py b/anki/consts.py index dde15fe93..f29683396 100644 --- a/anki/consts.py +++ b/anki/consts.py @@ -3,6 +3,7 @@ # License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html from anki.lang import _ +from typing import Any, Dict # whether new cards should be mixed with reviews, or shown first or last NEW_CARDS_DISTRIBUTE = 0 @@ -57,27 +58,27 @@ HELP_SITE="http://ankisrs.net/docs/manual.html" # Labels ########################################################################## -def newCardOrderLabels(): +def newCardOrderLabels() -> Dict[int, Any]: return { 0: _("Show new cards in random order"), 1: _("Show new cards in order added") } -def newCardSchedulingLabels(): +def newCardSchedulingLabels() -> Dict[int, Any]: return { 0: _("Mix new cards and reviews"), 1: _("Show new cards after reviews"), 2: _("Show new cards before reviews"), } -def alignmentLabels(): +def alignmentLabels() -> Dict[int, Any]: return { 0: _("Center"), 1: _("Left"), 2: _("Right"), } -def dynOrderLabels(): +def dynOrderLabels() -> Dict[int, Any]: return { 0: _("Oldest seen first"), 1: _("Random"), diff --git a/anki/db.py b/anki/db.py index 7209a595b..a1c009b57 100644 --- a/anki/db.py +++ b/anki/db.py @@ -6,18 +6,19 @@ import os import time from sqlite3 import dbapi2 as sqlite, Cursor +from typing import Any, List DBError = sqlite.Error class DB: - def __init__(self, path, timeout=0): + def __init__(self, path, timeout=0) -> None: self._db = sqlite.connect(path, timeout=timeout) self._db.text_factory = self._textFactory self._path = path self.echo = os.environ.get("DBECHO") self.mod = False - def execute(self, sql, *a, **ka): + def execute(self, sql, *a, **ka) -> Cursor: s = sql.strip().lower() # mark modified? for stmt in "insert", "update", "delete": @@ -37,7 +38,7 @@ class DB: print(a, ka) return res - def executemany(self, sql, l): + def executemany(self, sql, l) -> None: self.mod = True t = time.time() self._db.executemany(sql, l) @@ -46,68 +47,68 @@ class DB: if self.echo == "2": print(l) - def commit(self): + def commit(self) -> None: t = time.time() self._db.commit() if self.echo: print("commit %0.3fms" % ((time.time() - t)*1000)) - def executescript(self, sql): + def executescript(self, sql) -> None: self.mod = True if self.echo: print(sql) self._db.executescript(sql) - def rollback(self): + def rollback(self) -> None: self._db.rollback() - def scalar(self, *a, **kw): + def scalar(self, *a, **kw) -> Any: res = self.execute(*a, **kw).fetchone() if res: return res[0] return None - def all(self, *a, **kw): + def all(self, *a, **kw) -> List: return self.execute(*a, **kw).fetchall() - def first(self, *a, **kw): + def first(self, *a, **kw) -> Any: c = self.execute(*a, **kw) res = c.fetchone() c.close() return res - def list(self, *a, **kw): + def list(self, *a, **kw) -> List: return [x[0] for x in self.execute(*a, **kw)] - def close(self): + def close(self) -> None: self._db.text_factory = None self._db.close() - def set_progress_handler(self, *args): + def set_progress_handler(self, *args) -> None: self._db.set_progress_handler(*args) - def __enter__(self): + def __enter__(self) -> "DB": self._db.execute("begin") return self - def __exit__(self, exc_type, *args): + def __exit__(self, exc_type, *args) -> None: self._db.close() - def totalChanges(self): + def totalChanges(self) -> Any: return self._db.total_changes - def interrupt(self): + def interrupt(self) -> None: self._db.interrupt() - def setAutocommit(self, autocommit): + def setAutocommit(self, autocommit) -> None: if autocommit: self._db.isolation_level = None else: self._db.isolation_level = '' # strip out invalid utf-8 when reading from db - def _textFactory(self, data): + def _textFactory(self, data) -> str: return str(data, errors="ignore") - def cursor(self, factory=Cursor): + def cursor(self, factory=Cursor) -> Cursor: return self._db.cursor(factory) diff --git a/anki/decks.py b/anki/decks.py index 3adf8a373..9126526d7 100644 --- a/anki/decks.py +++ b/anki/decks.py @@ -11,6 +11,7 @@ from anki.hooks import runHook from anki.consts import * from anki.lang import _ from anki.errors import DeckRenameError +from typing import Any, Dict, List, Optional, Tuple # fixmes: # - make sure users can't set grad interval < 1 @@ -94,12 +95,12 @@ class DeckManager: # Registry save/load ############################################################# - def __init__(self, col): + def __init__(self, col) -> None: self.col = col self.decks = {} self.dconf = {} - def load(self, decks, dconf): + def load(self, decks, dconf) -> None: self.decks = json.loads(decks) self.dconf = json.loads(dconf) # set limits to within bounds @@ -114,14 +115,14 @@ class DeckManager: if not found: self.changed = False - def save(self, g=None): + def save(self, g=None) -> None: "Can be called with either a deck or a deck configuration." if g: g['mod'] = intTime() g['usn'] = self.col.usn() self.changed = True - def flush(self): + def flush(self) -> None: if self.changed: self.col.db.execute("update col set decks=?, dconf=?", json.dumps(self.decks), @@ -131,7 +132,7 @@ class DeckManager: # Deck save/load ############################################################# - def id(self, name, create=True, type=None): + def id(self, name, create=True, type=None) -> Optional[int]: "Add a deck with NAME. Reuse deck if already exists. Return id as int." if type is None: type = defaultDeck @@ -158,7 +159,7 @@ class DeckManager: runHook("newDeck") return int(id) - def rem(self, did, cardsToo=False, childrenToo=True): + def rem(self, did, cardsToo=False, childrenToo=True) -> None: "Remove the deck. If cardsToo, delete any cards inside." if str(did) == '1': # we won't allow the default deck to be deleted, but if it's a @@ -208,58 +209,58 @@ class DeckManager: self.select(int(list(self.decks.keys())[0])) self.save() - def allNames(self, dyn=True, forceDefault=True): + def allNames(self, dyn=True, forceDefault=True) -> List: "An unsorted list of all deck names." if dyn: return [x['name'] for x in self.all(forceDefault=forceDefault)] else: return [x['name'] for x in self.all(forceDefault=forceDefault) if not x['dyn']] - def all(self, forceDefault=True): + def all(self, forceDefault=True) -> List: "A list of all decks." decks = list(self.decks.values()) if not forceDefault and not self.col.db.scalar("select 1 from cards where did = 1 limit 1") and len(decks)>1: decks = [deck for deck in decks if deck['id'] != 1] return decks - def allIds(self): + def allIds(self) -> List[str]: return list(self.decks.keys()) - def collapse(self, did): + def collapse(self, did) -> None: deck = self.get(did) deck['collapsed'] = not deck['collapsed'] self.save(deck) - def collapseBrowser(self, did): + def collapseBrowser(self, did) -> None: deck = self.get(did) collapsed = deck.get('browserCollapsed', False) deck['browserCollapsed'] = not collapsed self.save(deck) - def count(self): + def count(self) -> int: return len(self.decks) - def get(self, did, default=True): + def get(self, did, default=True) -> Any: id = str(did) if id in self.decks: return self.decks[id] elif default: return self.decks['1'] - def byName(self, name): + def byName(self, name) -> Any: """Get deck with NAME, ignoring case.""" for m in list(self.decks.values()): if self.equalName(m['name'], name): return m - def update(self, g): + def update(self, g) -> None: "Add or update an existing deck. Used for syncing and merging." self.decks[str(g['id'])] = g self.maybeAddToActive() # mark registry changed, but don't bump mod time self.save() - def rename(self, g, newName): + def rename(self, g, newName) -> None: "Rename deck prefix to NAME if not exists. Updates children." # make sure target node doesn't already exist if self.byName(newName): @@ -284,7 +285,7 @@ class DeckManager: # renaming may have altered active did order self.maybeAddToActive() - def renameForDragAndDrop(self, draggedDeckDid, ontoDeckDid): + def renameForDragAndDrop(self, draggedDeckDid, ontoDeckDid) -> None: draggedDeck = self.get(draggedDeckDid) draggedDeckName = draggedDeck['name'] ontoDeckName = self.get(ontoDeckDid)['name'] @@ -299,7 +300,7 @@ class DeckManager: assert ontoDeckName.strip() self.rename(draggedDeck, ontoDeckName + "::" + self._basename(draggedDeckName)) - def _canDragAndDrop(self, draggedDeckName, ontoDeckName): + def _canDragAndDrop(self, draggedDeckName, ontoDeckName) -> bool: if draggedDeckName == ontoDeckName \ or self._isParent(ontoDeckName, draggedDeckName) \ or self._isAncestor(draggedDeckName, ontoDeckName): @@ -307,19 +308,19 @@ class DeckManager: else: return True - def _isParent(self, parentDeckName, childDeckName): + def _isParent(self, parentDeckName, childDeckName) -> Any: return self._path(childDeckName) == self._path(parentDeckName) + [ self._basename(childDeckName) ] - def _isAncestor(self, ancestorDeckName, descendantDeckName): + def _isAncestor(self, ancestorDeckName, descendantDeckName) -> Any: ancestorPath = self._path(ancestorDeckName) return ancestorPath == self._path(descendantDeckName)[0:len(ancestorPath)] - def _path(self, name): + def _path(self, name) -> Any: return name.split("::") - def _basename(self, name): + def _basename(self, name) -> Any: return self._path(name)[-1] - def _ensureParents(self, name): + def _ensureParents(self, name) -> Any: "Ensure parents exist, and return name with case matching parents." s = "" path = self._path(name) @@ -340,11 +341,11 @@ class DeckManager: # Deck configurations ############################################################# - def allConf(self): + def allConf(self) -> List: "A list of all deck config." return list(self.dconf.values()) - def confForDid(self, did): + def confForDid(self, did) -> Any: deck = self.get(did, default=False) assert deck if 'conf' in deck: @@ -354,14 +355,14 @@ class DeckManager: # dynamic decks have embedded conf return deck - def getConf(self, confId): + def getConf(self, confId) -> Any: return self.dconf[str(confId)] - def updateConf(self, g): + def updateConf(self, g) -> None: self.dconf[str(g['id'])] = g self.save() - def confId(self, name, cloneFrom=None): + def confId(self, name, cloneFrom=None) -> int: "Create a new configuration and return id." if cloneFrom is None: cloneFrom = defaultConf @@ -376,7 +377,7 @@ class DeckManager: self.save(c) return id - def remConf(self, id): + def remConf(self, id) -> None: "Remove a configuration and update all decks using it." assert int(id) != 1 self.col.modSchema(check=True) @@ -389,18 +390,18 @@ class DeckManager: g['conf'] = 1 self.save(g) - def setConf(self, grp, id): + def setConf(self, grp, id) -> None: grp['conf'] = id self.save(grp) - def didsForConf(self, conf): + def didsForConf(self, conf) -> List: dids = [] for deck in list(self.decks.values()): if 'conf' in deck and deck['conf'] == conf['id']: dids.append(deck['id']) return dids - def restoreToDefault(self, conf): + def restoreToDefault(self, conf) -> None: oldOrder = conf['new']['order'] new = copy.deepcopy(defaultConf) new['id'] = conf['id'] @@ -414,29 +415,29 @@ class DeckManager: # Deck utils ############################################################# - def name(self, did, default=False): + def name(self, did, default=False) -> Any: deck = self.get(did, default=default) if deck: return deck['name'] return _("[no deck]") - def nameOrNone(self, did): + def nameOrNone(self, did) -> Any: deck = self.get(did, default=False) if deck: return deck['name'] return None - def setDeck(self, cids, did): + def setDeck(self, cids, did) -> None: self.col.db.execute( "update cards set did=?,usn=?,mod=? where id in "+ ids2str(cids), did, self.col.usn(), intTime()) - def maybeAddToActive(self): + def maybeAddToActive(self) -> None: # reselect current deck, or default if current has disappeared c = self.current() self.select(c['id']) - def cids(self, did, children=False): + def cids(self, did, children=False) -> Any: if not children: return self.col.db.list("select id from cards where did=?", did) dids = [did] @@ -445,14 +446,14 @@ class DeckManager: return self.col.db.list("select id from cards where did in "+ ids2str(dids)) - def _recoverOrphans(self): + def _recoverOrphans(self) -> None: dids = list(self.decks.keys()) mod = self.col.db.mod self.col.db.execute("update cards set did = 1 where did not in "+ ids2str(dids)) self.col.db.mod = mod - def _checkDeckTree(self): + def _checkDeckTree(self) -> None: decks = self.col.decks.all() decks.sort(key=operator.itemgetter('name')) names = set() @@ -480,25 +481,25 @@ class DeckManager: names.add(deck['name']) - def checkIntegrity(self): + def checkIntegrity(self) -> None: self._recoverOrphans() self._checkDeckTree() # Deck selection ############################################################# - def active(self): + def active(self) -> Any: "The currrently active dids. Make sure to copy before modifying." return self.col.conf['activeDecks'] - def selected(self): + def selected(self) -> Any: "The currently selected did." return self.col.conf['curDeck'] - def current(self): + def current(self) -> Any: return self.get(self.selected()) - def select(self, did): + def select(self, did) -> None: "Select a new branch." # make sure arg is an int did = int(did) @@ -510,7 +511,7 @@ class DeckManager: self.col.conf['activeDecks'] = [did] + [a[1] for a in actv] self.changed = True - def children(self, did): + def children(self, did) -> List[Tuple[Any, Any]]: "All children of did, as (name, id)." name = self.get(did)['name'] actv = [] @@ -519,7 +520,7 @@ class DeckManager: actv.append((g['name'], g['id'])) return actv - def childDids(self, did, childMap): + def childDids(self, did, childMap) -> List: def gather(node, arr): for did, child in node.items(): arr.append(did) @@ -529,7 +530,7 @@ class DeckManager: gather(childMap[did], arr) return arr - def childMap(self): + def childMap(self) -> Dict[Any, Dict[Any, dict]]: nameMap = self.nameMap() childMap = {} @@ -547,7 +548,7 @@ class DeckManager: return childMap - def parents(self, did, nameMap=None): + def parents(self, did, nameMap=None) -> List: "All parents of did." # get parent and grandparent names parents = [] @@ -565,7 +566,7 @@ class DeckManager: parents[c] = deck return parents - def parentsByName(self, name): + def parentsByName(self, name) -> List: "All existing parents of name" if "::" not in name: return [] @@ -581,13 +582,13 @@ class DeckManager: return parents - def nameMap(self): + def nameMap(self) -> dict: return dict((d['name'], d) for d in self.decks.values()) # Sync handling ########################################################################## - def beforeUpload(self): + def beforeUpload(self) -> None: for d in self.all(): d['usn'] = 0 for c in self.allConf(): @@ -597,19 +598,19 @@ class DeckManager: # Dynamic decks ########################################################################## - def newDyn(self, name): + def newDyn(self, name) -> int: "Return a new dynamic deck and set it as the current deck." did = self.id(name, type=defaultDynamicDeck) self.select(did) return did - def isDyn(self, did): + def isDyn(self, did) -> Any: return self.get(did)['dyn'] @staticmethod - def normalizeName(name): + def normalizeName(name) -> str: return unicodedata.normalize("NFC", name.lower()) @staticmethod - def equalName(name1, name2): + def equalName(name1, name2) -> bool: return DeckManager.normalizeName(name1) == DeckManager.normalizeName(name2) diff --git a/anki/errors.py b/anki/errors.py index 604e83ae2..d283b159a 100644 --- a/anki/errors.py +++ b/anki/errors.py @@ -1,20 +1,21 @@ +from typing import Any # -*- coding: utf-8 -*- # Copyright: Ankitects Pty Ltd and contributors # License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html class AnkiError(Exception): - def __init__(self, type, **data): + def __init__(self, type, **data) -> None: super().__init__() self.type = type self.data = data - def __str__(self): + def __str__(self) -> Any: m = self.type if self.data: m += ": %s" % repr(self.data) return m class DeckRenameError(Exception): - def __init__(self, description): + def __init__(self, description) -> None: super().__init__() self.description = description def __str__(self): diff --git a/anki/exporting.py b/anki/exporting.py index d4499c622..b72694a90 100644 --- a/anki/exporting.py +++ b/anki/exporting.py @@ -9,24 +9,25 @@ from anki.lang import _ from anki.utils import ids2str, splitFields, namedtmp, stripHTML from anki.hooks import runHook from anki.storage import Collection +from typing import Any, Dict, List, Tuple class Exporter: includeHTML: typing.Union[bool, None] = None - def __init__(self, col, did=None): + def __init__(self, col, did=None) -> None: self.col = col self.did = did - def doExport(self, path): + def doExport(self, path) -> None: raise Exception("not implemented") - def exportInto(self, path): + def exportInto(self, path) -> None: self._escapeCount = 0 file = open(path, "wb") self.doExport(file) file.close() - def processText(self, text): + def processText(self, text) -> str: if self.includeHTML is False: text = self.stripHTML(text) @@ -34,7 +35,7 @@ class Exporter: return text - def escapeText(self, text): + def escapeText(self, text) -> str: "Escape newlines, tabs, CSS and quotechar." # fixme: we should probably quote fields with newlines # instead of converting them to spaces @@ -46,7 +47,7 @@ class Exporter: text = "\"" + text.replace("\"", "\"\"") + "\"" return text - def stripHTML(self, text): + def stripHTML(self, text) -> str: # very basic conversion to text s = text s = re.sub(r"(?i)<(br ?/?|div|p)>", " ", s) @@ -56,7 +57,7 @@ class Exporter: s = s.strip() return s - def cardIds(self): + def cardIds(self) -> Any: if not self.did: cids = self.col.db.list("select id from cards") else: @@ -73,10 +74,10 @@ class TextCardExporter(Exporter): ext = ".txt" includeHTML = True - def __init__(self, col): + def __init__(self, col) -> None: Exporter.__init__(self, col) - def doExport(self, file): + def doExport(self, file) -> None: ids = sorted(self.cardIds()) strids = ids2str(ids) def esc(s): @@ -100,11 +101,11 @@ class TextNoteExporter(Exporter): includeTags = True includeHTML = True - def __init__(self, col): + def __init__(self, col) -> None: Exporter.__init__(self, col) self.includeID = False - def doExport(self, file): + def doExport(self, file) -> None: cardIds = self.cardIds() data = [] for id, flds, tags in self.col.db.execute(""" @@ -137,10 +138,10 @@ class AnkiExporter(Exporter): includeSched: typing.Union[bool, None] = False includeMedia = True - def __init__(self, col): + def __init__(self, col) -> None: Exporter.__init__(self, col) - def exportInto(self, path): + def exportInto(self, path) -> None: # sched info+v2 scheduler not compatible w/ older clients self._v2sched = self.col.schedVer() != 1 and self.includeSched @@ -253,15 +254,15 @@ class AnkiExporter(Exporter): self.postExport() self.dst.close() - def postExport(self): + def postExport(self) -> None: # overwrite to apply customizations to the deck before it's closed, # such as update the deck description pass - def removeSystemTags(self, tags): + def removeSystemTags(self, tags) -> Any: return self.src.tags.remFromStr("marked leech", tags) - def _modelHasMedia(self, model, fname): + def _modelHasMedia(self, model, fname) -> bool: # First check the styling if fname in model["css"]: return True @@ -290,7 +291,7 @@ class AnkiPackageExporter(AnkiExporter): z.writestr("media", json.dumps(media)) z.close() - def doExport(self, z, path): + def doExport(self, z, path) -> Dict[str, str]: # export into the anki2 file colfile = path.replace(".apkg", ".anki2") AnkiExporter.exportInto(self, colfile) @@ -314,7 +315,7 @@ class AnkiPackageExporter(AnkiExporter): shutil.rmtree(path.replace(".apkg", ".media")) return media - def _exportMedia(self, z, files, fdir): + def _exportMedia(self, z, files, fdir) -> Dict[str, str]: media = {} for c, file in enumerate(files): cStr = str(c) @@ -331,14 +332,14 @@ class AnkiPackageExporter(AnkiExporter): return media - def prepareMedia(self): + def prepareMedia(self) -> None: # chance to move each file in self.mediaFiles into place before media # is zipped up pass # create a dummy collection to ensure older clients don't try to read # data they don't understand - def _addDummyCollection(self, zip): + def _addDummyCollection(self, zip) -> None: path = namedtmp("dummy.anki2") c = Collection(path) n = c.newNote() @@ -383,7 +384,7 @@ class AnkiCollectionPackageExporter(AnkiPackageExporter): # Export modules ########################################################################## -def exporters(): +def exporters() -> List[Tuple[str, Any]]: def id(obj): return ("%s (*%s)" % (obj.key, obj.ext), obj) exps = [ diff --git a/anki/find.py b/anki/find.py index e4d5f7a21..0ad5d42c3 100644 --- a/anki/find.py +++ b/anki/find.py @@ -9,6 +9,7 @@ import unicodedata from anki.utils import ids2str, splitFields, joinFields, intTime, fieldChecksum, stripHTMLMedia from anki.consts import * from anki.hooks import * +from typing import Any, List, Optional, Tuple # Find @@ -16,7 +17,7 @@ from anki.hooks import * class Finder: - def __init__(self, col): + def __init__(self, col) -> None: self.col = col self.search = dict( added=self._findAdded, @@ -35,7 +36,7 @@ class Finder: self.search['is'] = self._findCardState runHook("search", self.search) - def findCards(self, query, order=False): + def findCards(self, query, order=False) -> Any: "Return a list of card ids for QUERY." tokens = self._tokenize(query) preds, args = self._where(tokens) @@ -52,7 +53,7 @@ class Finder: res.reverse() return res - def findNotes(self, query): + def findNotes(self, query) -> Any: tokens = self._tokenize(query) preds, args = self._where(tokens) if preds is None: @@ -73,7 +74,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """+preds # Tokenizing ###################################################################### - def _tokenize(self, query): + def _tokenize(self, query) -> List: inQuote = False tokens = [] token = "" @@ -127,7 +128,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """+preds # Query building ###################################################################### - def _where(self, tokens): + def _where(self, tokens) -> Tuple[Any, Optional[List[str]]]: # state and query s = dict(isnot=False, isor=False, join=False, q="", bad=False) args = [] @@ -185,7 +186,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """+preds return None, None return s['q'], args - def _query(self, preds, order): + def _query(self, preds, order) -> str: # can we skip the note table? if "n." not in preds and "n." not in order: sql = "select c.id from cards c where " @@ -204,7 +205,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """+preds # Ordering ###################################################################### - def _order(self, order): + def _order(self, order) -> Tuple[Any, Any]: if not order: return "", False elif order is not True: @@ -241,7 +242,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """+preds # Commands ###################################################################### - def _findTag(self, args): + def _findTag(self, args) -> str: (val, args) = args if val == "none": return 'n.tags = ""' @@ -253,7 +254,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """+preds args.append(val) return "n.tags like ? escape '\\'" - def _findCardState(self, args): + def _findCardState(self, args) -> Optional[str]: (val, args) = args if val in ("review", "new", "learn"): if val == "review": @@ -273,7 +274,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """+preds (c.queue = 1 and c.due <= %d)""" % ( self.col.sched.today, self.col.sched.dayCutoff) - def _findFlag(self, args): + def _findFlag(self, args) -> Optional[str]: (val, args) = args if not val or len(val)!=1 or val not in "01234": return @@ -281,7 +282,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """+preds mask = 2**3 - 1 return "(c.flags & %d) == %d" % (mask, val) - def _findRated(self, args): + def _findRated(self, args) -> Optional[str]: # days(:optional_ease) (val, args) = args r = val.split(":") @@ -300,7 +301,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """+preds return ("c.id in (select cid from revlog where id>%d %s)" % (cutoff, ease)) - def _findAdded(self, args): + def _findAdded(self, args) -> Optional[str]: (val, args) = args try: days = int(val) @@ -309,7 +310,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """+preds cutoff = (self.col.sched.dayCutoff - 86400*days)*1000 return "c.id > %d" % cutoff - def _findProp(self, args): + def _findProp(self, args) -> Optional[str]: # extract (val, args) = args m = re.match("(^.+?)(<=|>=|!=|=|<|>)(.+?$)", val) @@ -340,31 +341,31 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """+preds q.append("(%s %s %s)" % (prop, cmp, val)) return " and ".join(q) - def _findText(self, val, args): + def _findText(self, val, args) -> str: val = val.replace("*", "%") args.append("%"+val+"%") args.append("%"+val+"%") return "(n.sfld like ? escape '\\' or n.flds like ? escape '\\')" - def _findNids(self, args): + def _findNids(self, args) -> Optional[str]: (val, args) = args if re.search("[^0-9,]", val): return return "n.id in (%s)" % val - def _findCids(self, args): + def _findCids(self, args) -> Optional[str]: (val, args) = args if re.search("[^0-9,]", val): return return "c.id in (%s)" % val - def _findMid(self, args): + def _findMid(self, args) -> Optional[str]: (val, args) = args if re.search("[^0-9]", val): return return "n.mid = %s" % val - def _findModel(self, args): + def _findModel(self, args) -> str: (val, args) = args ids = [] val = val.lower() @@ -373,7 +374,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """+preds ids.append(m['id']) return "n.mid in %s" % ids2str(ids) - def _findDeck(self, args): + def _findDeck(self, args) -> Optional[str]: # if searching for all decks, skip (val, args) = args if val == "*": @@ -404,7 +405,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """+preds sids = ids2str(ids) return "c.did in %s or c.odid in %s" % (sids, sids) - def _findTemplate(self, args): + def _findTemplate(self, args) -> str: # were we given an ordinal number? (val, args) = args try: @@ -428,7 +429,7 @@ select distinct(n.id) from cards c, notes n where c.nid=n.id and """+preds m['id'], t['ord'])) return " or ".join(lims) - def _findField(self, field, val): + def _findField(self, field, val) -> Optional[str]: field = field.lower() val = val.replace("*", "%") # find models that have that field @@ -460,7 +461,7 @@ where mid in %s and flds like ? escape '\\'""" % ( return "0" return "n.id in %s" % ids2str(nids) - def _findDupes(self, args): + def _findDupes(self, args) -> Optional[str]: # caller must call stripHTMLMedia on passed val (val, args) = args try: @@ -479,7 +480,7 @@ where mid in %s and flds like ? escape '\\'""" % ( # Find and replace ########################################################################## -def findReplace(col, nids, src, dst, regex=False, field=None, fold=True): +def findReplace(col, nids, src, dst, regex=False, field=None, fold=True) -> int: "Find and replace fields in a note." mmap = {} if field: @@ -529,7 +530,7 @@ def findReplace(col, nids, src, dst, regex=False, field=None, fold=True): col.genCards(nids) return len(d) -def fieldNames(col, downcase=True): +def fieldNames(col, downcase=True) -> List: fields = set() for m in col.models.all(): for f in m['flds']: @@ -538,7 +539,7 @@ def fieldNames(col, downcase=True): fields.add(name) return list(fields) -def fieldNamesForNotes(col, nids): +def fieldNamesForNotes(col, nids) -> List: fields = set() mids = col.db.list("select distinct mid from notes where id in %s" % ids2str(nids)) for mid in mids: @@ -551,7 +552,7 @@ def fieldNamesForNotes(col, nids): # Find duplicates ########################################################################## # returns array of ("dupestr", [nids]) -def findDupes(col, fieldName, search=""): +def findDupes(col, fieldName, search="") -> List[Tuple[Any, List]]: # limit search to notes with applicable field name if search: search = "("+search+") " diff --git a/anki/hooks.py b/anki/hooks.py index c8d909e41..98c811a51 100644 --- a/anki/hooks.py +++ b/anki/hooks.py @@ -21,7 +21,7 @@ from typing import Dict, List, Callable, Any _hooks: Dict[str, List[Callable[..., Any]]] = {} -def runHook(hook, *args): +def runHook(hook, *args) -> None: "Run all functions on hook." hook = _hooks.get(hook, None) if hook: @@ -32,7 +32,7 @@ def runHook(hook, *args): hook.remove(func) raise -def runFilter(hook, arg, *args): +def runFilter(hook, arg, *args) -> Any: hook = _hooks.get(hook, None) if hook: for func in hook: @@ -43,14 +43,14 @@ def runFilter(hook, arg, *args): raise return arg -def addHook(hook, func): +def addHook(hook, func) -> None: "Add a function to hook. Ignore if already on hook." if not _hooks.get(hook, None): _hooks[hook] = [] if func not in _hooks[hook]: _hooks[hook].append(func) -def remHook(hook, func): +def remHook(hook, func) -> None: "Remove a function if is on hook." hook = _hooks.get(hook, []) if func in hook: @@ -59,7 +59,7 @@ def remHook(hook, func): # Instrumenting ############################################################################## -def wrap(old, new, pos="after"): +def wrap(old, new, pos="after") -> Callable: "Override an existing function." def repl(*args, **kwargs): if pos == "after": diff --git a/anki/importing/anki2.py b/anki/importing/anki2.py index c03786ac2..31fbee7af 100644 --- a/anki/importing/anki2.py +++ b/anki/importing/anki2.py @@ -8,6 +8,7 @@ from anki.storage import Collection from anki.utils import intTime, splitFields, joinFields from anki.importing.base import Importer from anki.lang import _ +from typing import Any GUID = 1 MID = 2 @@ -27,7 +28,7 @@ class Anki2Importer(Importer): self._decks = {} self.mustResetLearning = False - def run(self, media=None): + def run(self, media=None) -> None: self._prepareFiles() if media is not None: # Anki1 importer has provided us with a custom media folder @@ -37,7 +38,7 @@ class Anki2Importer(Importer): finally: self.src.close(save=False) - def _prepareFiles(self): + def _prepareFiles(self) -> None: importingV2 = self.file.endswith(".anki21") self.mustResetLearning = False @@ -49,7 +50,7 @@ class Anki2Importer(Importer): if self.src.db.scalar("select 1 from cards where queue != 0 limit 1"): self.mustResetLearning = True - def _import(self): + def _import(self) -> None: self._decks = {} if self.deckPrefix: id = self.dst.decks.id(self.deckPrefix) @@ -68,13 +69,13 @@ class Anki2Importer(Importer): # Notes ###################################################################### - def _logNoteRow(self, action, noteRow): + def _logNoteRow(self, action, noteRow) -> None: self.log.append("[%s] %s" % ( action, noteRow[6].replace("\x1f", ", ") )) - def _importNotes(self): + def _importNotes(self) -> None: # build guid -> (id,mod,mid) hash & map of existing note ids self._notes = {} existing = {} @@ -185,7 +186,7 @@ class Anki2Importer(Importer): # determine if note is a duplicate, and adjust mid and/or guid as required # returns true if note should be added - def _uniquifyNote(self, note): + def _uniquifyNote(self, note) -> bool: origGuid = note[GUID] srcMid = note[MID] dstMid = self._mid(srcMid) @@ -207,11 +208,11 @@ class Anki2Importer(Importer): # the schemas don't match, we increment the mid and try again, creating a # new model if necessary. - def _prepareModels(self): + def _prepareModels(self) -> None: "Prepare index of schema hashes." self._modelMap = {} - def _mid(self, srcMid): + def _mid(self, srcMid) -> Any: "Return local id for remote MID." # already processed this mid? if srcMid in self._modelMap: @@ -248,7 +249,7 @@ class Anki2Importer(Importer): # Decks ###################################################################### - def _did(self, did): + def _did(self, did) -> Any: "Given did in src col, return local id." # already converted? if did in self._decks: @@ -295,7 +296,7 @@ class Anki2Importer(Importer): # Cards ###################################################################### - def _importCards(self): + def _importCards(self) -> None: if self.mustResetLearning: self.src.changeSchedulerVer(2) # build map of (guid, ord) -> cid and used id cache @@ -382,7 +383,7 @@ insert or ignore into revlog values (?,?,?,?,?,?,?,?,?)""", revlog) # note: this func only applies to imports of .anki2. for .apkg files, the # apkg importer does the copying - def _importStaticMedia(self): + def _importStaticMedia(self) -> None: # Import any '_foo' prefixed media files regardless of whether # they're used on notes or not dir = self.src.media.dir() @@ -392,7 +393,7 @@ insert or ignore into revlog values (?,?,?,?,?,?,?,?,?)""", revlog) if fname.startswith("_") and not self.dst.media.have(fname): self._writeDstMedia(fname, self._srcMediaData(fname)) - def _mediaData(self, fname, dir=None): + def _mediaData(self, fname, dir=None) -> bytes: if not dir: dir = self.src.media.dir() path = os.path.join(dir, fname) @@ -402,15 +403,15 @@ insert or ignore into revlog values (?,?,?,?,?,?,?,?,?)""", revlog) except (IOError, OSError): return - def _srcMediaData(self, fname): + def _srcMediaData(self, fname) -> bytes: "Data for FNAME in src collection." return self._mediaData(fname, self.src.media.dir()) - def _dstMediaData(self, fname): + def _dstMediaData(self, fname) -> bytes: "Data for FNAME in dst collection." return self._mediaData(fname, self.dst.media.dir()) - def _writeDstMedia(self, fname, data): + def _writeDstMedia(self, fname, data) -> None: path = os.path.join(self.dst.media.dir(), unicodedata.normalize("NFC", fname)) try: @@ -420,7 +421,7 @@ insert or ignore into revlog values (?,?,?,?,?,?,?,?,?)""", revlog) # the user likely used subdirectories pass - def _mungeMedia(self, mid, fields): + def _mungeMedia(self, mid, fields) -> str: fields = splitFields(fields) def repl(match): fname = match.group("fname") @@ -450,7 +451,7 @@ insert or ignore into revlog values (?,?,?,?,?,?,?,?,?)""", revlog) # Post-import cleanup ###################################################################### - def _postImport(self): + def _postImport(self) -> None: for did in list(self._decks.values()): self.col.sched.maybeRandomizeDeck(did) # make sure new position is correct diff --git a/anki/importing/apkg.py b/anki/importing/apkg.py index f0e21e462..88ae921dc 100644 --- a/anki/importing/apkg.py +++ b/anki/importing/apkg.py @@ -7,6 +7,7 @@ import unicodedata import json from anki.utils import tmpfile from anki.importing.anki2 import Anki2Importer +from typing import Any class AnkiPackageImporter(Anki2Importer): @@ -16,7 +17,7 @@ class AnkiPackageImporter(Anki2Importer): self.nameToNum = {} self.zip = None - def run(self): + def run(self) -> None: # extract the deck from the zip file self.zip = z = zipfile.ZipFile(self.file) # v2 scheduler? @@ -52,7 +53,7 @@ class AnkiPackageImporter(Anki2Importer): with open(path, "wb") as f: f.write(z.read(c)) - def _srcMediaData(self, fname): + def _srcMediaData(self, fname) -> Any: if fname in self.nameToNum: return self.zip.read(self.nameToNum[fname]) return None diff --git a/anki/importing/base.py b/anki/importing/base.py index 49f063121..254ba97df 100644 --- a/anki/importing/base.py +++ b/anki/importing/base.py @@ -3,6 +3,7 @@ # License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html from anki.utils import maxID +from typing import Any # Base importer ########################################################################## @@ -12,14 +13,14 @@ class Importer: needMapper = False needDelimiter = False - def __init__(self, col, file): + def __init__(self, col, file) -> None: self.file = file self.log = [] self.col = col self.total = 0 self.dst = None - def run(self): + def run(self) -> None: pass # Timestamps @@ -28,9 +29,9 @@ class Importer: # and a previous import may have created timestamps in the future, so we # need to make sure our starting point is safe. - def _prepareTS(self): + def _prepareTS(self) -> None: self._ts = maxID(self.dst.db) - def ts(self): + def ts(self) -> Any: self._ts += 1 return self._ts diff --git a/anki/importing/csvfile.py b/anki/importing/csvfile.py index c80119139..e036f56e0 100644 --- a/anki/importing/csvfile.py +++ b/anki/importing/csvfile.py @@ -7,6 +7,7 @@ import re from anki.importing.noteimp import NoteImporter, ForeignNote from anki.lang import _ +from typing import List class TextImporter(NoteImporter): @@ -22,7 +23,7 @@ class TextImporter(NoteImporter): self.tagsToAdd = [] self.numFields = 0 - def foreignNotes(self): + def foreignNotes(self) -> List[ForeignNote]: self.open() # process all lines log = [] @@ -60,12 +61,12 @@ class TextImporter(NoteImporter): # load & look for the right pattern self.cacheFile() - def cacheFile(self): + def cacheFile(self) -> None: "Read file into self.lines if not already there." if not self.fileobj: self.openFile() - def openFile(self): + def openFile(self) -> None: self.dialect = None self.fileobj = open(self.file, "r", encoding='utf-8-sig') self.data = self.fileobj.read() @@ -81,7 +82,7 @@ class TextImporter(NoteImporter): if not self.dialect and not self.delimiter: raise Exception("unknownFormat") - def updateDelimiter(self): + def updateDelimiter(self) -> None: def err(): raise Exception("unknownFormat") self.dialect = None @@ -126,7 +127,7 @@ class TextImporter(NoteImporter): self.open() return self.numFields - def noteFromFields(self, fields): + def noteFromFields(self, fields) -> ForeignNote: note = ForeignNote() note.fields.extend([x for x in fields]) note.tags.extend(self.tagsToAdd) diff --git a/anki/importing/mnemo.py b/anki/importing/mnemo.py index 294253882..09e159f4e 100644 --- a/anki/importing/mnemo.py +++ b/anki/importing/mnemo.py @@ -101,7 +101,7 @@ acq_reps+ret_reps, lapses, card_type_id from cards"""): def fields(self): return self._fields - def _mungeField(self, fld): + def _mungeField(self, fld) -> str: # \n -> br fld = re.sub("\r?\n", "
", fld) # latex differences @@ -110,7 +110,7 @@ acq_reps+ret_reps, lapses, card_type_id from cards"""): fld = re.sub(")?", "[sound:\\1]", fld) return fld - def _addFronts(self, notes, model=None, fields=("f", "b")): + def _addFronts(self, notes, model=None, fields=("f", "b")) -> None: data = [] for orig in notes: # create a foreign note object @@ -135,7 +135,7 @@ acq_reps+ret_reps, lapses, card_type_id from cards"""): # import self.importNotes(data) - def _addFrontBacks(self, notes): + def _addFrontBacks(self, notes) -> None: m = addBasicModel(self.col) m['name'] = "Mnemosyne-FrontBack" mm = self.col.models @@ -145,7 +145,7 @@ acq_reps+ret_reps, lapses, card_type_id from cards"""): mm.addTemplate(m, t) self._addFronts(notes, m) - def _addVocabulary(self, notes): + def _addVocabulary(self, notes) -> None: mm = self.col.models m = mm.new("Mnemosyne-Vocabulary") for f in "Expression", "Pronunciation", "Meaning", "Notes": @@ -164,7 +164,7 @@ acq_reps+ret_reps, lapses, card_type_id from cards"""): mm.add(m) self._addFronts(notes, m, fields=("f", "p_1", "m_1", "n")) - def _addCloze(self, notes): + def _addCloze(self, notes) -> None: data = [] notes = list(notes.values()) for orig in notes: diff --git a/anki/importing/noteimp.py b/anki/importing/noteimp.py index 176c07340..9418c1870 100644 --- a/anki/importing/noteimp.py +++ b/anki/importing/noteimp.py @@ -12,20 +12,21 @@ from anki.utils import fieldChecksum, guid64, timestampID, \ joinFields, intTime, splitFields from anki.importing.base import Importer from anki.lang import ngettext +from typing import Any, List, Optional # Stores a list of fields, tags and deck ###################################################################### class ForeignNote: "An temporary object storing fields and attributes." - def __init__(self): + def __init__(self) -> None: self.fields = [] self.tags = [] self.deck = None self.cards = {} # map of ord -> card class ForeignCard: - def __init__(self): + def __init__(self) -> None: self.due = 0 self.ivl = 1 self.factor = STARTING_FACTOR @@ -66,11 +67,11 @@ class NoteImporter(Importer): c = self.foreignNotes() self.importNotes(c) - def fields(self): + def fields(self) -> int: "The number of fields." return 0 - def initMapping(self): + def initMapping(self) -> None: flds = [f['name'] for f in self.model['flds']] # truncate to provided count flds = flds[0:self.fields()] @@ -81,18 +82,18 @@ class NoteImporter(Importer): flds = flds + [None] * (self.fields() - len(flds)) self.mapping = flds - def mappingOk(self): + def mappingOk(self) -> bool: return self.model['flds'][0]['name'] in self.mapping - def foreignNotes(self): + def foreignNotes(self) -> List: "Return a list of foreign notes for importing." return [] - def open(self): + def open(self) -> None: "Open file and ensure it's in the right format." return - def importNotes(self, notes): + def importNotes(self, notes) -> None: "Convert each card into a note, apply attributes and add to col." assert self.mappingOk() # note whether tags are mapped @@ -219,7 +220,7 @@ This can happen when you have empty fields or when you have not mapped the \ content in the text file to the correct fields.""")) self.total = len(self._ids) - def newData(self, n): + def newData(self, n) -> Optional[list]: id = self._nextID self._nextID += 1 self._ids.append(id) @@ -233,12 +234,12 @@ content in the text file to the correct fields.""")) intTime(), self.col.usn(), self.col.tags.join(n.tags), n.fieldsStr, "", "", 0, ""] - def addNew(self, rows): + def addNew(self, rows) -> None: self.col.db.executemany( "insert or replace into notes values (?,?,?,?,?,?,?,?,?,?,?)", rows) - def updateData(self, n, id, sflds): + def updateData(self, n, id, sflds) -> Optional[list]: self._ids.append(id) if not self.processFields(n, sflds): return @@ -251,7 +252,7 @@ content in the text file to the correct fields.""")) return [intTime(), self.col.usn(), n.fieldsStr, id, n.fieldsStr] - def addUpdates(self, rows): + def addUpdates(self, rows) -> None: old = self.col.db.totalChanges() if self._tagsMapped: self.col.db.executemany(""" @@ -263,7 +264,7 @@ update notes set mod = ?, usn = ?, flds = ? where id = ? and flds != ?""", rows) self.updateCount = self.col.db.totalChanges() - old - def processFields(self, note, fields=None): + def processFields(self, note, fields=None) -> Any: if not fields: fields = [""]*len(self.model['flds']) for c, f in enumerate(self.mapping): @@ -280,7 +281,7 @@ where id = ? and flds != ?""", rows) self._emptyNotes = True return ords - def updateCards(self): + def updateCards(self) -> None: data = [] for nid, ord, c in self._cards: data.append((c.ivl, c.due, c.factor, c.reps, c.lapses, nid, ord)) diff --git a/anki/importing/pauker.py b/anki/importing/pauker.py index d13e4cade..aeb58671b 100644 --- a/anki/importing/pauker.py +++ b/anki/importing/pauker.py @@ -6,6 +6,7 @@ import gzip, math, random, time, html import xml.etree.ElementTree as ET from anki.importing.noteimp import NoteImporter, ForeignNote, ForeignCard from anki.stdmodels import addForwardReverse +from typing import List ONE_DAY = 60*60*24 @@ -28,7 +29,7 @@ class PaukerImporter(NoteImporter): '''Pauker is Front/Back''' return 2 - def foreignNotes(self): + def foreignNotes(self) -> List[ForeignNote]: '''Build and return a list of notes.''' notes = [] @@ -66,7 +67,7 @@ class PaukerImporter(NoteImporter): return notes - def _learnedCard(self, batch, timestamp): + def _learnedCard(self, batch, timestamp) -> ForeignCard: ivl = math.exp(batch) now = time.time() due = ivl - (now - timestamp/1000.0)/ONE_DAY diff --git a/anki/importing/supermemo_xml.py b/anki/importing/supermemo_xml.py index e2d428725..5758885bb 100644 --- a/anki/importing/supermemo_xml.py +++ b/anki/importing/supermemo_xml.py @@ -13,6 +13,7 @@ from anki.lang import ngettext from xml.dom import minidom from string import capwords import re, unicodedata, time +from typing import Any, List class SmartDict(dict): """ @@ -25,7 +26,7 @@ class SmartDict(dict): x.get('first_name'). """ - def __init__(self, *a, **kw): + def __init__(self, *a, **kw) -> None: if a: if isinstance(type(a[0]), dict): kw.update(a[0]) @@ -119,17 +120,17 @@ class SupermemoXmlImporter(NoteImporter): ## TOOLS - def _fudgeText(self, text): + def _fudgeText(self, text) -> Any: "Replace sm syntax to Anki syntax" text = text.replace("\n\r", "
") text = text.replace("\n", "
") return text - def _unicode2ascii(self,str): + def _unicode2ascii(self,str) -> str: "Remove diacritic punctuation from strings (titles)" return "".join([ c for c in unicodedata.normalize('NFKD', str) if not unicodedata.combining(c)]) - def _decode_htmlescapes(self,s): + def _decode_htmlescapes(self,s) -> str: """Unescape HTML code.""" #In case of bad formated html you can import MinimalSoup etc.. see btflsoup source code from bs4 import BeautifulSoup as btflsoup @@ -142,7 +143,7 @@ class SupermemoXmlImporter(NoteImporter): return str(btflsoup(s, "html.parser")) - def _afactor2efactor(self, af): + def _afactor2efactor(self, af) -> Any: # Adapted from # Ranges for A-factors and E-factors @@ -166,7 +167,7 @@ class SupermemoXmlImporter(NoteImporter): ## DEFAULT IMPORTER METHODS - def foreignNotes(self): + def foreignNotes(self) -> List[ForeignNote]: # Load file and parse it by minidom self.loadSource(self.file) @@ -187,7 +188,7 @@ class SupermemoXmlImporter(NoteImporter): ## PARSER METHODS - def addItemToCards(self,item): + def addItemToCards(self,item) -> None: "This method actually do conversion" # new anki card @@ -247,7 +248,7 @@ class SupermemoXmlImporter(NoteImporter): self.notes.append(note) - def logger(self,text,level=1): + def logger(self,text,level=1) -> None: "Wrapper for Anki logger" dLevels={0:'',1:'Info',2:'Verbose',3:'Debug'} @@ -259,7 +260,7 @@ class SupermemoXmlImporter(NoteImporter): # OPEN AND LOAD - def openAnything(self,source): + def openAnything(self,source) -> Any: "Open any source / actually only openig of files is used" if source == "-": @@ -282,7 +283,7 @@ class SupermemoXmlImporter(NoteImporter): import io return io.StringIO(str(source)) - def loadSource(self, source): + def loadSource(self, source) -> None: """Load source file and parse with xml.dom.minidom""" self.source = source self.logger('Load started...') @@ -293,7 +294,7 @@ class SupermemoXmlImporter(NoteImporter): # PARSE - def parse(self, node=None): + def parse(self, node=None) -> None: "Parse method - parses document elements" if node is None and self.xmldoc is not None: @@ -306,12 +307,12 @@ class SupermemoXmlImporter(NoteImporter): else: self.logger('No handler for method %s' % _method, level=3) - def parse_Document(self, node): + def parse_Document(self, node) -> None: "Parse XML document" self.parse(node.documentElement) - def parse_Element(self, node): + def parse_Element(self, node) -> None: "Parse XML element" _method = "do_%s" % node.tagName @@ -322,7 +323,7 @@ class SupermemoXmlImporter(NoteImporter): self.logger('No handler for method %s' % _method, level=3) #print traceback.print_exc() - def parse_Text(self, node): + def parse_Text(self, node) -> None: "Parse text inside elements. Text is stored into local buffer." text = node.data @@ -336,12 +337,12 @@ class SupermemoXmlImporter(NoteImporter): # DO - def do_SuperMemoCollection(self, node): + def do_SuperMemoCollection(self, node) -> None: "Process SM Collection" for child in node.childNodes: self.parse(child) - def do_SuperMemoElement(self, node): + def do_SuperMemoElement(self, node) -> None: "Process SM Element (Type - Title,Topics)" self.logger('='*45, level=3) @@ -391,14 +392,14 @@ class SupermemoXmlImporter(NoteImporter): t = self.cntMeta['title'].pop() self.logger('End of topic \t- %s' % (t), level=2) - def do_Content(self, node): + def do_Content(self, node) -> None: "Process SM element Content" for child in node.childNodes: if hasattr(child,'tagName') and child.firstChild is not None: self.cntElm[-1][child.tagName]=child.firstChild.data - def do_LearningData(self, node): + def do_LearningData(self, node) -> None: "Process SM element LearningData" for child in node.childNodes: @@ -415,7 +416,7 @@ class SupermemoXmlImporter(NoteImporter): # for child in node.childNodes: self.parse(child) # self.cntElm[-1][node.tagName]=self.cntBuf.pop() - def do_Title(self, node): + def do_Title(self, node) -> None: "Process SM element Title" t = self._decode_htmlescapes(node.firstChild.data) @@ -425,7 +426,7 @@ class SupermemoXmlImporter(NoteImporter): self.logger('Start of topic \t- ' + " / ".join(self.cntMeta['title']), level=2) - def do_Type(self, node): + def do_Type(self, node) -> None: "Process SM element Type" if len(self.cntBuf) >=1 : diff --git a/anki/lang.py b/anki/lang.py index 39d7fbc45..215802a58 100644 --- a/anki/lang.py +++ b/anki/lang.py @@ -5,6 +5,7 @@ import os, sys, re import gettext import threading +from typing import Any langs = sorted([ ("Afrikaans", "af_ZA"), @@ -108,20 +109,20 @@ threadLocal = threading.local() currentLang = None currentTranslation = None -def localTranslation(): +def localTranslation() -> Any: "Return the translation local to this thread, or the default." if getattr(threadLocal, 'currentTranslation', None): return threadLocal.currentTranslation else: return currentTranslation -def _(str): +def _(str) -> Any: return localTranslation().gettext(str) -def ngettext(single, plural, n): +def ngettext(single, plural, n) -> Any: return localTranslation().ngettext(single, plural, n) -def langDir(): +def langDir() -> str: from anki.utils import isMac filedir = os.path.dirname(os.path.abspath(__file__)) if isMac: @@ -134,7 +135,7 @@ def langDir(): dir = os.path.abspath(os.path.join(filedir, "..", "locale")) return dir -def setLang(lang, local=True): +def setLang(lang, local=True) -> None: lang = mungeCode(lang) trans = gettext.translation( 'anki', langDir(), languages=[lang], fallback=True) @@ -146,18 +147,18 @@ def setLang(lang, local=True): currentLang = lang currentTranslation = trans -def getLang(): +def getLang() -> Any: "Return the language local to this thread, or the default." if getattr(threadLocal, 'currentLang', None): return threadLocal.currentLang else: return currentLang -def noHint(str): +def noHint(str) -> str: "Remove translation hint from end of string." return re.sub(r"(^.*?)( ?\(.+?\))?$", "\\1", str) -def mungeCode(code): +def mungeCode(code) -> Any: code = code.replace("-", "_") if code in compatMap: code = compatMap[code] diff --git a/anki/latex.py b/anki/latex.py index b60b0f170..2a5b001ff 100644 --- a/anki/latex.py +++ b/anki/latex.py @@ -6,6 +6,7 @@ import re, os, shutil, html from anki.utils import checksum, call, namedtmp, tmpdir, isMac, stripHTML from anki.hooks import addHook from anki.lang import _ +from typing import Any pngCommands = [ ["latex", "-interaction=nonstopmode", "tmp.tex"], @@ -28,7 +29,7 @@ regexps = { if isMac: os.environ['PATH'] += ":/usr/texbin:/Library/TeX/texbin" -def stripLatex(text): +def stripLatex(text) -> Any: for match in regexps['standard'].finditer(text): text = text.replace(match.group(), "") for match in regexps['expression'].finditer(text): @@ -37,7 +38,7 @@ def stripLatex(text): text = text.replace(match.group(), "") return text -def mungeQA(html, type, fields, model, data, col): +def mungeQA(html, type, fields, model, data, col) -> Any: "Convert TEXT with embedded latex tags to image links." for match in regexps['standard'].finditer(html): html = html.replace(match.group(), _imgLink(col, match.group(1), model)) @@ -50,7 +51,7 @@ def mungeQA(html, type, fields, model, data, col): "\\begin{displaymath}" + match.group(1) + "\\end{displaymath}", model)) return html -def _imgLink(col, latex, model): +def _imgLink(col, latex, model) -> Any: "Return an img link for LATEX, creating if necesssary." txt = _latexFromHtml(col, latex) @@ -75,13 +76,13 @@ def _imgLink(col, latex, model): else: return link -def _latexFromHtml(col, latex): +def _latexFromHtml(col, latex) -> Any: "Convert entities and fix newlines." latex = re.sub("|
", "\n", latex) latex = stripHTML(latex) return latex -def _buildImg(col, latex, fname, model): +def _buildImg(col, latex, fname, model) -> Any: # add header/footer latex = (model["latexPre"] + "\n" + latex + "\n" + @@ -129,7 +130,7 @@ package in the LaTeX header instead.""") % bad os.chdir(oldcwd) log.close() -def _errMsg(type, texpath): +def _errMsg(type, texpath) -> Any: msg = (_("Error executing %s.") % type) + "
" msg += (_("Generated file: %s") % texpath) + "
" try: diff --git a/anki/media.py b/anki/media.py index 42654b448..9fcaf674e 100644 --- a/anki/media.py +++ b/anki/media.py @@ -17,6 +17,9 @@ from anki.db import DB, DBError from anki.consts import * from anki.latex import mungeQA from anki.lang import _ +from typing import Any, List, Optional, Tuple, TypeVar, Union + +_T0 = TypeVar('_T0') class MediaManager: @@ -29,7 +32,7 @@ class MediaManager: ] regexps = soundRegexps + imgRegexps - def __init__(self, col, server): + def __init__(self, col, server) -> None: self.col = col if server: self._dir = None @@ -50,7 +53,7 @@ class MediaManager: # change database self.connect() - def connect(self): + def connect(self) -> None: if self.col.server: return path = self.dir()+".db2" @@ -61,7 +64,7 @@ class MediaManager: self._initDB() self.maybeUpgrade() - def _initDB(self): + def _initDB(self) -> None: self.db.executescript(""" create table media ( fname text not null primary key, @@ -75,7 +78,7 @@ create index idx_media_dirty on media (dirty); create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); """) - def maybeUpgrade(self): + def maybeUpgrade(self) -> None: oldpath = self.dir()+".db" if os.path.exists(oldpath): self.db.execute('attach "../collection.media.db" as old') @@ -102,7 +105,7 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); os.unlink(npath) os.rename("../collection.media.db", npath) - def close(self): + def close(self) -> None: if self.col.server: return self.db.close() @@ -115,16 +118,16 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); # may have been deleted pass - def _deleteDB(self): + def _deleteDB(self) -> None: path = self.db._path self.close() os.unlink(path) self.connect() - def dir(self): + def dir(self) -> Any: return self._dir - def _isFAT32(self): + def _isFAT32(self) -> Optional[bool]: if not isWin: return # pylint: disable=import-error @@ -141,11 +144,11 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); ########################################################################## # opath must be in unicode - def addFile(self, opath): + def addFile(self, opath) -> Any: with open(opath, "rb") as f: return self.writeData(opath, f.read()) - def writeData(self, opath, data, typeHint=None): + def writeData(self, opath, data, typeHint=None) -> Any: # if fname is a full path, use only the basename fname = os.path.basename(opath) @@ -193,7 +196,7 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); # String manipulation ########################################################################## - def filesInStr(self, mid, string, includeRemote=False): + def filesInStr(self, mid, string, includeRemote=False) -> List[str]: l = [] model = self.col.models.get(mid) strings = [] @@ -215,7 +218,7 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); l.append(fname) return l - def _expandClozes(self, string): + def _expandClozes(self, string) -> List[str]: ords = set(re.findall(r"{{c(\d+)::.+?}}", string)) strings = [] from anki.template.template import clozeReg @@ -233,17 +236,17 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); strings.append(re.sub(clozeReg%".+?", arepl, string)) return strings - def transformNames(self, txt, func): + def transformNames(self, txt, func) -> Any: for reg in self.regexps: txt = re.sub(reg, func, txt) return txt - def strip(self, txt): + def strip(self, txt: _T0) -> Union[str, _T0]: for reg in self.regexps: txt = re.sub(reg, "", txt) return txt - def escapeImages(self, string, unescape=False): + def escapeImages(self, string: _T0, unescape=False) -> Union[str, _T0]: if unescape: fn = urllib.parse.unquote else: @@ -261,7 +264,7 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); # Rebuilding DB ########################################################################## - def check(self, local=None): + def check(self, local=None) -> Any: "Return (missingFiles, unusedFiles)." mdir = self.dir() # gather all media references in NFC form @@ -335,7 +338,7 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); _("Anki does not support files in subfolders of the collection.media folder.")) return (nohave, unused, warnings) - def _normalizeNoteRefs(self, nid): + def _normalizeNoteRefs(self, nid) -> None: note = self.col.getNote(nid) for c, fld in enumerate(note.fields): nfc = unicodedata.normalize("NFC", fld) @@ -346,7 +349,7 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); # Copying on import ########################################################################## - def have(self, fname): + def have(self, fname) -> bool: return os.path.exists(os.path.join(self.dir(), fname)) # Illegal characters and paths @@ -354,7 +357,7 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); _illegalCharReg = re.compile(r'[][><:"/?*^\\|\0\r\n]') - def stripIllegal(self, str): + def stripIllegal(self, str) -> str: return re.sub(self._illegalCharReg, "", str) def hasIllegal(self, s: str): @@ -366,7 +369,7 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); return True return False - def cleanFilename(self, fname): + def cleanFilename(self, fname) -> str: fname = self.stripIllegal(fname) fname = self._cleanWin32Filename(fname) fname = self._cleanLongFilename(fname) @@ -375,7 +378,7 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); return fname - def _cleanWin32Filename(self, fname): + def _cleanWin32Filename(self, fname: _T0) -> Union[str, _T0]: if not isWin: return fname @@ -387,7 +390,7 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); return fname - def _cleanLongFilename(self, fname): + def _cleanLongFilename(self, fname) -> Any: # a fairly safe limit that should work on typical windows # paths and on eCryptfs partitions, even with a duplicate # suffix appended @@ -416,22 +419,22 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); # Tracking changes ########################################################################## - def findChanges(self): + def findChanges(self) -> None: "Scan the media folder if it's changed, and note any changes." if self._changed(): self._logChanges() - def haveDirty(self): + def haveDirty(self) -> Any: return self.db.scalar("select 1 from media where dirty=1 limit 1") - def _mtime(self, path): + def _mtime(self, path) -> int: return int(os.stat(path).st_mtime) - def _checksum(self, path): + def _checksum(self, path) -> str: with open(path, "rb") as f: return checksum(f.read()) - def _changed(self): + def _changed(self) -> int: "Return dir mtime if it has changed since the last findChanges()" # doesn't track edits, but user can add or remove a file to update mod = self.db.scalar("select dirMod from meta") @@ -440,7 +443,7 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); return False return mtime - def _logChanges(self): + def _logChanges(self) -> None: (added, removed) = self._changes() media = [] for f, mtime in added: @@ -453,7 +456,7 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); self.db.execute("update meta set dirMod = ?", self._mtime(self.dir())) self.db.commit() - def _changes(self): + def _changes(self) -> Tuple[List[Tuple[str, int]], List[str]]: self.cache = {} for (name, csum, mod) in self.db.execute( "select fname, csum, mtime from media where csum is not null"): @@ -515,37 +518,37 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); # Syncing-related ########################################################################## - def lastUsn(self): + def lastUsn(self) -> Any: return self.db.scalar("select lastUsn from meta") - def setLastUsn(self, usn): + def setLastUsn(self, usn) -> None: self.db.execute("update meta set lastUsn = ?", usn) self.db.commit() - def syncInfo(self, fname): + def syncInfo(self, fname) -> Any: ret = self.db.first( "select csum, dirty from media where fname=?", fname) return ret or (None, 0) - def markClean(self, fnames): + def markClean(self, fnames) -> None: for fname in fnames: self.db.execute( "update media set dirty=0 where fname=?", fname) - def syncDelete(self, fname): + def syncDelete(self, fname) -> None: if os.path.exists(fname): os.unlink(fname) self.db.execute("delete from media where fname=?", fname) - def mediaCount(self): + def mediaCount(self) -> Any: return self.db.scalar( "select count() from media where csum is not null") - def dirtyCount(self): + def dirtyCount(self) -> Any: return self.db.scalar( "select count() from media where dirty=1") - def forceResync(self): + def forceResync(self) -> None: self.db.execute("delete from media") self.db.execute("update meta set lastUsn=0,dirMod=0") self.db.commit() @@ -557,7 +560,7 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); # Media syncing: zips ########################################################################## - def mediaChangesZip(self): + def mediaChangesZip(self) -> Tuple[bytes, list]: f = io.BytesIO() z = zipfile.ZipFile(f, "w", compression=zipfile.ZIP_DEFLATED) @@ -590,7 +593,7 @@ create table meta (dirMod int, lastUsn int); insert into meta values (0, 0); z.close() return f.getvalue(), fnames - def addFilesFromZip(self, zipData): + def addFilesFromZip(self, zipData) -> int: "Extract zip data; true if finished." f = io.BytesIO(zipData) z = zipfile.ZipFile(f, "r") diff --git a/anki/models.py b/anki/models.py index 2bfc0d47b..005198324 100644 --- a/anki/models.py +++ b/anki/models.py @@ -10,6 +10,7 @@ from anki.lang import _ from anki.consts import * from anki.hooks import runHook import time +from typing import List, Optional, Tuple, Union # Models ########################################################################## @@ -75,17 +76,17 @@ class ModelManager: # Saving/loading registry ############################################################# - def __init__(self, col): + def __init__(self, col) -> None: self.col = col self.models = {} self.changed = False - def load(self, json_): + def load(self, json_) -> None: "Load registry from JSON." self.changed = False self.models = json.loads(json_) - def save(self, m=None, templates=False, updateReqs=True): + def save(self, m=None, templates=False, updateReqs=True) -> None: "Mark M modified if provided, and schedule registry flush." if m and m['id']: m['mod'] = intTime() @@ -97,7 +98,7 @@ class ModelManager: self.changed = True runHook("newModel") - def flush(self): + def flush(self) -> None: "Flush the registry if any models were changed." if self.changed: self.ensureNotEmpty() @@ -105,7 +106,7 @@ class ModelManager: json.dumps(self.models)) self.changed = False - def ensureNotEmpty(self): + def ensureNotEmpty(self) -> Optional[bool]: if not self.models: from anki.stdmodels import addBasicModel addBasicModel(self.col) @@ -114,37 +115,37 @@ class ModelManager: # Retrieving and creating models ############################################################# - def current(self, forDeck=True): + def current(self, forDeck=True) -> Any: "Get current model." m = self.get(self.col.decks.current().get('mid')) if not forDeck or not m: m = self.get(self.col.conf['curModel']) return m or list(self.models.values())[0] - def setCurrent(self, m): + def setCurrent(self, m) -> None: self.col.conf['curModel'] = m['id'] self.col.setMod() - def get(self, id): + def get(self, id) -> Any: "Get model with ID, or None." id = str(id) if id in self.models: return self.models[id] - def all(self): + def all(self) -> List: "Get all models." return list(self.models.values()) - def allNames(self): + def allNames(self) -> List: return [m['name'] for m in self.all()] - def byName(self, name): + def byName(self, name) -> Any: "Get model with NAME." for m in list(self.models.values()): if m['name'] == name: return m - def new(self, name): + def new(self, name: str) -> Dict[str, Any]: "Create a new model, save it in the registry, and return it." # caller should call save() after modifying m = defaultModel.copy() @@ -156,7 +157,7 @@ class ModelManager: m['id'] = None return m - def rem(self, m): + def rem(self, m) -> None: "Delete model, and all its cards/notes." self.col.modSchema(check=True) current = self.current()['id'] == m['id'] @@ -171,52 +172,52 @@ select id from cards where nid in (select id from notes where mid = ?)""", if current: self.setCurrent(list(self.models.values())[0]) - def add(self, m): + def add(self, m) -> None: self._setID(m) self.update(m) self.setCurrent(m) self.save(m) - def ensureNameUnique(self, m): + def ensureNameUnique(self, m) -> None: for mcur in self.all(): if (mcur['name'] == m['name'] and mcur['id'] != m['id']): m['name'] += "-" + checksum(str(time.time()))[:5] break - def update(self, m): + def update(self, m) -> None: "Add or update an existing model. Used for syncing and merging." self.ensureNameUnique(m) self.models[str(m['id'])] = m # mark registry changed, but don't bump mod time self.save() - def _setID(self, m): + def _setID(self, m) -> None: while 1: id = str(intTime(1000)) if id not in self.models: break m['id'] = id - def have(self, id): + def have(self, id) -> bool: return str(id) in self.models - def ids(self): + def ids(self) -> List[str]: return list(self.models.keys()) # Tools ################################################## - def nids(self, m): + def nids(self, m) -> Any: "Note ids for M." return self.col.db.list( "select id from notes where mid = ?", m['id']) - def useCount(self, m): + def useCount(self, m) -> Any: "Number of note using M." return self.col.db.scalar( "select count() from notes where mid = ?", m['id']) - def tmplUseCount(self, m, ord): + def tmplUseCount(self, m, ord) -> Any: return self.col.db.scalar(""" select count() from cards, notes where cards.nid = notes.id and notes.mid = ? and cards.ord = ?""", m['id'], ord) @@ -224,7 +225,7 @@ and notes.mid = ? and cards.ord = ?""", m['id'], ord) # Copying ################################################## - def copy(self, m): + def copy(self, m) -> Any: "Copy, save and return." m2 = copy.deepcopy(m) m2['name'] = _("%s copy") % m2['name'] @@ -234,30 +235,30 @@ and notes.mid = ? and cards.ord = ?""", m['id'], ord) # Fields ################################################## - def newField(self, name): + def newField(self, name) -> Dict[str, Any]: assert(isinstance(name, str)) f = defaultField.copy() f['name'] = name return f - def fieldMap(self, m): + def fieldMap(self, m) -> Dict[Any, Tuple[Any, Any]]: "Mapping of field name -> (ord, field)." return dict((f['name'], (f['ord'], f)) for f in m['flds']) - def fieldNames(self, m): + def fieldNames(self, m) -> List: return [f['name'] for f in m['flds']] - def sortIdx(self, m): + def sortIdx(self, m) -> Any: return m['sortf'] - def setSortIdx(self, m, idx): + def setSortIdx(self, m, idx) -> None: assert 0 <= idx < len(m['flds']) self.col.modSchema(check=True) m['sortf'] = idx self.col.updateFieldCache(self.nids(m)) self.save(m, updateReqs=False) - def addField(self, m, field): + def addField(self, m, field) -> None: # only mod schema if model isn't new if m['id']: self.col.modSchema(check=True) @@ -269,7 +270,7 @@ and notes.mid = ? and cards.ord = ?""", m['id'], ord) return fields self._transformFields(m, add) - def remField(self, m, field): + def remField(self, m, field) -> None: self.col.modSchema(check=True) # save old sort field sortFldName = m['flds'][m['sortf']]['name'] @@ -292,7 +293,7 @@ and notes.mid = ? and cards.ord = ?""", m['id'], ord) # saves self.renameField(m, field, None) - def moveField(self, m, field, idx): + def moveField(self, m, field, idx) -> None: self.col.modSchema(check=True) oldidx = m['flds'].index(field) if oldidx == idx: @@ -313,7 +314,7 @@ and notes.mid = ? and cards.ord = ?""", m['id'], ord) return fields self._transformFields(m, move) - def renameField(self, m, field, newName): + def renameField(self, m, field, newName) -> None: self.col.modSchema(check=True) pat = r'{{([^{}]*)([:#^/]|[^:#/^}][^:}]*?:|)%s}}' def wrap(txt): @@ -331,11 +332,11 @@ and notes.mid = ? and cards.ord = ?""", m['id'], ord) field['name'] = newName self.save(m) - def _updateFieldOrds(self, m): + def _updateFieldOrds(self, m) -> None: for c, f in enumerate(m['flds']): f['ord'] = c - def _transformFields(self, m, fn): + def _transformFields(self, m, fn) -> None: # model hasn't been added yet? if not m['id']: return @@ -350,12 +351,12 @@ and notes.mid = ? and cards.ord = ?""", m['id'], ord) # Templates ################################################## - def newTemplate(self, name): + def newTemplate(self, name: str) -> Dict[str, Any]: t = defaultTemplate.copy() t['name'] = name return t - def addTemplate(self, m, template): + def addTemplate(self, m, template) -> None: "Note: should col.genCards() afterwards." if m['id']: self.col.modSchema(check=True) @@ -363,7 +364,7 @@ and notes.mid = ? and cards.ord = ?""", m['id'], ord) self._updateTemplOrds(m) self.save(m) - def remTemplate(self, m, template): + def remTemplate(self, m, template) -> bool: "False if removing template would leave orphan notes." assert len(m['tmpls']) > 1 # find cards using this template @@ -393,11 +394,11 @@ update cards set ord = ord - 1, usn = ?, mod = ? self.save(m) return True - def _updateTemplOrds(self, m): + def _updateTemplOrds(self, m) -> None: for c, t in enumerate(m['tmpls']): t['ord'] = c - def moveTemplate(self, m, template, idx): + def moveTemplate(self, m, template, idx) -> None: oldidx = m['tmpls'].index(template) if oldidx == idx: return @@ -416,7 +417,7 @@ update cards set ord = (case %s end),usn=?,mod=? where nid in ( select id from notes where mid = ?)""" % " ".join(map), self.col.usn(), intTime(), m['id']) - def _syncTemplates(self, m): + def _syncTemplates(self, m) -> None: rem = self.col.genCards(self.nids(m)) # Model changing @@ -424,7 +425,7 @@ select id from notes where mid = ?)""" % " ".join(map), # - maps are ord->ord, and there should not be duplicate targets # - newModel should be self if model is not changing - def change(self, m, nids, newModel, fmap, cmap): + def change(self, m, nids, newModel, fmap, cmap) -> None: self.col.modSchema(check=True) assert newModel['id'] == m['id'] or (fmap and cmap) if fmap: @@ -433,7 +434,7 @@ select id from notes where mid = ?)""" % " ".join(map), self._changeCards(nids, m, newModel, cmap) self.col.genCards(nids) - def _changeNotes(self, nids, newModel, map): + def _changeNotes(self, nids, newModel, map) -> None: d = [] nfields = len(newModel['flds']) for (nid, flds) in self.col.db.execute( @@ -452,7 +453,7 @@ select id from notes where mid = ?)""" % " ".join(map), "update notes set flds=:flds,mid=:mid,mod=:m,usn=:u where id = :nid", d) self.col.updateFieldCache(nids) - def _changeCards(self, nids, oldModel, newModel, map): + def _changeCards(self, nids, oldModel, newModel, map) -> None: d = [] deleted = [] for (cid, ord) in self.col.db.execute( @@ -482,7 +483,7 @@ select id from notes where mid = ?)""" % " ".join(map), # Schema hash ########################################################################## - def scmhash(self, m): + def scmhash(self, m) -> str: "Return a hash of the schema, to see if models are compatible." s = "" for f in m['flds']: @@ -494,7 +495,7 @@ select id from notes where mid = ?)""" % " ".join(map), # Required field/text cache ########################################################################## - def _updateRequired(self, m): + def _updateRequired(self, m) -> None: if m['type'] == MODEL_CLOZE: # nothing to do return @@ -505,7 +506,7 @@ select id from notes where mid = ?)""" % " ".join(map), req.append([t['ord'], ret[0], ret[1]]) m['req'] = req - def _reqForTemplate(self, m, flds, t): + def _reqForTemplate(self, m, flds, t) -> Tuple[Union[str, List[int]], ...]: a = [] b = [] for f in flds: @@ -542,7 +543,7 @@ select id from notes where mid = ?)""" % " ".join(map), req.append(i) return type, req - def availOrds(self, m, flds): + def availOrds(self, m, flds) -> List: "Given a joined field string, return available template ordinals." if m['type'] == MODEL_CLOZE: return self._availClozeOrds(m, flds) @@ -576,7 +577,7 @@ select id from notes where mid = ?)""" % " ".join(map), avail.append(ord) return avail - def _availClozeOrds(self, m, flds, allowEmpty=True): + def _availClozeOrds(self, m, flds, allowEmpty=True) -> List: sflds = splitFields(flds) map = self.fieldMap(m) ords = set() @@ -598,7 +599,7 @@ select id from notes where mid = ?)""" % " ".join(map), # Sync handling ########################################################################## - def beforeUpload(self): + def beforeUpload(self) -> None: for m in self.all(): m['usn'] = 0 self.save() diff --git a/anki/mpv.py b/anki/mpv.py index 13267deb0..c111f1343 100644 --- a/anki/mpv.py +++ b/anki/mpv.py @@ -56,6 +56,7 @@ class MPVTimeoutError(MPVError): pass from anki.utils import isWin +from typing import Any if isWin: # pylint: disable=import-error import win32file, win32pipe, pywintypes, winerror # pytype: disable=import-error @@ -76,7 +77,7 @@ class MPVBase: "--keep-open=no", ] - def __init__(self, window_id=None, debug=False): + def __init__(self, window_id=None, debug=False) -> None: self.window_id = window_id self.debug = debug @@ -87,18 +88,18 @@ class MPVBase: self._prepare_thread() self._start_thread() - def __del__(self): + def __del__(self) -> None: self._stop_thread() self._stop_process() self._stop_socket() - def _thread_id(self): + def _thread_id(self) -> int: return threading.get_ident() # # Process # - def _prepare_process(self): + def _prepare_process(self) -> None: """Prepare the argument list for the mpv process. """ self.argv = [self.executable] @@ -107,12 +108,12 @@ class MPVBase: if self.window_id is not None: self.argv += ["--wid", str(self.window_id)] - def _start_process(self): + def _start_process(self) -> None: """Start the mpv process. """ self._proc = subprocess.Popen(self.argv, env=self.popenEnv) - def _stop_process(self): + def _stop_process(self) -> None: """Stop the mpv process. """ if hasattr(self, "_proc"): @@ -125,7 +126,7 @@ class MPVBase: # # Socket communication # - def _prepare_socket(self): + def _prepare_socket(self) -> None: """Create a random socket filename which we pass to mpv with the --input-unix-socket option. """ @@ -136,7 +137,7 @@ class MPVBase: os.close(fd) os.remove(self._sock_filename) - def _start_socket(self): + def _start_socket(self) -> None: """Wait for the mpv process to create the unix socket and finish startup. """ @@ -173,7 +174,7 @@ class MPVBase: else: raise MPVProcessError("unable to start process") - def _stop_socket(self): + def _stop_socket(self) -> None: """Clean up the socket. """ if hasattr(self, "_sock"): @@ -184,7 +185,7 @@ class MPVBase: except OSError: pass - def _prepare_thread(self): + def _prepare_thread(self) -> None: """Set up the queues for the communication threads. """ self._request_queue = Queue(1) @@ -192,14 +193,14 @@ class MPVBase: self._event_queue = Queue() self._stop_event = threading.Event() - def _start_thread(self): + def _start_thread(self) -> None: """Start up the communication threads. """ self._thread = threading.Thread(target=self._reader) self._thread.daemon = True self._thread.start() - def _stop_thread(self): + def _stop_thread(self) -> None: """Stop the communication threads. """ if hasattr(self, "_stop_event"): @@ -207,7 +208,7 @@ class MPVBase: if hasattr(self, "_thread"): self._thread.join() - def _reader(self): + def _reader(self) -> None: """Read the incoming json messages from the unix socket that is connected to the mpv process. Pass them on to the message handler. """ @@ -249,21 +250,21 @@ class MPVBase: # # Message handling # - def _compose_message(self, message): + def _compose_message(self, message) -> bytes: """Return a json representation from a message dictionary. """ # XXX may be strict is too strict ;-) data = json.dumps(message) return data.encode("utf8", "strict") + b"\n" - def _parse_message(self, data): + def _parse_message(self, data) -> Any: """Return a message dictionary from a json representation. """ # XXX may be strict is too strict ;-) data = data.decode("utf8", "strict") return json.loads(data) - def _handle_message(self, message): + def _handle_message(self, message) -> None: """Handle different types of incoming messages, i.e. responses to commands or asynchronous events. """ @@ -283,7 +284,7 @@ class MPVBase: else: raise MPVCommunicationError("invalid message %r" % message) - def _send_message(self, message, timeout=None): + def _send_message(self, message, timeout=None) -> None: """Send a message/command to the mpv process, message must be a dictionary of the form {"command": ["arg1", "arg2", ...]}. Responses from the mpv process must be collected using _get_response(). @@ -320,7 +321,7 @@ class MPVBase: raise MPVCommunicationError("broken sender socket") data = data[size:] - def _get_response(self, timeout=None): + def _get_response(self, timeout=None) -> Any: """Collect the response message to a previous request. If there was an error a MPVCommandError exception is raised, otherwise the command specific data is returned. @@ -335,7 +336,7 @@ class MPVBase: else: return message.get("data") - def _get_event(self, timeout=None): + def _get_event(self, timeout=None) -> Any: """Collect a single event message that has been received out-of-band from the mpv process. If a timeout is specified and there have not been any events during that period, None is returned. @@ -345,7 +346,7 @@ class MPVBase: except Empty: return None - def _send_request(self, message, timeout=None, _retry=1): + def _send_request(self, message, timeout=None, _retry=1) -> Any: """Send a command to the mpv process and collect the result. """ self.ensure_running() @@ -365,12 +366,12 @@ class MPVBase: # # Public API # - def is_running(self): + def is_running(self) -> bool: """Return True if the mpv process is still active. """ return self._proc.poll() is None - def ensure_running(self): + def ensure_running(self) -> None: if not self.is_running(): self._stop_thread() self._stop_process() @@ -382,7 +383,7 @@ class MPVBase: self._prepare_thread() self._start_thread() - def close(self): + def close(self) -> None: """Shutdown the mpv process and our communication setup. """ if self.is_running(): @@ -413,7 +414,7 @@ class MPV(MPVBase): threads to the same MPV instance are synchronized. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._callbacks = {} @@ -463,7 +464,7 @@ class MPV(MPVBase): # # Event/callback API # - def _event_reader(self): + def _event_reader(self) -> None: """Collect incoming event messages and call the event handler. """ while not self._stop_event.is_set(): @@ -473,7 +474,7 @@ class MPV(MPVBase): self._handle_event(message) - def _handle_event(self, message): + def _handle_event(self, message) -> None: """Lookup and call the callbacks for a particular event message. """ if message["event"] == "property-change": @@ -487,7 +488,7 @@ class MPV(MPVBase): else: callback() - def register_callback(self, name, callback): + def register_callback(self, name, callback) -> None: """Register a function `callback` for the event `name`. """ try: @@ -497,7 +498,7 @@ class MPV(MPVBase): self._callbacks.setdefault(name, []).append(callback) - def unregister_callback(self, name, callback): + def unregister_callback(self, name, callback) -> None: """Unregister a previously registered function `callback` for the event `name`. """ @@ -511,7 +512,7 @@ class MPV(MPVBase): except ValueError: raise MPVError("callback %r not registered for event %r" % (callback, name)) - def register_property_callback(self, name, callback): + def register_property_callback(self, name, callback) -> int: """Register a function `callback` for the property-change event on property `name`. """ @@ -533,7 +534,7 @@ class MPV(MPVBase): self._property_serials[(name, callback)] = serial return serial - def unregister_property_callback(self, name, callback): + def unregister_property_callback(self, name, callback) -> None: """Unregister a previously registered function `callback` for the property-change event on property `name`. """ @@ -553,17 +554,17 @@ class MPV(MPVBase): # # Public API # - def command(self, *args, timeout=1): + def command(self, *args, timeout=1) -> Any: """Execute a single command on the mpv process and return the result. """ return self._send_request({"command": list(args)}, timeout=timeout) - def get_property(self, name): + def get_property(self, name) -> Any: """Return the value of property `name`. """ return self.command("get_property", name) - def set_property(self, name, value): + def set_property(self, name, value) -> Any: """Set the value of property `name`. """ return self.command("set_property", name, value) diff --git a/anki/notes.py b/anki/notes.py index d83c8f55f..b1bd5953e 100644 --- a/anki/notes.py +++ b/anki/notes.py @@ -4,10 +4,11 @@ from anki.utils import fieldChecksum, intTime, \ joinFields, splitFields, stripHTMLMedia, timestampID, guid64 +from typing import Any, List, Tuple class Note: - def __init__(self, col, model=None, id=None): + def __init__(self, col, model=None, id=None) -> None: assert not (model and id) self.col = col self.newlyAdded = False @@ -26,7 +27,7 @@ class Note: self._fmap = self.col.models.fieldMap(self._model) self.scm = self.col.scm - def load(self): + def load(self) -> None: (self.guid, self.mid, self.mod, @@ -43,7 +44,7 @@ from notes where id = ?""", self.id) self._fmap = self.col.models.fieldMap(self._model) self.scm = self.col.scm - def flush(self, mod=None): + def flush(self, mod=None) -> None: "If fields or tags have changed, write changes to disk." assert self.scm == self.col.scm self._preFlush() @@ -66,57 +67,57 @@ insert or replace into notes values (?,?,?,?,?,?,?,?,?,?,?)""", self.col.tags.register(self.tags) self._postFlush() - def joinedFields(self): + def joinedFields(self) -> str: return joinFields(self.fields) - def cards(self): + def cards(self) -> List: return [self.col.getCard(id) for id in self.col.db.list( "select id from cards where nid = ? order by ord", self.id)] - def model(self): + def model(self) -> Any: return self._model # Dict interface ################################################## - def keys(self): + def keys(self) -> List: return list(self._fmap.keys()) - def values(self): + def values(self) -> Any: return self.fields - def items(self): + def items(self) -> List[Tuple[Any, Any]]: return [(f['name'], self.fields[ord]) for ord, f in sorted(self._fmap.values())] - def _fieldOrd(self, key): + def _fieldOrd(self, key) -> Any: try: return self._fmap[key][0] except: raise KeyError(key) - def __getitem__(self, key): + def __getitem__(self, key) -> Any: return self.fields[self._fieldOrd(key)] - def __setitem__(self, key, value): + def __setitem__(self, key, value) -> None: self.fields[self._fieldOrd(key)] = value - def __contains__(self, key): + def __contains__(self, key) -> bool: return key in list(self._fmap.keys()) # Tags ################################################## - def hasTag(self, tag): + def hasTag(self, tag) -> Any: return self.col.tags.inList(tag, self.tags) - def stringTags(self): + def stringTags(self) -> Any: return self.col.tags.join(self.col.tags.canonify(self.tags)) - def setTagsFromStr(self, str): + def setTagsFromStr(self, str) -> None: self.tags = self.col.tags.split(str) - def delTag(self, tag): + def delTag(self, tag) -> None: rem = [] for t in self.tags: if t.lower() == tag.lower(): @@ -124,14 +125,14 @@ insert or replace into notes values (?,?,?,?,?,?,?,?,?,?,?)""", for r in rem: self.tags.remove(r) - def addTag(self, tag): + def addTag(self, tag) -> None: # duplicates will be stripped on save self.tags.append(tag) # Unique/duplicate check ################################################## - def dupeOrEmpty(self): + def dupeOrEmpty(self) -> int: "1 if first is empty; 2 if first is a duplicate, False otherwise." val = self.fields[0] if not val.strip(): @@ -149,12 +150,12 @@ insert or replace into notes values (?,?,?,?,?,?,?,?,?,?,?)""", # Flushing cloze notes ################################################## - def _preFlush(self): + def _preFlush(self) -> None: # have we been added yet? self.newlyAdded = not self.col.db.scalar( "select 1 from cards where nid = ?", self.id) - def _postFlush(self): + def _postFlush(self) -> None: # generate missing cards if not self.newlyAdded: rem = self.col.genCards([self.id]) diff --git a/anki/sched.py b/anki/sched.py index dcce7963f..ae2ad53ed 100644 --- a/anki/sched.py +++ b/anki/sched.py @@ -13,6 +13,9 @@ from anki.utils import ids2str, intTime, fmtTimeSpan from anki.lang import _ from anki.consts import * from anki.hooks import runHook +from typing import Any, List, Optional, Tuple, TypeVar + +_T = TypeVar('_T') # queue types: 0=new/cram, 1=lrn, 2=rev, 3=day lrn, -1=suspended, -2=buried # revlog types: 0=lrn, 1=rev, 2=relrn, 3=cram @@ -24,7 +27,7 @@ class Scheduler: _spreadRev = True _burySiblingsOnAnswer = True - def __init__(self, col): + def __init__(self, col) -> None: self.col = col self.queueLimit = 50 self.reportLimit = 1000 @@ -33,7 +36,7 @@ class Scheduler: self._haveQueues = False self._updateCutoff() - def getCard(self): + def getCard(self) -> Any: "Pop the next card from the queue. None if finished." self._checkDay() if not self._haveQueues: @@ -47,14 +50,14 @@ class Scheduler: card.startTimer() return card - def reset(self): + def reset(self) -> None: self._updateCutoff() self._resetLrn() self._resetRev() self._resetNew() self._haveQueues = True - def answerCard(self, card, ease): + def answerCard(self, card, ease) -> None: self.col.log() assert 1 <= ease <= 4 self.col.markReview(card) @@ -93,7 +96,7 @@ class Scheduler: card.usn = self.col.usn() card.flushSched() - def counts(self, card=None): + def counts(self, card=None) -> tuple: counts = [self.newCount, self.lrnCount, self.revCount] if card: idx = self.countIdx(card) @@ -103,7 +106,7 @@ class Scheduler: counts[idx] += 1 return tuple(counts) - def dueForecast(self, days=7): + def dueForecast(self, days=7) -> List: "Return counts over next DAYS. Includes today." daysd = dict(self.col.db.all(""" select due, count() from cards @@ -121,12 +124,12 @@ order by due""" % self._deckLimit(), ret = [x[1] for x in sorted(daysd.items())] return ret - def countIdx(self, card): + def countIdx(self, card) -> Any: if card.queue == 3: return 1 return card.queue - def answerButtons(self, card): + def answerButtons(self, card) -> int: if card.odue: # normal review in dyn deck? if card.odid and card.queue == 2: @@ -140,7 +143,7 @@ order by due""" % self._deckLimit(), else: return 3 - def unburyCards(self): + def unburyCards(self) -> None: "Unbury cards." self.col.conf['lastUnburied'] = self.today self.col.log( @@ -148,7 +151,7 @@ order by due""" % self._deckLimit(), self.col.db.execute( "update cards set queue=type where queue = -2") - def unburyCardsForDeck(self): + def unburyCardsForDeck(self) -> None: sids = ids2str(self.col.decks.active()) self.col.log( self.col.db.list("select id from cards where queue = -2 and did in %s" @@ -160,7 +163,7 @@ order by due""" % self._deckLimit(), # Rev/lrn/time daily stats ########################################################################## - def _updateStats(self, card, type, cnt=1): + def _updateStats(self, card, type, cnt=1) -> None: key = type+"Today" for g in ([self.col.decks.get(card.did)] + self.col.decks.parents(card.did)): @@ -168,7 +171,7 @@ order by due""" % self._deckLimit(), g[key][1] += cnt self.col.decks.save(g) - def extendLimits(self, new, rev): + def extendLimits(self, new, rev) -> None: cur = self.col.decks.current() parents = self.col.decks.parents(cur['id']) children = [self.col.decks.get(did) for (name, did) in @@ -179,7 +182,7 @@ order by due""" % self._deckLimit(), g['revToday'][1] -= rev self.col.decks.save(g) - def _walkingCount(self, limFn=None, cntFn=None): + def _walkingCount(self, limFn=None, cntFn=None) -> Any: tot = 0 pcounts = {} # for each of the active decks @@ -213,7 +216,7 @@ order by due""" % self._deckLimit(), # Deck list ########################################################################## - def deckDueList(self): + def deckDueList(self) -> List[list]: "Returns [deckname, did, rev, lrn, new]" self._checkDay() self.col.decks.checkIntegrity() @@ -247,10 +250,10 @@ order by due""" % self._deckLimit(), lims[deck['name']] = [nlim, rlim] return data - def deckDueTree(self): + def deckDueTree(self) -> Any: return self._groupChildren(self.deckDueList()) - def _groupChildren(self, grps): + def _groupChildren(self, grps) -> Tuple[Tuple[Any, Any, Any, Any, Any, Any], ...]: # first, split the group names into components for g in grps: g[0] = g[0].split("::") @@ -259,7 +262,7 @@ order by due""" % self._deckLimit(), # then run main function return self._groupChildrenMain(grps) - def _groupChildrenMain(self, grps): + def _groupChildrenMain(self, grps) -> Tuple[Tuple[Any, Any, Any, Any, Any, Any], ...]: tree = [] # group and recurse def key(grp): @@ -300,7 +303,7 @@ order by due""" % self._deckLimit(), # Getting the next card ########################################################################## - def _getCard(self): + def _getCard(self) -> Any: "Return the next due card id, or None." # learning card due? c = self._getLrnCard() @@ -329,19 +332,19 @@ order by due""" % self._deckLimit(), # New cards ########################################################################## - def _resetNewCount(self): + def _resetNewCount(self) -> None: cntFn = lambda did, lim: self.col.db.scalar(""" select count() from (select 1 from cards where did = ? and queue = 0 limit ?)""", did, lim) self.newCount = self._walkingCount(self._deckNewLimitSingle, cntFn) - def _resetNew(self): + def _resetNew(self) -> None: self._resetNewCount() self._newDids = self.col.decks.active()[:] self._newQueue = [] self._updateNewCardRatio() - def _fillNew(self): + def _fillNew(self) -> Any: if self._newQueue: return True if not self.newCount: @@ -365,12 +368,12 @@ did = ? and queue = 0 limit ?)""", did, lim) self._resetNew() return self._fillNew() - def _getNewCard(self): + def _getNewCard(self) -> Any: if self._fillNew(): self.newCount -= 1 return self.col.getCard(self._newQueue.pop()) - def _updateNewCardRatio(self): + def _updateNewCardRatio(self) -> None: if self.col.conf['newSpread'] == NEW_CARDS_DISTRIBUTE: if self.newCount: self.newCardModulus = ( @@ -381,7 +384,7 @@ did = ? and queue = 0 limit ?)""", did, lim) return self.newCardModulus = 0 - def _timeForNewCard(self): + def _timeForNewCard(self) -> Optional[int]: "True if it's time to display a new card when distributing." if not self.newCount: return False @@ -392,7 +395,7 @@ did = ? and queue = 0 limit ?)""", did, lim) elif self.newCardModulus: return self.reps and self.reps % self.newCardModulus == 0 - def _deckNewLimit(self, did, fn=None): + def _deckNewLimit(self, did, fn=None) -> Any: if not fn: fn = self._deckNewLimitSingle sel = self.col.decks.get(did) @@ -406,7 +409,7 @@ did = ? and queue = 0 limit ?)""", did, lim) lim = min(rem, lim) return lim - def _newForDeck(self, did, lim): + def _newForDeck(self, did, lim) -> Any: "New count for a single deck." if not lim: return 0 @@ -415,14 +418,14 @@ did = ? and queue = 0 limit ?)""", did, lim) select count() from (select 1 from cards where did = ? and queue = 0 limit ?)""", did, lim) - def _deckNewLimitSingle(self, g): + def _deckNewLimitSingle(self, g) -> Any: "Limit for deck without parent limits." if g['dyn']: return self.reportLimit c = self.col.decks.confForDid(g['id']) return max(0, c['new']['perDay'] - g['newToday'][1]) - def totalNewForCurrentDeck(self): + def totalNewForCurrentDeck(self) -> Any: return self.col.db.scalar( """ select count() from cards where id in ( @@ -432,7 +435,7 @@ select id from cards where did in %s and queue = 0 limit ?)""" # Learning queues ########################################################################## - def _resetLrnCount(self): + def _resetLrnCount(self) -> None: # sub-day self.lrnCount = self.col.db.scalar(""" select sum(left/1000) from (select left from cards where @@ -445,14 +448,14 @@ select count() from cards where did in %s and queue = 3 and due <= ? limit %d""" % (self._deckLimit(), self.reportLimit), self.today) - def _resetLrn(self): + def _resetLrn(self) -> None: self._resetLrnCount() self._lrnQueue = [] self._lrnDayQueue = [] self._lrnDids = self.col.decks.active()[:] # sub-day learning - def _fillLrn(self): + def _fillLrn(self) -> Any: if not self.lrnCount: return False if self._lrnQueue: @@ -465,7 +468,7 @@ limit %d""" % (self._deckLimit(), self.reportLimit), lim=self.dayCutoff) self._lrnQueue.sort() return self._lrnQueue - def _getLrnCard(self, collapse=False): + def _getLrnCard(self, collapse=False) -> Any: if self._fillLrn(): cutoff = time.time() if collapse: @@ -477,7 +480,7 @@ limit %d""" % (self._deckLimit(), self.reportLimit), lim=self.dayCutoff) return card # daily learning - def _fillLrnDay(self): + def _fillLrnDay(self) -> Optional[bool]: if not self.lrnCount: return False if self._lrnDayQueue: @@ -501,12 +504,12 @@ did = ? and queue = 3 and due <= ? limit ?""", # nothing left in the deck; move to next self._lrnDids.pop(0) - def _getLrnDayCard(self): + def _getLrnDayCard(self) -> Any: if self._fillLrnDay(): self.lrnCount -= 1 return self.col.getCard(self._lrnDayQueue.pop()) - def _answerLrnCard(self, card, ease): + def _answerLrnCard(self, card, ease) -> None: # ease 1=no, 2=yes, 3=remove conf = self._lrnConf(card) if card.odid and not card.wasNew: @@ -568,7 +571,7 @@ did = ? and queue = 3 and due <= ? limit ?""", card.queue = 3 self._logLrn(card, ease, conf, leaving, type, lastLeft) - def _delayForGrade(self, conf, left): + def _delayForGrade(self, conf, left) -> Any: left = left % 1000 try: delay = conf['delays'][-left] @@ -580,13 +583,13 @@ did = ? and queue = 3 and due <= ? limit ?""", delay = 1 return delay*60 - def _lrnConf(self, card): + def _lrnConf(self, card) -> Any: if card.type == 2: return self._lapseConf(card) else: return self._newConf(card) - def _rescheduleAsRev(self, card, conf, early): + def _rescheduleAsRev(self, card, conf, early) -> None: lapse = card.type == 2 if lapse: if self._resched(card): @@ -609,7 +612,7 @@ did = ? and queue = 3 and due <= ? limit ?""", card.queue = card.type = 0 card.due = self.col.nextID("pos") - def _startingLeft(self, card): + def _startingLeft(self, card) -> int: if card.type == 2: conf = self._lapseConf(card) else: @@ -618,7 +621,7 @@ did = ? and queue = 3 and due <= ? limit ?""", tod = self._leftToday(conf['delays'], tot) return tot + tod*1000 - def _leftToday(self, delays, left, now=None): + def _leftToday(self, delays, left, now=None) -> int: "The number of steps that can be completed by the day cutoff." if not now: now = intTime() @@ -631,7 +634,7 @@ did = ? and queue = 3 and due <= ? limit ?""", ok = i return ok+1 - def _graduatingIvl(self, card, conf, early, adj=True): + def _graduatingIvl(self, card, conf, early, adj=True) -> Any: if card.type == 2: # lapsed card being relearnt if card.odid: @@ -649,13 +652,13 @@ did = ? and queue = 3 and due <= ? limit ?""", else: return ideal - def _rescheduleNew(self, card, conf, early): + def _rescheduleNew(self, card, conf, early) -> None: "Reschedule a new card that's graduated for the first time." card.ivl = self._graduatingIvl(card, conf, early) card.due = self.today+card.ivl card.factor = conf['initialFactor'] - def _logLrn(self, card, ease, conf, leaving, type, lastLeft): + def _logLrn(self, card, ease, conf, leaving, type, lastLeft) -> None: lastIvl = -(self._delayForGrade(conf, lastLeft)) ivl = card.ivl if leaving else -(self._delayForGrade(conf, card.left)) def log(): @@ -670,7 +673,7 @@ did = ? and queue = 3 and due <= ? limit ?""", time.sleep(0.01) log() - def removeLrn(self, ids=None): + def removeLrn(self, ids=None) -> None: "Remove cards from the learning queues." if ids: extra = " and id in "+ids2str(ids) @@ -689,7 +692,7 @@ where queue in (1,3) and type = 2 self.forgetCards(self.col.db.list( "select id from cards where queue in (1,3) %s" % extra)) - def _lrnForDeck(self, did): + def _lrnForDeck(self, did) -> Any: cnt = self.col.db.scalar( """ select sum(left/1000) from @@ -705,16 +708,16 @@ and due <= ? limit ?)""", # Reviews ########################################################################## - def _deckRevLimit(self, did): + def _deckRevLimit(self, did) -> Any: return self._deckNewLimit(did, self._deckRevLimitSingle) - def _deckRevLimitSingle(self, d): + def _deckRevLimitSingle(self, d) -> Any: if d['dyn']: return self.reportLimit c = self.col.decks.confForDid(d['id']) return max(0, c['rev']['perDay'] - d['revToday'][1]) - def _revForDeck(self, did, lim): + def _revForDeck(self, did, lim) -> Any: lim = min(lim, self.reportLimit) return self.col.db.scalar( """ @@ -723,7 +726,7 @@ select count() from and due <= ? limit ?)""", did, self.today, lim) - def _resetRevCount(self): + def _resetRevCount(self) -> None: def cntFn(did, lim): return self.col.db.scalar(""" select count() from (select id from cards where @@ -732,12 +735,12 @@ did = ? and queue = 2 and due <= ? limit %d)""" % lim, self.revCount = self._walkingCount( self._deckRevLimitSingle, cntFn) - def _resetRev(self): + def _resetRev(self) -> None: self._resetRevCount() self._revQueue = [] self._revDids = self.col.decks.active()[:] - def _fillRev(self): + def _fillRev(self) -> Any: if self._revQueue: return True if not self.revCount: @@ -774,12 +777,12 @@ did = ? and queue = 2 and due <= ? limit ?""", self._resetRev() return self._fillRev() - def _getRevCard(self): + def _getRevCard(self) -> Any: if self._fillRev(): self.revCount -= 1 return self.col.getCard(self._revQueue.pop()) - def totalRevForCurrentDeck(self): + def totalRevForCurrentDeck(self) -> Any: return self.col.db.scalar( """ select count() from cards where id in ( @@ -789,7 +792,7 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" # Answering a review card ########################################################################## - def _answerRevCard(self, card, ease): + def _answerRevCard(self, card, ease) -> None: delay = 0 if ease == 1: delay = self._rescheduleLapse(card) @@ -797,7 +800,7 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" self._rescheduleRev(card, ease) self._logRev(card, ease, delay) - def _rescheduleLapse(self, card): + def _rescheduleLapse(self, card) -> Any: conf = self._lapseConf(card) card.lastIvl = card.ivl if self._resched(card): @@ -833,10 +836,10 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" card.queue = 3 return delay - def _nextLapseIvl(self, card, conf): + def _nextLapseIvl(self, card, conf) -> Any: return max(conf['minInt'], int(card.ivl*conf['mult'])) - def _rescheduleRev(self, card, ease): + def _rescheduleRev(self, card, ease) -> None: # update interval card.lastIvl = card.ivl if self._resched(card): @@ -851,7 +854,7 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" card.odid = 0 card.odue = 0 - def _logRev(self, card, ease, delay): + def _logRev(self, card, ease, delay) -> None: def log(): self.col.db.execute( "insert into revlog values (?,?,?,?,?,?,?,?,?)", @@ -868,7 +871,7 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" # Interval management ########################################################################## - def _nextRevIvl(self, card, ease): + def _nextRevIvl(self, card, ease) -> Any: "Ideal next interval for CARD, given EASE." delay = self._daysLate(card) conf = self._revConf(card) @@ -886,11 +889,11 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" # interval capped? return min(interval, conf['maxIvl']) - def _fuzzedIvl(self, ivl): + def _fuzzedIvl(self, ivl) -> int: min, max = self._fuzzIvlRange(ivl) return random.randint(min, max) - def _fuzzIvlRange(self, ivl): + def _fuzzIvlRange(self, ivl) -> List: if ivl < 2: return [1, 1] elif ivl == 2: @@ -905,22 +908,22 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" fuzz = max(fuzz, 1) return [ivl-fuzz, ivl+fuzz] - def _constrainedIvl(self, ivl, conf, prev): + def _constrainedIvl(self, ivl, conf, prev) -> int: "Integer interval after interval factor and prev+1 constraints applied." new = ivl * conf.get('ivlFct', 1) return int(max(new, prev+1)) - def _daysLate(self, card): + def _daysLate(self, card) -> Any: "Number of days later than scheduled." due = card.odue if card.odid else card.due return max(0, self.today - due) - def _updateRevIvl(self, card, ease): + def _updateRevIvl(self, card, ease) -> None: idealIvl = self._nextRevIvl(card, ease) card.ivl = min(max(self._adjRevIvl(card, idealIvl), card.ivl+1), self._revConf(card)['maxIvl']) - def _adjRevIvl(self, card, idealIvl): + def _adjRevIvl(self, card, idealIvl) -> int: if self._spreadRev: idealIvl = self._fuzzedIvl(idealIvl) return idealIvl @@ -928,7 +931,7 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" # Dynamic deck handling ########################################################################## - def rebuildDyn(self, did=None): + def rebuildDyn(self, did=None) -> Any: "Rebuild a dynamic deck." did = did or self.col.decks.selected() deck = self.col.decks.get(did) @@ -942,7 +945,7 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" self.col.decks.select(did) return ids - def _fillDyn(self, deck): + def _fillDyn(self, deck) -> Any: search, limit, order = deck['terms'][0] orderlimit = self._dynOrder(order, limit) if search.strip(): @@ -958,7 +961,7 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" self._moveToDyn(deck['id'], ids) return ids - def emptyDyn(self, did, lim=None): + def emptyDyn(self, did, lim=None) -> None: if not lim: lim = "did = %s" % did self.col.log(self.col.db.list("select id from cards where %s" % lim)) @@ -969,10 +972,10 @@ else type end), type = (case when type = 1 then 0 else type end), due = odue, odue = 0, odid = 0, usn = ? where %s""" % lim, self.col.usn()) - def remFromDyn(self, cids): + def remFromDyn(self, cids) -> None: self.emptyDyn(None, "id in %s and odid" % ids2str(cids)) - def _dynOrder(self, o, l): + def _dynOrder(self, o, l) -> str: if o == DYN_OLDEST: t = "(select max(id) from revlog where cid=c.id)" elif o == DYN_RANDOM: @@ -997,7 +1000,7 @@ due = odue, odue = 0, odid = 0, usn = ? where %s""" % lim, t = "c.due" return t + " limit %d" % l - def _moveToDyn(self, did, ids): + def _moveToDyn(self, did, ids) -> None: deck = self.col.decks.get(did) data = [] t = intTime(); u = self.col.usn() @@ -1016,7 +1019,7 @@ odid = (case when odid then odid else did end), odue = (case when odue then odue else due end), did = ?, queue = %s, due = ?, usn = ? where id = ?""" % queue, data) - def _dynIvlBoost(self, card): + def _dynIvlBoost(self, card) -> Any: assert card.odid and card.type == 2 assert card.factor elapsed = card.ivl - (card.odue - self.today) @@ -1028,7 +1031,7 @@ did = ?, queue = %s, due = ?, usn = ? where id = ?""" % queue, data) # Leeches ########################################################################## - def _checkLeech(self, card, conf): + def _checkLeech(self, card, conf) -> Optional[bool]: "Leech handler. True if card was a leech." lf = conf['leechFails'] if not lf: @@ -1057,10 +1060,10 @@ did = ?, queue = %s, due = ?, usn = ? where id = ?""" % queue, data) # Tools ########################################################################## - def _cardConf(self, card): + def _cardConf(self, card) -> Any: return self.col.decks.confForDid(card.did) - def _newConf(self, card): + def _newConf(self, card) -> Any: conf = self._cardConf(card) # normal deck if not card.odid: @@ -1080,7 +1083,7 @@ did = ?, queue = %s, due = ?, usn = ? where id = ?""" % queue, data) perDay=self.reportLimit ) - def _lapseConf(self, card): + def _lapseConf(self, card) -> Any: conf = self._cardConf(card) # normal deck if not card.odid: @@ -1099,7 +1102,7 @@ did = ?, queue = %s, due = ?, usn = ? where id = ?""" % queue, data) resched=conf['resched'], ) - def _revConf(self, card): + def _revConf(self, card) -> Any: conf = self._cardConf(card) # normal deck if not card.odid: @@ -1107,10 +1110,10 @@ did = ?, queue = %s, due = ?, usn = ? where id = ?""" % queue, data) # dynamic deck return self.col.decks.confForDid(card.odid)['rev'] - def _deckLimit(self): + def _deckLimit(self) -> str: return ids2str(self.col.decks.active()) - def _resched(self, card): + def _resched(self, card) -> Any: conf = self._cardConf(card) if not conf['dyn']: return True @@ -1119,7 +1122,7 @@ did = ?, queue = %s, due = ?, usn = ? where id = ?""" % queue, data) # Daily cutoff ########################################################################## - def _updateCutoff(self): + def _updateCutoff(self) -> None: oldToday = self.today # days since col created self.today = int((time.time() - self.col.crt) // 86400) @@ -1141,7 +1144,7 @@ did = ?, queue = %s, due = ?, usn = ? where id = ?""" % queue, data) if unburied < self.today: self.unburyCards() - def _checkDay(self): + def _checkDay(self) -> None: # check if the day has rolled over if time.time() > self.dayCutoff: self.reset() @@ -1149,12 +1152,12 @@ did = ?, queue = %s, due = ?, usn = ? where id = ?""" % queue, data) # Deck finished state ########################################################################## - def finishedMsg(self): + def finishedMsg(self) -> str: return (""+_( "Congratulations! You have finished this deck for now.")+ "

" + self._nextDueMsg()) - def _nextDueMsg(self): + def _nextDueMsg(self) -> str: line = [] # the new line replacements are so we don't break translations # in a point release @@ -1181,20 +1184,20 @@ Some related or buried cards were delayed until a later session.""")+now) To study outside of the normal schedule, click the Custom Study button below.""")) return "

".join(line) - def revDue(self): + def revDue(self) -> Any: "True if there are any rev cards due." return self.col.db.scalar( ("select 1 from cards where did in %s and queue = 2 " "and due <= ? limit 1") % self._deckLimit(), self.today) - def newDue(self): + def newDue(self) -> Any: "True if there are any new cards due." return self.col.db.scalar( ("select 1 from cards where did in %s and queue = 0 " "limit 1") % self._deckLimit()) - def haveBuried(self): + def haveBuried(self) -> bool: sdids = ids2str(self.col.decks.active()) cnt = self.col.db.scalar( "select 1 from cards where queue = -2 and did in %s limit 1" % sdids) @@ -1203,7 +1206,7 @@ To study outside of the normal schedule, click the Custom Study button below.""" # Next time reports ########################################################################## - def nextIvlStr(self, card, ease, short=False): + def nextIvlStr(self, card, ease, short=False) -> Any: "Return the next interval for CARD as a string." ivl = self.nextIvl(card, ease) if not ivl: @@ -1213,7 +1216,7 @@ To study outside of the normal schedule, click the Custom Study button below.""" s = "<"+s return s - def nextIvl(self, card, ease): + def nextIvl(self, card, ease) -> Any: "Return the next interval for CARD, in seconds." if card.queue in (0,1,3): return self._nextLrnIvl(card, ease) @@ -1228,7 +1231,7 @@ To study outside of the normal schedule, click the Custom Study button below.""" return self._nextRevIvl(card, ease)*86400 # this isn't easily extracted from the learn code - def _nextLrnIvl(self, card, ease): + def _nextLrnIvl(self, card, ease) -> Any: if card.queue == 0: card.left = self._startingLeft(card) conf = self._lrnConf(card) @@ -1253,7 +1256,7 @@ To study outside of the normal schedule, click the Custom Study button below.""" # Suspending ########################################################################## - def suspendCards(self, ids): + def suspendCards(self, ids) -> None: "Suspend cards." self.col.log(ids) self.remFromDyn(ids) @@ -1262,7 +1265,7 @@ To study outside of the normal schedule, click the Custom Study button below.""" "update cards set queue=-1,mod=?,usn=? where id in "+ ids2str(ids), intTime(), self.col.usn()) - def unsuspendCards(self, ids): + def unsuspendCards(self, ids) -> None: "Unsuspend cards." self.col.log(ids) self.col.db.execute( @@ -1270,7 +1273,7 @@ To study outside of the normal schedule, click the Custom Study button below.""" "where queue = -1 and id in "+ ids2str(ids), intTime(), self.col.usn()) - def buryCards(self, cids): + def buryCards(self, cids) -> None: self.col.log(cids) self.remFromDyn(cids) self.removeLrn(cids) @@ -1278,7 +1281,7 @@ To study outside of the normal schedule, click the Custom Study button below.""" update cards set queue=-2,mod=?,usn=? where id in """+ids2str(cids), intTime(), self.col.usn()) - def buryNote(self, nid): + def buryNote(self, nid) -> None: "Bury all cards for note until next session." cids = self.col.db.list( "select id from cards where nid = ? and queue >= 0", nid) @@ -1287,7 +1290,7 @@ update cards set queue=-2,mod=?,usn=? where id in """+ids2str(cids), # Sibling spacing ########################################################################## - def _burySiblings(self, card): + def _burySiblings(self, card) -> None: toBury = [] nconf = self._newConf(card) buryNew = nconf.get("bury", True) @@ -1324,7 +1327,7 @@ and (queue=0 or (queue=2 and due<=?))""", # Resetting ########################################################################## - def forgetCards(self, ids): + def forgetCards(self, ids) -> None: "Put cards at the end of the new queue." self.remFromDyn(ids) self.col.db.execute( @@ -1336,7 +1339,7 @@ and (queue=0 or (queue=2 and due<=?))""", self.sortCards(ids, start=pmax+1) self.col.log(ids) - def reschedCards(self, ids, imin, imax): + def reschedCards(self, ids, imin, imax) -> None: "Put cards in review queue with a new interval in days (min, max)." d = [] t = self.today @@ -1352,7 +1355,7 @@ usn=:usn,mod=:mod,factor=:fact where id=:id""", d) self.col.log(ids) - def resetCards(self, ids): + def resetCards(self, ids) -> None: "Completely reset cards for export." sids = ids2str(ids) # we want to avoid resetting due number of existing new cards on export @@ -1371,7 +1374,7 @@ usn=:usn,mod=:mod,factor=:fact where id=:id""", # Repositioning new cards ########################################################################## - def sortCards(self, cids, start=1, step=1, shuffle=False, shift=False): + def sortCards(self, cids, start=1, step=1, shuffle=False, shift=False) -> None: scids = ids2str(cids) now = intTime() nids = [] @@ -1411,15 +1414,15 @@ and due >= ? and queue = 0""" % scids, now, self.col.usn(), shiftby, low) self.col.db.executemany( "update cards set due=:due,mod=:now,usn=:usn where id = :cid", d) - def randomizeCards(self, did): + def randomizeCards(self, did) -> None: cids = self.col.db.list("select id from cards where did = ?", did) self.sortCards(cids, shuffle=True) - def orderCards(self, did): + def orderCards(self, did) -> None: cids = self.col.db.list("select id from cards where did = ? order by id", did) self.sortCards(cids) - def resortConf(self, conf): + def resortConf(self, conf) -> None: for did in self.col.decks.didsForConf(conf): if conf['new']['order'] == 0: self.randomizeCards(did) @@ -1427,7 +1430,7 @@ and due >= ? and queue = 0""" % scids, now, self.col.usn(), shiftby, low) self.orderCards(did) # for post-import - def maybeRandomizeDeck(self, did=None): + def maybeRandomizeDeck(self, did=None) -> None: if not did: did = self.col.decks.selected() conf = self.col.decks.confForDid(did) diff --git a/anki/schedv2.py b/anki/schedv2.py index 955e7bf03..34b21fec5 100644 --- a/anki/schedv2.py +++ b/anki/schedv2.py @@ -13,6 +13,7 @@ from anki.utils import ids2str, intTime, fmtTimeSpan from anki.lang import _ from anki.consts import * from anki.hooks import runHook +from typing import Any, List, Optional, Tuple # card types: 0=new, 1=lrn, 2=rev, 3=relrn # queue types: 0=new, 1=(re)lrn, 2=rev, 3=day (re)lrn, @@ -26,7 +27,7 @@ class Scheduler: haveCustomStudy = True _burySiblingsOnAnswer = True - def __init__(self, col): + def __init__(self, col) -> None: self.col = col self.queueLimit = 50 self.reportLimit = 1000 @@ -37,7 +38,7 @@ class Scheduler: self._lrnCutoff = 0 self._updateCutoff() - def getCard(self): + def getCard(self) -> Any: "Pop the next card from the queue. None if finished." self._checkDay() if not self._haveQueues: @@ -51,14 +52,14 @@ class Scheduler: card.startTimer() return card - def reset(self): + def reset(self) -> None: self._updateCutoff() self._resetLrn() self._resetRev() self._resetNew() self._haveQueues = True - def answerCard(self, card, ease): + def answerCard(self, card, ease) -> None: self.col.log() assert 1 <= ease <= 4 assert 0 <= card.queue <= 4 @@ -73,7 +74,7 @@ class Scheduler: card.usn = self.col.usn() card.flushSched() - def _answerCard(self, card, ease): + def _answerCard(self, card, ease) -> None: if self._previewingCard(card): self._answerCardPreview(card, ease) return @@ -103,7 +104,7 @@ class Scheduler: if card.odue: card.odue = 0 - def _answerCardPreview(self, card, ease): + def _answerCardPreview(self, card, ease) -> None: assert 1 <= ease <= 2 if ease == 1: @@ -116,14 +117,14 @@ class Scheduler: self._restorePreviewCard(card) self._removeFromFiltered(card) - def counts(self, card=None): + def counts(self, card=None) -> tuple: counts = [self.newCount, self.lrnCount, self.revCount] if card: idx = self.countIdx(card) counts[idx] += 1 return tuple(counts) - def dueForecast(self, days=7): + def dueForecast(self, days=7) -> List: "Return counts over next DAYS. Includes today." daysd = dict(self.col.db.all(""" select due, count() from cards @@ -141,12 +142,12 @@ order by due""" % self._deckLimit(), ret = [x[1] for x in sorted(daysd.items())] return ret - def countIdx(self, card): + def countIdx(self, card) -> Any: if card.queue in (3,4): return 1 return card.queue - def answerButtons(self, card): + def answerButtons(self, card) -> int: conf = self._cardConf(card) if card.odid and not conf['resched']: return 2 @@ -155,7 +156,7 @@ order by due""" % self._deckLimit(), # Rev/lrn/time daily stats ########################################################################## - def _updateStats(self, card, type, cnt=1): + def _updateStats(self, card, type, cnt=1) -> None: key = type+"Today" for g in ([self.col.decks.get(card.did)] + self.col.decks.parents(card.did)): @@ -163,7 +164,7 @@ order by due""" % self._deckLimit(), g[key][1] += cnt self.col.decks.save(g) - def extendLimits(self, new, rev): + def extendLimits(self, new, rev) -> None: cur = self.col.decks.current() parents = self.col.decks.parents(cur['id']) children = [self.col.decks.get(did) for (name, did) in @@ -174,7 +175,7 @@ order by due""" % self._deckLimit(), g['revToday'][1] -= rev self.col.decks.save(g) - def _walkingCount(self, limFn=None, cntFn=None): + def _walkingCount(self, limFn=None, cntFn=None) -> Any: tot = 0 pcounts = {} # for each of the active decks @@ -208,7 +209,7 @@ order by due""" % self._deckLimit(), # Deck list ########################################################################## - def deckDueList(self): + def deckDueList(self) -> List[list]: "Returns [deckname, did, rev, lrn, new]" self._checkDay() self.col.decks.checkIntegrity() @@ -245,10 +246,10 @@ order by due""" % self._deckLimit(), lims[deck['name']] = [nlim, rlim] return data - def deckDueTree(self): + def deckDueTree(self) -> Any: return self._groupChildren(self.deckDueList()) - def _groupChildren(self, grps): + def _groupChildren(self, grps) -> Tuple[Tuple[Any, Any, Any, Any, Any, Any], ...]: # first, split the group names into components for g in grps: g[0] = g[0].split("::") @@ -257,7 +258,7 @@ order by due""" % self._deckLimit(), # then run main function return self._groupChildrenMain(grps) - def _groupChildrenMain(self, grps): + def _groupChildrenMain(self, grps) -> Tuple[Tuple[Any, Any, Any, Any, Any, Any], ...]: tree = [] # group and recurse def key(grp): @@ -296,7 +297,7 @@ order by due""" % self._deckLimit(), # Getting the next card ########################################################################## - def _getCard(self): + def _getCard(self) -> Any: "Return the next due card id, or None." # learning card due? c = self._getLrnCard() @@ -338,19 +339,19 @@ order by due""" % self._deckLimit(), # New cards ########################################################################## - def _resetNewCount(self): + def _resetNewCount(self) -> None: cntFn = lambda did, lim: self.col.db.scalar(""" select count() from (select 1 from cards where did = ? and queue = 0 limit ?)""", did, lim) self.newCount = self._walkingCount(self._deckNewLimitSingle, cntFn) - def _resetNew(self): + def _resetNew(self) -> None: self._resetNewCount() self._newDids = self.col.decks.active()[:] self._newQueue = [] self._updateNewCardRatio() - def _fillNew(self): + def _fillNew(self) -> Any: if self._newQueue: return True if not self.newCount: @@ -374,12 +375,12 @@ did = ? and queue = 0 limit ?)""", did, lim) self._resetNew() return self._fillNew() - def _getNewCard(self): + def _getNewCard(self) -> Any: if self._fillNew(): self.newCount -= 1 return self.col.getCard(self._newQueue.pop()) - def _updateNewCardRatio(self): + def _updateNewCardRatio(self) -> None: if self.col.conf['newSpread'] == NEW_CARDS_DISTRIBUTE: if self.newCount: self.newCardModulus = ( @@ -390,7 +391,7 @@ did = ? and queue = 0 limit ?)""", did, lim) return self.newCardModulus = 0 - def _timeForNewCard(self): + def _timeForNewCard(self) -> Optional[int]: "True if it's time to display a new card when distributing." if not self.newCount: return False @@ -401,7 +402,7 @@ did = ? and queue = 0 limit ?)""", did, lim) elif self.newCardModulus: return self.reps and self.reps % self.newCardModulus == 0 - def _deckNewLimit(self, did, fn=None): + def _deckNewLimit(self, did, fn=None) -> Any: if not fn: fn = self._deckNewLimitSingle sel = self.col.decks.get(did) @@ -415,7 +416,7 @@ did = ? and queue = 0 limit ?)""", did, lim) lim = min(rem, lim) return lim - def _newForDeck(self, did, lim): + def _newForDeck(self, did, lim) -> Any: "New count for a single deck." if not lim: return 0 @@ -424,14 +425,14 @@ did = ? and queue = 0 limit ?)""", did, lim) select count() from (select 1 from cards where did = ? and queue = 0 limit ?)""", did, lim) - def _deckNewLimitSingle(self, g): + def _deckNewLimitSingle(self, g) -> Any: "Limit for deck without parent limits." if g['dyn']: return self.dynReportLimit c = self.col.decks.confForDid(g['id']) return max(0, c['new']['perDay'] - g['newToday'][1]) - def totalNewForCurrentDeck(self): + def totalNewForCurrentDeck(self) -> Any: return self.col.db.scalar( """ select count() from cards where id in ( @@ -442,18 +443,18 @@ select id from cards where did in %s and queue = 0 limit ?)""" ########################################################################## # scan for any newly due learning cards every minute - def _updateLrnCutoff(self, force): + def _updateLrnCutoff(self, force) -> bool: nextCutoff = intTime() + self.col.conf['collapseTime'] if nextCutoff - self._lrnCutoff > 60 or force: self._lrnCutoff = nextCutoff return True return False - def _maybeResetLrn(self, force): + def _maybeResetLrn(self, force) -> None: if self._updateLrnCutoff(force): self._resetLrn() - def _resetLrnCount(self): + def _resetLrnCount(self) -> None: # sub-day self.lrnCount = self.col.db.scalar(""" select count() from cards where did in %s and queue = 1 @@ -470,7 +471,7 @@ and due <= ?""" % (self._deckLimit()), select count() from cards where did in %s and queue = 4 """ % (self._deckLimit())) - def _resetLrn(self): + def _resetLrn(self) -> None: self._updateLrnCutoff(force=True) self._resetLrnCount() self._lrnQueue = [] @@ -478,7 +479,7 @@ select count() from cards where did in %s and queue = 4 self._lrnDids = self.col.decks.active()[:] # sub-day learning - def _fillLrn(self): + def _fillLrn(self) -> Any: if not self.lrnCount: return False if self._lrnQueue: @@ -492,7 +493,7 @@ limit %d""" % (self._deckLimit(), self.reportLimit), lim=cutoff) self._lrnQueue.sort() return self._lrnQueue - def _getLrnCard(self, collapse=False): + def _getLrnCard(self, collapse=False) -> Any: self._maybeResetLrn(force=collapse and self.lrnCount == 0) if self._fillLrn(): cutoff = time.time() @@ -505,7 +506,7 @@ limit %d""" % (self._deckLimit(), self.reportLimit), lim=cutoff) return card # daily learning - def _fillLrnDay(self): + def _fillLrnDay(self) -> Optional[bool]: if not self.lrnCount: return False if self._lrnDayQueue: @@ -529,12 +530,12 @@ did = ? and queue = 3 and due <= ? limit ?""", # nothing left in the deck; move to next self._lrnDids.pop(0) - def _getLrnDayCard(self): + def _getLrnDayCard(self) -> Any: if self._fillLrnDay(): self.lrnCount -= 1 return self.col.getCard(self._lrnDayQueue.pop()) - def _answerLrnCard(self, card, ease): + def _answerLrnCard(self, card, ease) -> None: conf = self._lrnConf(card) if card.type in (2,3): type = 2 @@ -565,11 +566,11 @@ did = ? and queue = 3 and due <= ? limit ?""", self._logLrn(card, ease, conf, leaving, type, lastLeft) - def _updateRevIvlOnFail(self, card, conf): + def _updateRevIvlOnFail(self, card, conf) -> None: card.lastIvl = card.ivl card.ivl = self._lapseIvl(card, conf) - def _moveToFirstStep(self, card, conf): + def _moveToFirstStep(self, card, conf) -> Any: card.left = self._startingLeft(card) # relearning card? @@ -578,18 +579,18 @@ did = ? and queue = 3 and due <= ? limit ?""", return self._rescheduleLrnCard(card, conf) - def _moveToNextStep(self, card, conf): + def _moveToNextStep(self, card, conf) -> None: # decrement real left count and recalculate left today left = (card.left % 1000) - 1 card.left = self._leftToday(conf['delays'], left)*1000 + left self._rescheduleLrnCard(card, conf) - def _repeatStep(self, card, conf): + def _repeatStep(self, card, conf) -> None: delay = self._delayForRepeatingGrade(conf, card.left) self._rescheduleLrnCard(card, conf, delay=delay) - def _rescheduleLrnCard(self, card, conf, delay=None): + def _rescheduleLrnCard(self, card, conf, delay=None) -> Any: # normal delay for the current step? if delay is None: delay = self._delayForGrade(conf, card.left) @@ -619,7 +620,7 @@ did = ? and queue = 3 and due <= ? limit ?""", card.queue = 3 return delay - def _delayForGrade(self, conf, left): + def _delayForGrade(self, conf, left) -> Any: left = left % 1000 try: delay = conf['delays'][-left] @@ -631,7 +632,7 @@ did = ? and queue = 3 and due <= ? limit ?""", delay = 1 return delay*60 - def _delayForRepeatingGrade(self, conf, left): + def _delayForRepeatingGrade(self, conf, left) -> Any: # halfway between last and next delay1 = self._delayForGrade(conf, left) if len(conf['delays']) > 1: @@ -641,13 +642,13 @@ did = ? and queue = 3 and due <= ? limit ?""", avg = (delay1+max(delay1, delay2))//2 return avg - def _lrnConf(self, card): + def _lrnConf(self, card) -> Any: if card.type in (2, 3): return self._lapseConf(card) else: return self._newConf(card) - def _rescheduleAsRev(self, card, conf, early): + def _rescheduleAsRev(self, card, conf, early) -> None: lapse = card.type in (2,3) if lapse: @@ -659,14 +660,14 @@ did = ? and queue = 3 and due <= ? limit ?""", if card.odid: self._removeFromFiltered(card) - def _rescheduleGraduatingLapse(self, card, early=False): + def _rescheduleGraduatingLapse(self, card, early=False) -> None: if early: card.ivl += 1 card.due = self.today+card.ivl card.queue = 2 card.type = 2 - def _startingLeft(self, card): + def _startingLeft(self, card) -> int: if card.type == 3: conf = self._lapseConf(card) else: @@ -675,7 +676,7 @@ did = ? and queue = 3 and due <= ? limit ?""", tod = self._leftToday(conf['delays'], tot) return tot + tod*1000 - def _leftToday(self, delays, left, now=None): + def _leftToday(self, delays, left, now=None) -> int: "The number of steps that can be completed by the day cutoff." if not now: now = intTime() @@ -688,7 +689,7 @@ did = ? and queue = 3 and due <= ? limit ?""", ok = i return ok+1 - def _graduatingIvl(self, card, conf, early, fuzz=True): + def _graduatingIvl(self, card, conf, early, fuzz=True) -> Any: if card.type in (2,3): bonus = early and 1 or 0 return card.ivl + bonus @@ -702,14 +703,14 @@ did = ? and queue = 3 and due <= ? limit ?""", ideal = self._fuzzedIvl(ideal) return ideal - def _rescheduleNew(self, card, conf, early): + def _rescheduleNew(self, card, conf, early) -> None: "Reschedule a new card that's graduated for the first time." card.ivl = self._graduatingIvl(card, conf, early) card.due = self.today+card.ivl card.factor = conf['initialFactor'] card.type = card.queue = 2 - def _logLrn(self, card, ease, conf, leaving, type, lastLeft): + def _logLrn(self, card, ease, conf, leaving, type, lastLeft) -> None: lastIvl = -(self._delayForGrade(conf, lastLeft)) ivl = card.ivl if leaving else -(self._delayForGrade(conf, card.left)) def log(): @@ -724,7 +725,7 @@ did = ? and queue = 3 and due <= ? limit ?""", time.sleep(0.01) log() - def _lrnForDeck(self, did): + def _lrnForDeck(self, did) -> Any: cnt = self.col.db.scalar( """ select count() from @@ -740,11 +741,11 @@ and due <= ? limit ?)""", # Reviews ########################################################################## - def _currentRevLimit(self): + def _currentRevLimit(self) -> Any: d = self.col.decks.get(self.col.decks.selected(), default=False) return self._deckRevLimitSingle(d) - def _deckRevLimitSingle(self, d, parentLimit=None): + def _deckRevLimitSingle(self, d, parentLimit=None) -> Any: # invalid deck selected? if not d: return 0 @@ -765,7 +766,7 @@ and due <= ? limit ?)""", lim = min(lim, self._deckRevLimitSingle(parent, parentLimit=lim)) return lim - def _revForDeck(self, did, lim, childMap): + def _revForDeck(self, did, lim, childMap) -> Any: dids = [did] + self.col.decks.childDids(did, childMap) lim = min(lim, self.reportLimit) return self.col.db.scalar( @@ -775,18 +776,18 @@ select count() from and due <= ? limit ?)""" % ids2str(dids), self.today, lim) - def _resetRevCount(self): + def _resetRevCount(self) -> None: lim = self._currentRevLimit() self.revCount = self.col.db.scalar(""" select count() from (select id from cards where did in %s and queue = 2 and due <= ? limit %d)""" % ( ids2str(self.col.decks.active()), lim), self.today) - def _resetRev(self): + def _resetRev(self) -> None: self._resetRevCount() self._revQueue = [] - def _fillRev(self): + def _fillRev(self) -> Any: if self._revQueue: return True if not self.revCount: @@ -813,12 +814,12 @@ limit ?""" % (ids2str(self.col.decks.active())), self._resetRev() return self._fillRev() - def _getRevCard(self): + def _getRevCard(self) -> Any: if self._fillRev(): self.revCount -= 1 return self.col.getCard(self._revQueue.pop()) - def totalRevForCurrentDeck(self): + def totalRevForCurrentDeck(self) -> Any: return self.col.db.scalar( """ select count() from cards where id in ( @@ -828,7 +829,7 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" # Answering a review card ########################################################################## - def _answerRevCard(self, card, ease): + def _answerRevCard(self, card, ease) -> None: delay = 0 early = card.odid and (card.odue > self.today) type = early and 3 or 1 @@ -840,7 +841,7 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" self._logRev(card, ease, delay, type) - def _rescheduleLapse(self, card): + def _rescheduleLapse(self, card) -> Any: conf = self._lapseConf(card) card.lapses += 1 @@ -862,11 +863,11 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" return delay - def _lapseIvl(self, card, conf): + def _lapseIvl(self, card, conf) -> Any: ivl = max(1, conf['minInt'], int(card.ivl*conf['mult'])) return ivl - def _rescheduleRev(self, card, ease, early): + def _rescheduleRev(self, card, ease, early) -> None: # update interval card.lastIvl = card.ivl if early: @@ -881,7 +882,7 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" # card leaves filtered deck self._removeFromFiltered(card) - def _logRev(self, card, ease, delay, type): + def _logRev(self, card, ease, delay, type) -> None: def log(): self.col.db.execute( "insert into revlog values (?,?,?,?,?,?,?,?,?)", @@ -898,7 +899,7 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" # Interval management ########################################################################## - def _nextRevIvl(self, card, ease, fuzz): + def _nextRevIvl(self, card, ease, fuzz) -> int: "Next review interval for CARD, given EASE." delay = self._daysLate(card) conf = self._revConf(card) @@ -920,11 +921,11 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" (card.ivl + delay) * fct * conf['ease4'], conf, ivl3, fuzz) return ivl4 - def _fuzzedIvl(self, ivl): + def _fuzzedIvl(self, ivl) -> int: min, max = self._fuzzIvlRange(ivl) return random.randint(min, max) - def _fuzzIvlRange(self, ivl): + def _fuzzIvlRange(self, ivl) -> List: if ivl < 2: return [1, 1] elif ivl == 2: @@ -939,7 +940,7 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" fuzz = max(fuzz, 1) return [ivl-fuzz, ivl+fuzz] - def _constrainedIvl(self, ivl, conf, prev, fuzz): + def _constrainedIvl(self, ivl, conf, prev, fuzz) -> int: ivl = int(ivl * conf.get('ivlFct', 1)) if fuzz: ivl = self._fuzzedIvl(ivl) @@ -947,19 +948,19 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" ivl = min(ivl, conf['maxIvl']) return int(ivl) - def _daysLate(self, card): + def _daysLate(self, card) -> Any: "Number of days later than scheduled." due = card.odue if card.odid else card.due return max(0, self.today - due) - def _updateRevIvl(self, card, ease): + def _updateRevIvl(self, card, ease) -> None: card.ivl = self._nextRevIvl(card, ease, fuzz=True) - def _updateEarlyRevIvl(self, card, ease): + def _updateEarlyRevIvl(self, card, ease) -> None: card.ivl = self._earlyReviewIvl(card, ease) # next interval for card when answered early+correctly - def _earlyReviewIvl(self, card, ease): + def _earlyReviewIvl(self, card, ease) -> int: assert card.odid and card.type == 2 assert card.factor assert ease > 1 @@ -997,7 +998,7 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" # Dynamic deck handling ########################################################################## - def rebuildDyn(self, did=None): + def rebuildDyn(self, did=None) -> Optional[int]: "Rebuild a dynamic deck." did = did or self.col.decks.selected() deck = self.col.decks.get(did) @@ -1011,7 +1012,7 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" self.col.decks.select(did) return cnt - def _fillDyn(self, deck): + def _fillDyn(self, deck) -> int: start = -100000 total = 0 for search, limit, order in deck['terms']: @@ -1029,7 +1030,7 @@ select id from cards where did in %s and queue = 2 and due <= ? limit ?)""" total += len(ids) return total - def emptyDyn(self, did, lim=None): + def emptyDyn(self, did, lim=None) -> None: if not lim: lim = "did = %s" % did self.col.log(self.col.db.list("select id from cards where %s" % lim)) @@ -1040,10 +1041,10 @@ due = (case when odue>0 then odue else due end), odue = 0, odid = 0, usn = ? whe self._restoreQueueSnippet, lim), self.col.usn()) - def remFromDyn(self, cids): + def remFromDyn(self, cids) -> None: self.emptyDyn(None, "id in %s and odid" % ids2str(cids)) - def _dynOrder(self, o, l): + def _dynOrder(self, o, l) -> str: if o == DYN_OLDEST: t = "(select max(id) from revlog where cid=c.id)" elif o == DYN_RANDOM: @@ -1065,7 +1066,7 @@ due = (case when odue>0 then odue else due end), odue = 0, odid = 0, usn = ? whe t = "c.due, c.ord" return t + " limit %d" % l - def _moveToDyn(self, did, ids, start=-100000): + def _moveToDyn(self, did, ids, start=-100000) -> None: deck = self.col.decks.get(did) data = [] u = self.col.usn() @@ -1089,13 +1090,13 @@ where id = ? """ % queue self.col.db.executemany(query, data) - def _removeFromFiltered(self, card): + def _removeFromFiltered(self, card) -> None: if card.odid: card.did = card.odid card.odue = 0 card.odid = 0 - def _restorePreviewCard(self, card): + def _restorePreviewCard(self, card) -> None: assert card.odid card.due = card.odue @@ -1113,7 +1114,7 @@ where id = ? # Leeches ########################################################################## - def _checkLeech(self, card, conf): + def _checkLeech(self, card, conf) -> Optional[bool]: "Leech handler. True if card was a leech." lf = conf['leechFails'] if not lf: @@ -1136,10 +1137,10 @@ where id = ? # Tools ########################################################################## - def _cardConf(self, card): + def _cardConf(self, card) -> Any: return self.col.decks.confForDid(card.did) - def _newConf(self, card): + def _newConf(self, card) -> Any: conf = self._cardConf(card) # normal deck if not card.odid: @@ -1158,7 +1159,7 @@ where id = ? perDay=self.reportLimit ) - def _lapseConf(self, card): + def _lapseConf(self, card) -> Any: conf = self._cardConf(card) # normal deck if not card.odid: @@ -1176,7 +1177,7 @@ where id = ? resched=conf['resched'], ) - def _revConf(self, card): + def _revConf(self, card) -> Any: conf = self._cardConf(card) # normal deck if not card.odid: @@ -1184,20 +1185,20 @@ where id = ? # dynamic deck return self.col.decks.confForDid(card.odid)['rev'] - def _deckLimit(self): + def _deckLimit(self) -> str: return ids2str(self.col.decks.active()) - def _previewingCard(self, card): + def _previewingCard(self, card) -> Any: conf = self._cardConf(card) return conf['dyn'] and not conf['resched'] - def _previewDelay(self, card): + def _previewDelay(self, card) -> Any: return self._cardConf(card).get("previewDelay", 10)*60 # Daily cutoff ########################################################################## - def _updateCutoff(self): + def _updateCutoff(self) -> None: oldToday = self.today # days since col created self.today = self._daysSinceCreation() @@ -1220,12 +1221,12 @@ where id = ? self.unburyCards() self.col.conf['lastUnburied'] = self.today - def _checkDay(self): + def _checkDay(self) -> None: # check if the day has rolled over if time.time() > self.dayCutoff: self.reset() - def _dayCutoff(self): + def _dayCutoff(self) -> int: rolloverTime = self.col.conf.get("rollover", 4) if rolloverTime < 0: rolloverTime = 24+rolloverTime @@ -1237,7 +1238,7 @@ where id = ? stamp = int(time.mktime(date.timetuple())) return stamp - def _daysSinceCreation(self): + def _daysSinceCreation(self) -> int: startDate = datetime.datetime.fromtimestamp(self.col.crt) startDate = startDate.replace(hour=self.col.conf.get("rollover", 4), minute=0, second=0, microsecond=0) @@ -1246,12 +1247,12 @@ where id = ? # Deck finished state ########################################################################## - def finishedMsg(self): + def finishedMsg(self) -> str: return (""+_( "Congratulations! You have finished this deck for now.")+ "

" + self._nextDueMsg()) - def _nextDueMsg(self): + def _nextDueMsg(self) -> str: line = [] # the new line replacements are so we don't break translations # in a point release @@ -1278,38 +1279,38 @@ Some related or buried cards were delayed until a later session.""")+now) To study outside of the normal schedule, click the Custom Study button below.""")) return "

".join(line) - def revDue(self): + def revDue(self) -> Any: "True if there are any rev cards due." return self.col.db.scalar( ("select 1 from cards where did in %s and queue = 2 " "and due <= ? limit 1") % self._deckLimit(), self.today) - def newDue(self): + def newDue(self) -> Any: "True if there are any new cards due." return self.col.db.scalar( ("select 1 from cards where did in %s and queue = 0 " "limit 1") % self._deckLimit()) - def haveBuriedSiblings(self): + def haveBuriedSiblings(self) -> bool: sdids = ids2str(self.col.decks.active()) cnt = self.col.db.scalar( "select 1 from cards where queue = -2 and did in %s limit 1" % sdids) return not not cnt - def haveManuallyBuried(self): + def haveManuallyBuried(self) -> bool: sdids = ids2str(self.col.decks.active()) cnt = self.col.db.scalar( "select 1 from cards where queue = -3 and did in %s limit 1" % sdids) return not not cnt - def haveBuried(self): + def haveBuried(self) -> bool: return self.haveManuallyBuried() or self.haveBuriedSiblings() # Next time reports ########################################################################## - def nextIvlStr(self, card, ease, short=False): + def nextIvlStr(self, card, ease, short=False) -> Any: "Return the next interval for CARD as a string." ivl = self.nextIvl(card, ease) if not ivl: @@ -1319,7 +1320,7 @@ To study outside of the normal schedule, click the Custom Study button below.""" s = "<"+s return s - def nextIvl(self, card, ease): + def nextIvl(self, card, ease) -> Any: "Return the next interval for CARD, in seconds." # preview mode? if self._previewingCard(card): @@ -1345,7 +1346,7 @@ To study outside of the normal schedule, click the Custom Study button below.""" return self._nextRevIvl(card, ease, fuzz=False)*86400 # this isn't easily extracted from the learn code - def _nextLrnIvl(self, card, ease): + def _nextLrnIvl(self, card, ease) -> Any: if card.queue == 0: card.left = self._startingLeft(card) conf = self._lrnConf(card) @@ -1377,14 +1378,14 @@ else end) """ - def suspendCards(self, ids): + def suspendCards(self, ids) -> None: "Suspend cards." self.col.log(ids) self.col.db.execute( "update cards set queue=-1,mod=?,usn=? where id in "+ ids2str(ids), intTime(), self.col.usn()) - def unsuspendCards(self, ids): + def unsuspendCards(self, ids) -> None: "Unsuspend cards." self.col.log(ids) self.col.db.execute( @@ -1392,27 +1393,27 @@ end) "where queue = -1 and id in %s") % (self._restoreQueueSnippet, ids2str(ids)), intTime(), self.col.usn()) - def buryCards(self, cids, manual=True): + def buryCards(self, cids, manual=True) -> None: queue = manual and -3 or -2 self.col.log(cids) self.col.db.execute(""" update cards set queue=?,mod=?,usn=? where id in """+ids2str(cids), queue, intTime(), self.col.usn()) - def buryNote(self, nid): + def buryNote(self, nid) -> None: "Bury all cards for note until next session." cids = self.col.db.list( "select id from cards where nid = ? and queue >= 0", nid) self.buryCards(cids) - def unburyCards(self): + def unburyCards(self) -> None: "Unbury all buried cards in all decks." self.col.log( self.col.db.list("select id from cards where queue in (-2, -3)")) self.col.db.execute( "update cards set %s where queue in (-2, -3)" % self._restoreQueueSnippet) - def unburyCardsForDeck(self, type="all"): + def unburyCardsForDeck(self, type="all") -> None: if type == "all": queue = "queue in (-2, -3)" elif type == "manual": @@ -1433,7 +1434,7 @@ update cards set queue=?,mod=?,usn=? where id in """+ids2str(cids), # Sibling spacing ########################################################################## - def _burySiblings(self, card): + def _burySiblings(self, card) -> None: toBury = [] nconf = self._newConf(card) buryNew = nconf.get("bury", True) @@ -1467,7 +1468,7 @@ and (queue=0 or (queue=2 and due<=?))""", # Resetting ########################################################################## - def forgetCards(self, ids): + def forgetCards(self, ids) -> None: "Put cards at the end of the new queue." self.remFromDyn(ids) self.col.db.execute( @@ -1479,7 +1480,7 @@ and (queue=0 or (queue=2 and due<=?))""", self.sortCards(ids, start=pmax+1) self.col.log(ids) - def reschedCards(self, ids, imin, imax): + def reschedCards(self, ids, imin, imax) -> None: "Put cards in review queue with a new interval in days (min, max)." d = [] t = self.today @@ -1495,7 +1496,7 @@ usn=:usn,mod=:mod,factor=:fact where id=:id""", d) self.col.log(ids) - def resetCards(self, ids): + def resetCards(self, ids) -> None: "Completely reset cards for export." sids = ids2str(ids) # we want to avoid resetting due number of existing new cards on export @@ -1514,7 +1515,7 @@ usn=:usn,mod=:mod,factor=:fact where id=:id""", # Repositioning new cards ########################################################################## - def sortCards(self, cids, start=1, step=1, shuffle=False, shift=False): + def sortCards(self, cids, start=1, step=1, shuffle=False, shift=False) -> None: scids = ids2str(cids) now = intTime() nids = [] @@ -1554,15 +1555,15 @@ and due >= ? and queue = 0""" % scids, now, self.col.usn(), shiftby, low) self.col.db.executemany( "update cards set due=:due,mod=:now,usn=:usn where id = :cid", d) - def randomizeCards(self, did): + def randomizeCards(self, did) -> None: cids = self.col.db.list("select id from cards where did = ?", did) self.sortCards(cids, shuffle=True) - def orderCards(self, did): + def orderCards(self, did) -> None: cids = self.col.db.list("select id from cards where did = ? order by id", did) self.sortCards(cids) - def resortConf(self, conf): + def resortConf(self, conf) -> None: for did in self.col.decks.didsForConf(conf): if conf['new']['order'] == 0: self.randomizeCards(did) @@ -1570,7 +1571,7 @@ and due >= ? and queue = 0""" % scids, now, self.col.usn(), shiftby, low) self.orderCards(did) # for post-import - def maybeRandomizeDeck(self, did=None): + def maybeRandomizeDeck(self, did=None) -> None: if not did: did = self.col.decks.selected() conf = self.col.decks.confForDid(did) @@ -1581,7 +1582,7 @@ and due >= ? and queue = 0""" % scids, now, self.col.usn(), shiftby, low) # Changing scheduler versions ########################################################################## - def _emptyAllFiltered(self): + def _emptyAllFiltered(self) -> None: self.col.db.execute(""" update cards set did = odid, queue = (case when type = 1 then 0 @@ -1593,7 +1594,7 @@ else type end), due = odue, odue = 0, odid = 0, usn = ? where odid != 0""", self.col.usn()) - def _removeAllFromLearning(self, schedVer=2): + def _removeAllFromLearning(self, schedVer=2) -> None: # remove review cards from relearning if schedVer == 1: self.col.db.execute(""" @@ -1612,7 +1613,7 @@ due = odue, odue = 0, odid = 0, usn = ? where odid != 0""", "select id from cards where queue in (1,3)")) # v1 doesn't support buried/suspended (re)learning cards - def _resetSuspendedLearning(self): + def _resetSuspendedLearning(self) -> None: self.col.db.execute(""" update cards set type = (case when type = 1 then 0 @@ -1624,15 +1625,15 @@ mod = %d, usn = %d where queue < 0""" % (intTime(), self.col.usn())) # no 'manually buried' queue in v1 - def _moveManuallyBuried(self): + def _moveManuallyBuried(self) -> None: self.col.db.execute("update cards set queue=-2,mod=%d where queue=-3" % intTime()) # adding 'hard' in v2 scheduler means old ease entries need shifting # up or down - def _remapLearningAnswers(self, sql): + def _remapLearningAnswers(self, sql) -> None: self.col.db.execute("update revlog set %s and type in (0,2)" % sql) - def moveToV1(self): + def moveToV1(self) -> None: self._emptyAllFiltered() self._removeAllFromLearning() @@ -1640,7 +1641,7 @@ where queue < 0""" % (intTime(), self.col.usn())) self._resetSuspendedLearning() self._remapLearningAnswers("ease=ease-1 where ease in (3,4)") - def moveToV2(self): + def moveToV2(self) -> None: self._emptyAllFiltered() self._removeAllFromLearning(schedVer=1) self._remapLearningAnswers("ease=ease+1 where ease in (2,3)") diff --git a/anki/sound.py b/anki/sound.py index 72760b485..f98406ea7 100644 --- a/anki/sound.py +++ b/anki/sound.py @@ -5,7 +5,9 @@ import html import re, sys, threading, time, subprocess, os, atexit import random -from typing import List +from typing import List, Tuple, Dict, Any +from typing import Callable, NoReturn, Optional + from anki.hooks import addHook, runHook from anki.utils import tmpdir, isWin, isMac, isLin from anki.lang import _ @@ -15,19 +17,19 @@ from anki.lang import _ _soundReg = r"\[sound:(.*?)\]" -def playFromText(text): +def playFromText(text) -> None: for match in allSounds(text): # filename is html encoded match = html.unescape(match) play(match) -def allSounds(text): +def allSounds(text) -> List: return re.findall(_soundReg, text) -def stripSounds(text): +def stripSounds(text) -> str: return re.sub(_soundReg, "", text) -def hasSound(text): +def hasSound(text) -> bool: return re.search(_soundReg, text) is not None # Packaged commands @@ -35,7 +37,7 @@ def hasSound(text): # return modified command array that points to bundled command, and return # required environment -def _packagedCmd(cmd): +def _packagedCmd(cmd) -> Tuple[Any, Dict[str, str]]: cmd = cmd[:] env = os.environ.copy() if "LD_LIBRARY_PATH" in env: @@ -76,7 +78,7 @@ if sys.platform == "win32": else: si = None -def retryWait(proc): +def retryWait(proc) -> Any: # osx throws interrupted system call errors frequently while 1: try: @@ -89,6 +91,10 @@ def retryWait(proc): from anki.mpv import MPV, MPVBase +_player: Optional[Callable[[Any], Any]] +_queueEraser: Optional[Callable[[], Any]] +_soundReg: str + mpvPath, mpvEnv = _packagedCmd(["mpv"]) class MpvManager(MPV): @@ -101,28 +107,28 @@ class MpvManager(MPV): "--input-media-keys=no", ] - def __init__(self): + def __init__(self) -> None: super().__init__(window_id=None, debug=False) - def queueFile(self, file): + def queueFile(self, file) -> None: runHook("mpvWillPlay", file) path = os.path.join(os.getcwd(), file) self.command("loadfile", path, "append-play") - def clearQueue(self): + def clearQueue(self) -> None: self.command("stop") - def togglePause(self): + def togglePause(self) -> None: self.set_property("pause", not self.get_property("pause")) - def seekRelative(self, secs): + def seekRelative(self, secs) -> None: self.command("seek", secs, "relative") - def on_idle(self): + def on_idle(self) -> None: runHook("mpvIdleHook") -def setMpvConfigBase(base): +def setMpvConfigBase(base) -> None: mpvConfPath = os.path.join(base, "mpv.conf") MpvManager.default_argv += [ "--no-config", @@ -131,14 +137,14 @@ def setMpvConfigBase(base): mpvManager = None -def setupMPV(): +def setupMPV() -> None: global mpvManager, _player, _queueEraser mpvManager = MpvManager() _player = mpvManager.queueFile _queueEraser = mpvManager.clearQueue atexit.register(cleanupMPV) -def cleanupMPV(): +def cleanupMPV() -> None: global mpvManager, _player, _queueEraser if mpvManager: mpvManager.close() @@ -151,7 +157,7 @@ def cleanupMPV(): # if anki crashes, an old mplayer instance may be left lying around, # which prevents renaming or deleting the profile -def cleanupOldMplayerProcesses(): +def cleanupOldMplayerProcesses() -> None: # pylint: disable=import-error import psutil # pytype: disable=import-error @@ -189,7 +195,7 @@ class MplayerMonitor(threading.Thread): mplayer = None deadPlayers: List[subprocess.Popen] = [] - def run(self): + def run(self) -> NoReturn: global mplayerClear self.mplayer = None self.deadPlayers = [] @@ -244,7 +250,7 @@ class MplayerMonitor(threading.Thread): return True self.deadPlayers = [pl for pl in self.deadPlayers if clean(pl)] - def kill(self): + def kill(self) -> None: if not self.mplayer: return try: @@ -255,7 +261,7 @@ class MplayerMonitor(threading.Thread): pass self.mplayer = None - def startProcess(self): + def startProcess(self) -> subprocess.Popen: try: cmd = mplayerCmd + ["-slave", "-idle"] cmd, env = _packagedCmd(cmd) @@ -267,7 +273,7 @@ class MplayerMonitor(threading.Thread): mplayerEvt.clear() raise Exception("Did you install mplayer?") -def queueMplayer(path): +def queueMplayer(path) -> None: ensureMplayerThreads() if isWin and os.path.exists(path): # mplayer on windows doesn't like the encoding, so we create a @@ -284,13 +290,13 @@ def queueMplayer(path): mplayerQueue.append(path) mplayerEvt.set() -def clearMplayerQueue(): +def clearMplayerQueue() -> None: global mplayerClear, mplayerQueue mplayerQueue = [] mplayerClear = True mplayerEvt.set() -def ensureMplayerThreads(): +def ensureMplayerThreads() -> None: global mplayerManager if not mplayerManager: mplayerManager = MplayerMonitor() @@ -302,7 +308,7 @@ def ensureMplayerThreads(): # clean up mplayer on exit atexit.register(stopMplayer) -def stopMplayer(*args): +def stopMplayer(*args) -> None: if not mplayerManager: return mplayerManager.kill() @@ -326,7 +332,7 @@ except: class _Recorder: - def postprocess(self, encode=True): + def postprocess(self, encode=True) -> None: self.encode = encode for c in processingChain: #print c @@ -344,18 +350,18 @@ class _Recorder: "Error running %s") % " ".join(cmd)) - def cleanup(self): + def cleanup(self) -> None: if os.path.exists(processingSrc): os.unlink(processingSrc) class PyAudioThreadedRecorder(threading.Thread): - def __init__(self, startupDelay): + def __init__(self, startupDelay) -> None: threading.Thread.__init__(self) self.startupDelay = startupDelay self.finish = False - def run(self): + def run(self) -> Any: chunk = 1024 p = pyaudio.PyAudio() @@ -421,10 +427,10 @@ if not pyaudio: _player = queueMplayer _queueEraser = clearMplayerQueue -def play(path): +def play(path) -> None: _player(path) -def clearAudioQueue(): +def clearAudioQueue() -> None: _queueEraser() Recorder = PyAudioRecorder diff --git a/anki/stats.py b/anki/stats.py index 69df4b935..d9c98789c 100644 --- a/anki/stats.py +++ b/anki/stats.py @@ -8,6 +8,7 @@ import json from anki.utils import fmtTimeSpan, ids2str from anki.lang import _, ngettext +from typing import Any, List, Tuple, Optional # Card stats @@ -15,12 +16,12 @@ from anki.lang import _, ngettext class CardStats: - def __init__(self, col, card): + def __init__(self, col, card) -> None: self.col = col self.card = card self.txt = "" - def report(self): + def report(self) -> str: c = self.card # pylint: disable=unnecessary-lambda fmt = lambda x, **kwargs: fmtTimeSpan(x, short=True, **kwargs) @@ -65,24 +66,24 @@ class CardStats: self.txt += "" return self.txt - def addLine(self, k, v): + def addLine(self, k, v) -> None: self.txt += self.makeLine(k, v) - def makeLine(self, k, v): + def makeLine(self, k, v) -> str: txt = "" txt += "%s%s" % (k, v) return txt - def date(self, tm): + def date(self, tm) -> str: return time.strftime("%Y-%m-%d", time.localtime(tm)) - def time(self, tm): - str = "" + def time(self, tm) -> str: + s = "" if tm >= 60: - str = fmtTimeSpan((tm/60)*60, short=True, point=-1, unit=1) - if tm%60 != 0 or not str: - str += fmtTimeSpan(tm%60, point=2 if not str else -1, short=True) - return str + s = fmtTimeSpan((tm/60)*60, short=True, point=-1, unit=1) + if tm%60 != 0 or not s: + s += fmtTimeSpan(tm%60, point=2 if not s else -1, short=True) + return s # Collection stats ########################################################################## @@ -101,7 +102,7 @@ colSusp = "#ff0" class CollectionStats: - def __init__(self, col): + def __init__(self, col) -> None: self.col = col self._stats = None self.type = 0 @@ -110,7 +111,7 @@ class CollectionStats: self.wholeCollection = False # assumes jquery & plot are available in document - def report(self, type=0): + def report(self, type=0) -> str: # 0=days, 1=weeks, 2=months self.type = type from .statsbg import bg @@ -126,7 +127,7 @@ class CollectionStats: txt += self._section(self.footer()) return "

%s
" % txt - def _section(self, txt): + def _section(self, txt) -> str: return "
%s
" % txt css = """ @@ -143,7 +144,7 @@ body {background-image: url(data:image/png;base64,%s); } # Today stats ###################################################################### - def todayStats(self): + def todayStats(self) -> str: b = self._title(_("Today")) # studied today lim = self._revlogLimit() @@ -199,13 +200,13 @@ from revlog where id > ? """+lim, (self.col.sched.dayCutoff-86400)*1000) # Due and cumulative due ###################################################################### - def get_start_end_chunk(self, by='review'): + def get_start_end_chunk(self, by='review') -> Tuple[int, Optional[int], int]: start = 0 if self.type == 0: end, chunk = 31, 1 elif self.type == 1: end, chunk = 52, 7 - elif self.type == 2: + else: # self.type == 2: end = None if self._deckAge(by) <= 100: chunk = 1 @@ -215,7 +216,7 @@ from revlog where id > ? """+lim, (self.col.sched.dayCutoff-86400)*1000) chunk = 31 return start, end, chunk - def dueGraph(self): + def dueGraph(self) -> str: start, end, chunk = self.get_start_end_chunk() d = self._due(start, end, chunk) yng = [] @@ -251,7 +252,7 @@ from revlog where id > ? """+lim, (self.col.sched.dayCutoff-86400)*1000) txt += self._dueInfo(tot, len(totd)*chunk) return txt - def _dueInfo(self, tot, num): + def _dueInfo(self, tot, num) -> str: i = [] self._line(i, _("Total"), ngettext("%d review", "%d reviews", tot) % tot) self._line(i, _("Average"), self._avgDay( @@ -263,7 +264,7 @@ and due = ?""" % self._limit(), self.col.sched.today+1) self._line(i, _("Due tomorrow"), tomorrow) return self._lineTbl(i) - def _due(self, start=None, end=None, chunk=1): + def _due(self, start=None, end=None, chunk=1) -> Any: lim = "" if start is not None: lim += " and due-:today >= %d" % start @@ -283,7 +284,7 @@ group by day order by day""" % (self._limit(), lim), # Added, reps and time spent ###################################################################### - def introductionGraph(self): + def introductionGraph(self) -> str: start, days, chunk = self.get_start_end_chunk() data = self._added(days, chunk) if not data: @@ -315,7 +316,7 @@ group by day order by day""" % (self._limit(), lim), return txt - def repsGraphs(self): + def repsGraphs(self) -> str: start, days, chunk = self.get_start_end_chunk() data = self._done(days, chunk) if not data: @@ -363,7 +364,7 @@ group by day order by day""" % (self._limit(), lim), txt2 += rep return self._section(txt1) + self._section(txt2) - def _ansInfo(self, totd, studied, first, unit, convHours=False, total=None): + def _ansInfo(self, totd, studied, first, unit, convHours=False, total=None) -> Tuple[str, int]: assert(totd) tot = totd[-1][1] period = self._periodDays() @@ -404,7 +405,7 @@ group by day order by day""" % (self._limit(), lim), _("%(a)0.1fs (%(b)s)") % dict(a=(tot*60)/total, b=text)) return self._lineTbl(i), int(tot) - def _splitRepData(self, data, spec): + def _splitRepData(self, data, spec) -> Tuple[List[dict], List[Tuple[Any, Any]]]: sep = {} totcnt = {} totd = {} @@ -433,7 +434,7 @@ group by day order by day""" % (self._limit(), lim), bars={'show': False}, lines=dict(show=True), stack=-n)) return (ret, alltot) - def _added(self, num=7, chunk=1): + def _added(self, num=7, chunk=1) -> Any: lims = [] if num is not None: lims.append("id > %d" % ( @@ -454,7 +455,7 @@ count(id) from cards %s group by day order by day""" % lim, cut=self.col.sched.dayCutoff,tf=tf, chunk=chunk) - def _done(self, num=7, chunk=1): + def _done(self, num=7, chunk=1) -> Any: lims = [] if num is not None: lims.append("id > %d" % ( @@ -490,7 +491,7 @@ group by day order by day""" % lim, tf=tf, chunk=chunk) - def _daysStudied(self): + def _daysStudied(self) -> Any: lims = [] num = self._periodDays() if num: @@ -516,7 +517,7 @@ group by day order by day)""" % lim, # Intervals ###################################################################### - def ivlGraph(self): + def ivlGraph(self) -> str: (ivls, all, avg, max_), chunk = self._ivls() tot = 0 totd = [] @@ -545,7 +546,7 @@ group by day order by day)""" % lim, self._line(i, _("Longest interval"), fmtTimeSpan(max_*86400)) return txt + self._lineTbl(i) - def _ivls(self): + def _ivls(self) -> Tuple[list, int]: start, end, chunk = self.get_start_end_chunk() lim = "and grp <= %d" % end if end else "" data = [self.col.db.all(""" @@ -560,7 +561,7 @@ select count(), avg(ivl), max(ivl) from cards where did in %s and queue = 2""" % # Eases ###################################################################### - def easeGraph(self): + def easeGraph(self) -> str: # 3 + 4 + 4 + spaces on sides and middle = 15 # yng starts at 1+3+1 = 5 # mtr starts at 5+4+1 = 10 @@ -591,7 +592,7 @@ select count(), avg(ivl), max(ivl) from cards where did in %s and queue = 2""" % txt += self._easeInfo(eases) return txt - def _easeInfo(self, eases): + def _easeInfo(self, eases) -> str: types = {0: [0, 0], 1: [0, 0], 2: [0,0]} for (type, ease, cnt) in eases: if ease == 1: @@ -614,7 +615,7 @@ select count(), avg(ivl), max(ivl) from cards where did in %s and queue = 2""" % "".join(i) + "") - def _eases(self): + def _eases(self) -> Any: lims = [] lim = self._revlogLimit() if lim: @@ -643,7 +644,7 @@ order by thetype, ease""" % (ease4repl, lim)) # Hourly retention ###################################################################### - def hourGraph(self): + def hourGraph(self) -> str: data = self._hourRet() if not data: return "" @@ -690,7 +691,7 @@ order by thetype, ease""" % (ease4repl, lim)) txt += _("Hours with less than 30 reviews are not shown.") return txt - def _hourRet(self): + def _hourRet(self) -> Any: lim = self._revlogLimit() if lim: lim = " and " + lim @@ -715,7 +716,7 @@ group by hour having count() > 30 order by hour""" % lim, # Cards ###################################################################### - def cardGraph(self): + def cardGraph(self) -> str: # graph data div = self._cards() d = [] @@ -749,7 +750,7 @@ when you answer "good" on a review.''') info) return txt - def _line(self, i, a, b, bold=True): + def _line(self, i, a, b, bold=True) -> None: #T: Symbols separating first and second column in a statistics table. Eg in "Total: 3 reviews". colon = _(":") if bold: @@ -757,10 +758,10 @@ when you answer "good" on a review.''') else: i.append(("%s%s%s") % (a,colon,b)) - def _lineTbl(self, i): + def _lineTbl(self, i) -> str: return "" + "".join(i) + "
" - def _factors(self): + def _factors(self) -> Any: return self.col.db.first(""" select min(factor) / 10.0, @@ -768,7 +769,7 @@ avg(factor) / 10.0, max(factor) / 10.0 from cards where did in %s and queue = 2""" % self._limit()) - def _cards(self): + def _cards(self) -> Any: return self.col.db.first(""" select sum(case when queue=2 and ivl >= 21 then 1 else 0 end), -- mtr @@ -780,7 +781,7 @@ from cards where did in %s""" % self._limit()) # Footer ###################################################################### - def footer(self): + def footer(self) -> str: b = "

" b += _("Generated on %s") % time.asctime(time.localtime(time.time())) b += "
" @@ -801,7 +802,7 @@ from cards where did in %s""" % self._limit()) ###################################################################### def _graph(self, id, data, conf=None, - type="bars", xunit=1, ylabel=_("Cards"), ylabel2=""): + type="bars", xunit=1, ylabel=_("Cards"), ylabel2="") -> str: if conf is None: conf = {} # display settings @@ -902,21 +903,21 @@ $(function () { ylab=ylabel, ylab2=ylabel2, data=json.dumps(data), conf=json.dumps(conf))) - def _limit(self): + def _limit(self) -> Any: if self.wholeCollection: return ids2str([d['id'] for d in self.col.decks.all()]) return self.col.sched._deckLimit() - def _revlogLimit(self): + def _revlogLimit(self) -> str: if self.wholeCollection: return "" return ("cid in (select id from cards where did in %s)" % ids2str(self.col.decks.active())) - def _title(self, title, subtitle=""): + def _title(self, title, subtitle="") -> str: return '

%s

%s' % (title, subtitle) - def _deckAge(self, by): + def _deckAge(self, by) -> int: lim = self._revlogLimit() if lim: lim = " where " + lim @@ -932,13 +933,13 @@ $(function () { 1, int(1+((self.col.sched.dayCutoff - (t/1000)) / 86400))) return period - def _periodDays(self): + def _periodDays(self) -> Optional[int]: start, end, chunk = self.get_start_end_chunk() if end is None: return None return end * chunk - def _avgDay(self, tot, num, unit): + def _avgDay(self, tot, num, unit) -> str: vals = [] try: vals.append(_("%(a)0.1f %(b)s/day") % dict(a=tot/float(num), b=unit)) diff --git a/anki/stdmodels.py b/anki/stdmodels.py index ba86fc72d..c96413a73 100644 --- a/anki/stdmodels.py +++ b/anki/stdmodels.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright: Ankitects Pty Ltd and contributors # License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html +from typing import Dict, Any from anki.lang import _ from anki.consts import MODEL_CLOZE @@ -10,7 +11,7 @@ models = [] # Basic ########################################################################## -def _newBasicModel(col, name=None): +def _newBasicModel(col, name=None) -> Dict[str, Any]: mm = col.models m = mm.new(name or _("Basic")) fm = mm.newField(_("Front")) @@ -23,7 +24,7 @@ def _newBasicModel(col, name=None): mm.addTemplate(m, t) return m -def addBasicModel(col): +def addBasicModel(col) -> Dict[str, Any]: m = _newBasicModel(col) col.models.add(m) return m @@ -33,7 +34,7 @@ models.append((lambda: _("Basic"), addBasicModel)) # Basic w/ typing ########################################################################## -def addBasicTypingModel(col): +def addBasicTypingModel(col) -> Dict[str, Any]: mm = col.models m = _newBasicModel(col, _("Basic (type in the answer)")) t = m['tmpls'][0] @@ -47,7 +48,7 @@ models.append((lambda: _("Basic (type in the answer)"), addBasicTypingModel)) # Forward & Reverse ########################################################################## -def _newForwardReverse(col, name=None): +def _newForwardReverse(col, name=None) -> Dict[str, Any]: mm = col.models m = _newBasicModel(col, name or _("Basic (and reversed card)")) t = mm.newTemplate(_("Card 2")) @@ -56,7 +57,7 @@ def _newForwardReverse(col, name=None): mm.addTemplate(m, t) return m -def addForwardReverse(col): +def addForwardReverse(col) -> Dict[str, Any]: m = _newForwardReverse(col) col.models.add(m) return m @@ -66,7 +67,7 @@ models.append((lambda: _("Basic (and reversed card)"), addForwardReverse)) # Forward & Optional Reverse ########################################################################## -def addForwardOptionalReverse(col): +def addForwardOptionalReverse(col) -> Dict[str, Any]: mm = col.models m = _newForwardReverse(col, _("Basic (optional reversed card)")) av = _("Add Reverse") @@ -83,7 +84,7 @@ models.append((lambda: _("Basic (optional reversed card)"), # Cloze ########################################################################## -def addClozeModel(col): +def addClozeModel(col) -> Dict[str, Any]: mm = col.models m = mm.new(_("Cloze")) m['type'] = MODEL_CLOZE diff --git a/anki/storage.py b/anki/storage.py index 6375d6f1c..815ee7e72 100644 --- a/anki/storage.py +++ b/anki/storage.py @@ -14,8 +14,11 @@ from anki.collection import _Collection from anki.consts import * from anki.stdmodels import addBasicModel, addClozeModel, addForwardReverse, \ addForwardOptionalReverse, addBasicTypingModel +from typing import Any, Dict, List, Optional, Tuple, Type, Union -def Collection(path, lock=True, server=False, log=False): +_Collection: Type[_Collection] + +def Collection(path, lock=True, server=False, log=False) -> _Collection: "Open a new or existing collection. Path must be unicode." assert path.endswith(".anki2") path = os.path.abspath(path) @@ -54,7 +57,7 @@ def Collection(path, lock=True, server=False, log=False): col.lock() return col -def _upgradeSchema(db): +def _upgradeSchema(db) -> Any: ver = db.scalar("select ver from col") if ver == SCHEMA_VERSION: return ver @@ -83,7 +86,7 @@ id, guid, mid, mod, usn, tags, flds, sfld, csum, flags, data from notes2""") _updateIndices(db) return ver -def _upgrade(col, ver): +def _upgrade(col, ver) -> None: if ver < 3: # new deck properties for d in col.decks.all(): @@ -184,7 +187,7 @@ update cards set left = left + left*1000 where queue = 1""") col.models.save(m) col.db.execute("update col set ver = 11") -def _upgradeClozeModel(col, m): +def _upgradeClozeModel(col, m) -> None: m['type'] = MODEL_CLOZE # convert first template t = m['tmpls'][0] @@ -205,7 +208,7 @@ def _upgradeClozeModel(col, m): # Creating a new collection ###################################################################### -def _createDB(db): +def _createDB(db) -> int: db.execute("pragma page_size = 4096") db.execute("pragma legacy_file_format = 0") db.execute("vacuum") @@ -214,7 +217,7 @@ def _createDB(db): db.execute("analyze") return SCHEMA_VERSION -def _addSchema(db, setColConf=True): +def _addSchema(db, setColConf=True) -> None: db.executescript(""" create table if not exists col ( id integer primary key, @@ -291,7 +294,7 @@ values(1,0,0,%(s)s,%(v)s,0,0,0,'','{}','','','{}'); if setColConf: _addColVars(db, *_getColVars(db)) -def _getColVars(db): +def _getColVars(db) -> Tuple[Any, Any, Dict[str, Optional[Union[int, str, List[int]]]]]: import anki.collection import anki.decks g = copy.deepcopy(anki.decks.defaultDeck) @@ -303,14 +306,14 @@ def _getColVars(db): gc['id'] = 1 return g, gc, anki.collection.defaultConf.copy() -def _addColVars(db, g, gc, c): +def _addColVars(db, g, gc, c) -> None: db.execute(""" update col set conf = ?, decks = ?, dconf = ?""", json.dumps(c), json.dumps({'1': g}), json.dumps({'1': gc})) -def _updateIndices(db): +def _updateIndices(db) -> None: "Add indices to the DB." db.executescript(""" -- syncing diff --git a/anki/sync.py b/anki/sync.py index 48de878ef..ec75f4eca 100644 --- a/anki/sync.py +++ b/anki/sync.py @@ -16,6 +16,7 @@ from anki.utils import versionWithBuild from .hooks import runHook import anki from .lang import ngettext +from typing import Any, Dict, List, Optional, Tuple, Union # syncing vars HTTP_TIMEOUT = 90 @@ -30,7 +31,7 @@ class UnexpectedSchemaChange(Exception): class Syncer: - def __init__(self, col, server=None): + def __init__(self, col, server=None) -> None: self.col = col self.server = server @@ -39,7 +40,7 @@ class Syncer: self.maxUsn = 0 self.tablesLeft = [] - def sync(self): + def sync(self) -> str: "Returns 'noChanges', 'fullSync', 'success', etc" self.syncMsg = "" self.uname = "" @@ -138,14 +139,14 @@ class Syncer: self.finish(mod) return "success" - def _forceFullSync(self): + def _forceFullSync(self) -> str: # roll back and force full sync self.col.rollback() self.col.modSchema(False) self.col.save() return "sanityCheckFailed" - def _gravesChunk(self, graves): + def _gravesChunk(self, graves: Dict) -> Tuple[Dict, Optional[Dict]]: lim = 250 chunk = dict(notes=[], cards=[], decks=[]) for cat in "notes", "cards", "decks": @@ -159,7 +160,7 @@ class Syncer: return chunk, graves return chunk, None - def meta(self): + def meta(self) -> dict: return dict( mod=self.col.mod, scm=self.col.scm, @@ -170,7 +171,7 @@ class Syncer: cont=True ) - def changes(self): + def changes(self) -> dict: "Bundle up small objects." d = dict(models=self.getModels(), decks=self.getDecks(), @@ -180,7 +181,7 @@ class Syncer: d['crt'] = self.col.crt return d - def mergeChanges(self, lchg, rchg): + def mergeChanges(self, lchg, rchg) -> None: # then the other objects self.mergeModels(rchg['models']) self.mergeDecks(rchg['decks']) @@ -192,7 +193,7 @@ class Syncer: self.col.crt = rchg['crt'] self.prepareToChunk() - def sanityCheck(self): + def sanityCheck(self) -> Union[list, str]: if not self.col.basicCheck(): return "failed basic check" for t in "cards", "notes", "revlog", "graves": @@ -226,10 +227,10 @@ class Syncer: len(self.col.decks.allConf()), ] - def usnLim(self): + def usnLim(self) -> str: return "usn = -1" - def finish(self, mod=None): + def finish(self, mod: int) -> int: self.col.ls = mod self.col._usn = self.maxUsn + 1 # ensure we save the mod time even if no changes made @@ -240,11 +241,11 @@ class Syncer: # Chunked syncing ########################################################################## - def prepareToChunk(self): + def prepareToChunk(self) -> None: self.tablesLeft = ["revlog", "cards", "notes"] self.cursor = None - def cursorForTable(self, table): + def cursorForTable(self, table) -> Any: lim = self.usnLim() x = self.col.db.execute d = (self.maxUsn, lim) @@ -261,7 +262,7 @@ lapses, left, odue, odid, flags, data from cards where %s""" % d) select id, guid, mid, mod, %d, tags, flds, '', '', flags, data from notes where %s""" % d) - def chunk(self): + def chunk(self) -> dict: buf = dict(done=False) lim = 250 while self.tablesLeft and lim: @@ -284,7 +285,7 @@ from notes where %s""" % d) buf['done'] = True return buf - def applyChunk(self, chunk): + def applyChunk(self, chunk) -> None: if "revlog" in chunk: self.mergeRevlog(chunk['revlog']) if "cards" in chunk: @@ -295,7 +296,7 @@ from notes where %s""" % d) # Deletions ########################################################################## - def removed(self): + def removed(self) -> dict: cards = [] notes = [] decks = [] @@ -316,7 +317,7 @@ from notes where %s""" % d) return dict(cards=cards, notes=notes, decks=decks) - def remove(self, graves): + def remove(self, graves) -> None: # pretend to be the server so we don't set usn = -1 self.col.server = True @@ -333,14 +334,14 @@ from notes where %s""" % d) # Models ########################################################################## - def getModels(self): + def getModels(self) -> List: mods = [m for m in self.col.models.all() if m['usn'] == -1] for m in mods: m['usn'] = self.maxUsn self.col.models.save() return mods - def mergeModels(self, rchg): + def mergeModels(self, rchg) -> None: for r in rchg: l = self.col.models.get(r['id']) # if missing locally or server is newer, update @@ -358,7 +359,7 @@ from notes where %s""" % d) # Decks ########################################################################## - def getDecks(self): + def getDecks(self) -> List[list]: decks = [g for g in self.col.decks.all() if g['usn'] == -1] for g in decks: g['usn'] = self.maxUsn @@ -368,7 +369,7 @@ from notes where %s""" % d) self.col.decks.save() return [decks, dconf] - def mergeDecks(self, rchg): + def mergeDecks(self, rchg) -> None: for r in rchg[0]: l = self.col.decks.get(r['id'], False) # work around mod time being stored as string @@ -390,7 +391,7 @@ from notes where %s""" % d) # Tags ########################################################################## - def getTags(self): + def getTags(self) -> List: tags = [] for t, usn in self.col.tags.allItems(): if usn == -1: @@ -399,18 +400,18 @@ from notes where %s""" % d) self.col.tags.save() return tags - def mergeTags(self, tags): + def mergeTags(self, tags) -> None: self.col.tags.register(tags, usn=self.maxUsn) # Cards/notes/revlog ########################################################################## - def mergeRevlog(self, logs): + def mergeRevlog(self, logs) -> None: self.col.db.executemany( "insert or ignore into revlog values (?,?,?,?,?,?,?,?,?)", logs) - def newerRows(self, data, table, modIdx): + def newerRows(self, data, table, modIdx) -> List: ids = (r[0] for r in data) lmods = {} for id, mod in self.col.db.execute( @@ -424,13 +425,13 @@ from notes where %s""" % d) self.col.log(table, data) return update - def mergeCards(self, cards): + def mergeCards(self, cards) -> None: self.col.db.executemany( "insert or replace into cards values " "(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", self.newerRows(cards, "cards", 4)) - def mergeNotes(self, notes): + def mergeNotes(self, notes) -> None: rows = self.newerRows(notes, "notes", 3) self.col.db.executemany( "insert or replace into notes values (?,?,?,?,?,?,?,?,?,?,?)", @@ -440,10 +441,10 @@ from notes where %s""" % d) # Col config ########################################################################## - def getConf(self): + def getConf(self) -> Any: return self.col.conf - def mergeConf(self, conf): + def mergeConf(self, conf) -> None: self.col.conf = conf # Wrapper for requests that tracks upload/download progress @@ -454,22 +455,22 @@ class AnkiRequestsClient: verify = True timeout = 60 - def __init__(self): + def __init__(self) -> None: self.session = requests.Session() - def post(self, url, data, headers): + def post(self, url, data, headers) -> Any: data = _MonitoringFile(data) # pytype: disable=wrong-arg-types headers['User-Agent'] = self._agentName() return self.session.post( url, data=data, headers=headers, stream=True, timeout=self.timeout, verify=self.verify) # pytype: disable=wrong-arg-types - def get(self, url, headers=None): + def get(self, url, headers=None) -> requests.models.Response: if headers is None: headers = {} headers['User-Agent'] = self._agentName() return self.session.get(url, stream=True, headers=headers, timeout=self.timeout, verify=self.verify) - def streamContent(self, resp): + def streamContent(self, resp) -> bytes: resp.raise_for_status() buf = io.BytesIO() @@ -478,7 +479,7 @@ class AnkiRequestsClient: buf.write(chunk) return buf.getvalue() - def _agentName(self): + def _agentName(self) -> str: from anki import version return "Anki {}".format(version) @@ -490,7 +491,7 @@ if os.environ.get("ANKI_NOVERIFYSSL"): warnings.filterwarnings("ignore") class _MonitoringFile(io.BufferedReader): - def read(self, size=-1): + def read(self, size=-1) -> bytes: data = io.BufferedReader.read(self, HTTP_BUF_SIZE) runHook("httpSend", len(data)) return data @@ -500,7 +501,7 @@ class _MonitoringFile(io.BufferedReader): class HttpSyncer: - def __init__(self, hkey=None, client=None, hostNum=None): + def __init__(self, hkey=None, client=None, hostNum=None) -> None: self.hkey = hkey self.skey = checksum(str(random.random()))[:8] self.client = client or AnkiRequestsClient() @@ -508,14 +509,14 @@ class HttpSyncer: self.hostNum = hostNum self.prefix = "sync/" - def syncURL(self): + def syncURL(self) -> str: if devMode: url = "https://l1sync.ankiweb.net/" else: url = SYNC_BASE % (self.hostNum or "") return url + self.prefix - def assertOk(self, resp): + def assertOk(self, resp) -> None: # not using raise_for_status() as aqt expects this error msg if resp.status_code != 200: raise Exception("Unknown response code: %s" % resp.status_code) @@ -526,7 +527,7 @@ class HttpSyncer: # costly. We could send it as a raw post, but more HTTP clients seem to # support file uploading, so this is the more compatible choice. - def _buildPostData(self, fobj, comp): + def _buildPostData(self, fobj, comp) -> Tuple[Dict[str, str], io.BytesIO]: BOUNDARY=b"Anki-sync-boundary" bdry = b"--"+BOUNDARY buf = io.BytesIO() @@ -573,7 +574,7 @@ Content-Type: application/octet-stream\r\n\r\n""") return headers, buf - def req(self, method, fobj=None, comp=6, badAuthRaises=True): + def req(self, method, fobj=None, comp=6, badAuthRaises=True) -> Any: headers, body = self._buildPostData(fobj, comp) r = self.client.post(self.syncURL()+method, data=body, headers=headers) @@ -589,10 +590,10 @@ Content-Type: application/octet-stream\r\n\r\n""") class RemoteServer(HttpSyncer): - def __init__(self, hkey, hostNum): + def __init__(self, hkey, hostNum) -> None: HttpSyncer.__init__(self, hkey, hostNum=hostNum) - def hostKey(self, user, pw): + def hostKey(self, user, pw) -> Any: "Returns hkey or none if user/pw incorrect." self.postVars = dict() ret = self.req( @@ -604,7 +605,7 @@ class RemoteServer(HttpSyncer): self.hkey = json.loads(ret.decode("utf8"))['key'] return self.hkey - def meta(self): + def meta(self) -> Any: self.postVars = dict( k=self.hkey, s=self.skey, @@ -618,31 +619,31 @@ class RemoteServer(HttpSyncer): return return json.loads(ret.decode("utf8")) - def applyGraves(self, **kw): + def applyGraves(self, **kw) -> Any: return self._run("applyGraves", kw) - def applyChanges(self, **kw): + def applyChanges(self, **kw) -> Any: return self._run("applyChanges", kw) - def start(self, **kw): + def start(self, **kw) -> Any: return self._run("start", kw) - def chunk(self, **kw): + def chunk(self, **kw) -> Any: return self._run("chunk", kw) - def applyChunk(self, **kw): + def applyChunk(self, **kw) -> Any: return self._run("applyChunk", kw) - def sanityCheck2(self, **kw): + def sanityCheck2(self, **kw) -> Any: return self._run("sanityCheck2", kw) - def finish(self, **kw): + def finish(self, **kw) -> Any: return self._run("finish", kw) - def abort(self, **kw): + def abort(self, **kw) -> Any: return self._run("abort", kw) - def _run(self, cmd, data): + def _run(self, cmd, data) -> Any: return json.loads( self.req(cmd, io.BytesIO(json.dumps(data).encode("utf8"))).decode("utf8")) @@ -651,7 +652,7 @@ class RemoteServer(HttpSyncer): class FullSyncer(HttpSyncer): - def __init__(self, col, hkey, client, hostNum): + def __init__(self, col, hkey, client, hostNum) -> None: HttpSyncer.__init__(self, hkey, client, hostNum=hostNum) self.postVars = dict( k=self.hkey, @@ -659,7 +660,7 @@ class FullSyncer(HttpSyncer): ) self.col = col - def download(self): + def download(self) -> Optional[str]: runHook("sync", "download") localNotEmpty = self.col.db.scalar("select 1 from cards") self.col.close() @@ -683,7 +684,7 @@ class FullSyncer(HttpSyncer): os.rename(tpath, self.col.path) self.col = None - def upload(self): + def upload(self) -> bool: "True if upload successful." runHook("sync", "upload") # make sure it's ok before we try to upload @@ -709,12 +710,12 @@ class FullSyncer(HttpSyncer): class MediaSyncer: - def __init__(self, col, server=None): + def __init__(self, col, server=None) -> None: self.col = col self.server = server self.downloadCount = 0 - def sync(self): + def sync(self) -> Any: # check if there have been any changes runHook("sync", "findMedia") self.col.log("findChanges") @@ -824,7 +825,7 @@ class MediaSyncer: self.col.media.forceResync() return ret - def _downloadFiles(self, fnames): + def _downloadFiles(self, fnames) -> None: self.col.log("%d files to fetch"%len(fnames)) while fnames: top = fnames[0:SYNC_ZIP_COUNT] @@ -845,12 +846,12 @@ class MediaSyncer: class RemoteMediaServer(HttpSyncer): - def __init__(self, col, hkey, client, hostNum): + def __init__(self, col, hkey, client, hostNum) -> None: self.col = col HttpSyncer.__init__(self, hkey, client, hostNum=hostNum) self.prefix = "msync/" - def begin(self): + def begin(self) -> Any: self.postVars = dict( k=self.hkey, v="ankidesktop,%s,%s"%(anki.version, platDesc()) @@ -861,7 +862,7 @@ class RemoteMediaServer(HttpSyncer): return ret # args: lastUsn - def mediaChanges(self, **kw): + def mediaChanges(self, **kw) -> Any: self.postVars = dict( sk=self.skey, ) @@ -869,20 +870,20 @@ class RemoteMediaServer(HttpSyncer): self.req("mediaChanges", io.BytesIO(json.dumps(kw).encode("utf8")))) # args: files - def downloadFiles(self, **kw): + def downloadFiles(self, **kw) -> Any: return self.req("downloadFiles", io.BytesIO(json.dumps(kw).encode("utf8"))) - def uploadChanges(self, zip): + def uploadChanges(self, zip) -> Any: # no compression, as we compress the zip file instead return self._dataOnly( self.req("uploadChanges", io.BytesIO(zip), comp=0)) # args: local - def mediaSanity(self, **kw): + def mediaSanity(self, **kw) -> Any: return self._dataOnly( self.req("mediaSanity", io.BytesIO(json.dumps(kw).encode("utf8")))) - def _dataOnly(self, resp): + def _dataOnly(self, resp) -> Any: resp = json.loads(resp.decode("utf8")) if resp['err']: self.col.log("error returned:%s"%resp['err']) @@ -890,7 +891,7 @@ class RemoteMediaServer(HttpSyncer): return resp['data'] # only for unit tests - def mediatest(self, cmd): + def mediatest(self, cmd) -> Any: self.postVars = dict( k=self.hkey, ) diff --git a/anki/tags.py b/anki/tags.py index 60ddaa9ad..171b0a20e 100644 --- a/anki/tags.py +++ b/anki/tags.py @@ -14,21 +14,22 @@ import json from anki.utils import intTime, ids2str from anki.hooks import runHook import re +from typing import Any, List, Tuple class TagManager: # Registry save/load ############################################################# - def __init__(self, col): + def __init__(self, col) -> None: self.col = col self.tags = {} - def load(self, json_): + def load(self, json_) -> None: self.tags = json.loads(json_) self.changed = False - def flush(self): + def flush(self) -> None: if self.changed: self.col.db.execute("update col set tags=?", json.dumps(self.tags)) @@ -37,7 +38,7 @@ class TagManager: # Registering and fetching tags ############################################################# - def register(self, tags, usn=None): + def register(self, tags, usn=None) -> None: "Given a list of tags, add any missing ones to tag registry." found = False for t in tags: @@ -48,10 +49,10 @@ class TagManager: if found: runHook("newTag") - def all(self): + def all(self) -> List: return list(self.tags.keys()) - def registerNotes(self, nids=None): + def registerNotes(self, nids=None) -> None: "Add any missing tags from notes to the tags list." # when called without an argument, the old list is cleared first. if nids: @@ -63,13 +64,13 @@ class TagManager: self.register(set(self.split( " ".join(self.col.db.list("select distinct tags from notes"+lim))))) - def allItems(self): + def allItems(self) -> List[Tuple[Any, Any]]: return list(self.tags.items()) - def save(self): + def save(self) -> None: self.changed = True - def byDeck(self, did, children=False): + def byDeck(self, did, children=False) -> List: basequery = "select n.tags from cards c, notes n WHERE c.nid = n.id" if not children: query = basequery + " AND c.did=?" @@ -85,7 +86,7 @@ class TagManager: # Bulk addition/removal from notes ############################################################# - def bulkAdd(self, ids, tags, add=True): + def bulkAdd(self, ids, tags, add=True) -> None: "Add tags in bulk. TAGS is space-separated." newTags = self.split(tags) if not newTags: @@ -117,23 +118,23 @@ class TagManager: "update notes set tags=:t,mod=:n,usn=:u where id = :id", [fix(row) for row in res]) - def bulkRem(self, ids, tags): + def bulkRem(self, ids, tags) -> None: self.bulkAdd(ids, tags, False) # String-based utilities ########################################################################## - def split(self, tags): + def split(self, tags) -> List: "Parse a string and return a list of tags." return [t for t in tags.replace('\u3000', ' ').split(" ") if t] - def join(self, tags): + def join(self, tags) -> str: "Join tags into a single string, with leading and trailing spaces." if not tags: return "" return " %s " % " ".join(tags) - def addToStr(self, addtags, tags): + def addToStr(self, addtags, tags) -> str: "Add tags if they don't exist, and canonify." currentTags = self.split(tags) for tag in self.split(addtags): @@ -141,7 +142,7 @@ class TagManager: currentTags.append(tag) return self.join(self.canonify(currentTags)) - def remFromStr(self, deltags, tags): + def remFromStr(self, deltags, tags) -> str: "Delete tags if they exist." def wildcard(pat, str): pat = re.escape(pat).replace('\\*', '.*') @@ -161,7 +162,7 @@ class TagManager: # List-based utilities ########################################################################## - def canonify(self, tagList): + def canonify(self, tagList) -> List: "Strip duplicates, adjust case to match existing tags, and sort." strippedTags = [] for t in tagList: @@ -172,14 +173,14 @@ class TagManager: strippedTags.append(s) return sorted(set(strippedTags)) - def inList(self, tag, tags): + def inList(self, tag, tags) -> bool: "True if TAG is in TAGS. Ignore case." return tag.lower() in [t.lower() for t in tags] # Sync handling ########################################################################## - def beforeUpload(self): + def beforeUpload(self) -> None: for k in list(self.tags.keys()): self.tags[k] = 0 self.save() diff --git a/anki/template/__init__.py b/anki/template/__init__.py index 291b3e0e2..afb92d515 100644 --- a/anki/template/__init__.py +++ b/anki/template/__init__.py @@ -1,8 +1,9 @@ from .template import Template from . import furigana; furigana.install() from . import hint; hint.install() +from typing import Any -def render(template, context=None, **kwargs): +def render(template, context=None, **kwargs) -> Any: context = context and context.copy() or {} context.update(kwargs) return Template(template, context).render() diff --git a/anki/template/furigana.py b/anki/template/furigana.py index 929b0817d..bef895c1d 100644 --- a/anki/template/furigana.py +++ b/anki/template/furigana.py @@ -5,11 +5,12 @@ import re from anki.hooks import addHook +from typing import Any, Callable r = r' ?([^ >]+?)\[(.+?)\]' ruby = r'\1\2' -def noSound(repl): +def noSound(repl) -> Callable[[Any], Any]: def func(match): if match.group(2).startswith("sound:"): # return without modification @@ -18,19 +19,19 @@ def noSound(repl): return re.sub(r, repl, match.group(0)) return func -def _munge(s): +def _munge(s) -> Any: return s.replace(" ", " ") -def kanji(txt, *args): +def kanji(txt, *args) -> str: return re.sub(r, noSound(r'\1'), _munge(txt)) -def kana(txt, *args): +def kana(txt, *args) -> str: return re.sub(r, noSound(r'\2'), _munge(txt)) -def furigana(txt, *args): +def furigana(txt, *args) -> str: return re.sub(r, noSound(ruby), _munge(txt)) -def install(): +def install() -> None: addHook('fmod_kanji', kanji) addHook('fmod_kana', kana) addHook('fmod_furigana', furigana) diff --git a/anki/template/hint.py b/anki/template/hint.py index ad4e9e6b4..d2bfaf65a 100644 --- a/anki/template/hint.py +++ b/anki/template/hint.py @@ -5,7 +5,7 @@ from anki.hooks import addHook from anki.lang import _ -def hint(txt, extra, context, tag, fullname): +def hint(txt, extra, context, tag, fullname) -> str: if not txt.strip(): return "" # random id @@ -16,5 +16,5 @@ onclick="this.style.display='none';document.getElementById('%s').style.display=' %s """ % (domid, _("Show %s") % tag, domid, txt) -def install(): +def install() -> None: addHook('fmod_hint', hint) diff --git a/anki/template/template.py b/anki/template/template.py index ad53d47e9..2d6881ad6 100644 --- a/anki/template/template.py +++ b/anki/template/template.py @@ -1,11 +1,12 @@ import re from anki.utils import stripHTML, stripHTMLMedia from anki.hooks import runFilter +from typing import Any, Callable, NoReturn, Optional clozeReg = r"(?si)\{\{(c)%s::(.*?)(::(.*?))?\}\}" modifiers = {} -def modifier(symbol): +def modifier(symbol) -> Callable[[Any], Any]: """Decorator for associating a function with a Mustache tag modifier. @modifier('P') @@ -20,7 +21,7 @@ def modifier(symbol): return set_modifier -def get_or_attr(obj, name, default=None): +def get_or_attr(obj, name, default=None) -> Any: try: return obj[name] except KeyError: @@ -45,12 +46,12 @@ class Template: # Closing tag delimiter ctag = '}}' - def __init__(self, template, context=None): + def __init__(self, template, context=None) -> None: self.template = template self.context = context or {} self.compile_regexps() - def render(self, template=None, context=None, encoding=None): + def render(self, template=None, context=None, encoding=None) -> str: """Turns a Mustache template into something wonderful.""" template = template or self.template context = context or self.context @@ -61,7 +62,7 @@ class Template: result = result.encode(encoding) return result - def compile_regexps(self): + def compile_regexps(self) -> None: """Compiles our section and tag regular expressions.""" tags = { 'otag': re.escape(self.otag), 'ctag': re.escape(self.ctag) } @@ -71,7 +72,7 @@ class Template: tag = r"%(otag)s(#|=|&|!|>|\{)?(.+?)\1?%(ctag)s+" self.tag_re = re.compile(tag % tags) - def render_sections(self, template, context): + def render_sections(self, template, context) -> NoReturn: """Expands sections.""" while 1: match = self.section_re.search(template) @@ -104,7 +105,7 @@ class Template: return template - def render_tags(self, template, context): + def render_tags(self, template, context) -> str: """Renders all the tags in a template for a context.""" repCount = 0 while 1: @@ -130,16 +131,16 @@ class Template: # {{{ functions just like {{ in anki @modifier('{') - def render_tag(self, tag_name, context): + def render_tag(self, tag_name, context) -> Any: return self.render_unescaped(tag_name, context) @modifier('!') - def render_comment(self, tag_name=None, context=None): + def render_comment(self, tag_name=None, context=None) -> str: """Rendering a comment always returns nothing.""" return '' @modifier(None) - def render_unescaped(self, tag_name=None, context=None): + def render_unescaped(self, tag_name=None, context=None) -> Any: """Render a tag without escaping it.""" txt = get_or_attr(context, tag_name) if txt is not None: @@ -192,7 +193,7 @@ class Template: return '{unknown field %s}' % tag_name return txt - def clozeText(self, txt, ord, type): + def clozeText(self, txt, ord, type) -> str: reg = clozeReg if not re.search(reg%ord, txt): return "" @@ -215,7 +216,7 @@ class Template: return re.sub(reg%r"\d+", "\\2", txt) # look for clozes wrapped in mathjax, and change {{cx to {{Cx - def _removeFormattingFromMathjax(self, txt, ord): + def _removeFormattingFromMathjax(self, txt, ord) -> str: opening = ["\\(", "\\["] closing = ["\\)", "\\]"] # flags in middle of expression deprecated @@ -237,7 +238,7 @@ class Template: return txt @modifier('=') - def render_delimiter(self, tag_name=None, context=None): + def render_delimiter(self, tag_name=None, context=None) -> Optional[str]: """Changes the Mustache delimiter.""" try: self.otag, self.ctag = tag_name.split(' ') diff --git a/anki/template/view.py b/anki/template/view.py index 99110e071..29c8e0906 100644 --- a/anki/template/view.py +++ b/anki/template/view.py @@ -1,6 +1,7 @@ from .template import Template import os.path import re +from typing import Any class View: # Path where this view's template(s) live @@ -24,7 +25,7 @@ class View: # do any decoding of the template. template_encoding = None - def __init__(self, template=None, context=None, **kwargs): + def __init__(self, template=None, context=None, **kwargs) -> None: self.template = template self.context = context or {} @@ -36,7 +37,7 @@ class View: if kwargs: self.context.update(kwargs) - def inherit_settings(self, view): + def inherit_settings(self, view) -> None: """Given another View, copies its settings.""" if view.template_path: self.template_path = view.template_path @@ -44,7 +45,7 @@ class View: if view.template_name: self.template_name = view.template_name - def load_template(self): + def load_template(self) -> Any: if self.template: return self.template @@ -65,7 +66,7 @@ class View: raise IOError('"%s" not found in "%s"' % (name, ':'.join(self.template_path),)) - def _load_template(self): + def _load_template(self) -> str: f = open(self.template_file, 'r') try: template = f.read() @@ -75,7 +76,7 @@ class View: f.close() return template - def get_template_name(self, name=None): + def get_template_name(self, name=None) -> Any: """TemplatePartial => template_partial Takes a string but defaults to using the current class' name or the `template_name` attribute @@ -91,16 +92,16 @@ class View: return re.sub('[A-Z]', repl, name)[1:] - def __contains__(self, needle): + def __contains__(self, needle) -> bool: return needle in self.context or hasattr(self, needle) - def __getitem__(self, attr): + def __getitem__(self, attr) -> Any: val = self.get(attr, None) if not val: raise KeyError("No such key.") return val - def get(self, attr, default): + def get(self, attr, default) -> Any: attr = self.context.get(attr, getattr(self, attr, default)) if hasattr(attr, '__call__'): @@ -108,9 +109,9 @@ class View: else: return attr - def render(self, encoding=None): + def render(self, encoding=None) -> str: template = self.load_template() return Template(template, self).render(encoding=encoding) - def __str__(self): + def __str__(self) -> str: return self.render() diff --git a/anki/utils.py b/anki/utils.py index 10f091dbd..cec62d6e4 100644 --- a/anki/utils.py +++ b/anki/utils.py @@ -22,11 +22,14 @@ from anki.lang import _, ngettext # some add-ons expect json to be in the utils module import json # pylint: disable=unused-import +from typing import Any, Optional, Tuple + +_tmpdir: Optional[str] # Time handling ############################################################################## -def intTime(scale=1): +def intTime(scale=1) -> int: "The time in integer seconds. Pass scale=1000 to get milliseconds." return int(time.time()*scale) @@ -48,7 +51,7 @@ inTimeTable = { "seconds": lambda n: ngettext("in %s second", "in %s seconds", n), } -def shortTimeFmt(type): +def shortTimeFmt(type) -> Any: return { #T: year is an abbreviation for year. %s is a number of years "years": _("%sy"), @@ -64,7 +67,7 @@ def shortTimeFmt(type): "seconds": _("%ss"), }[type] -def fmtTimeSpan(time, pad=0, point=0, short=False, inTime=False, unit=99): +def fmtTimeSpan(time, pad=0, point=0, short=False, inTime=False, unit=99) -> str: "Return a string representing a time span (eg '2 days')." (type, point) = optimalPeriod(time, point, unit) time = convertSecondsTo(time, type) @@ -80,7 +83,7 @@ def fmtTimeSpan(time, pad=0, point=0, short=False, inTime=False, unit=99): timestr = "%%%(a)d.%(b)df" % {'a': pad, 'b': point} return locale.format_string(fmt % timestr, time) -def optimalPeriod(time, point, unit): +def optimalPeriod(time, point, unit) -> Tuple[str, Any]: if abs(time) < 60 or unit < 1: type = "seconds" point -= 1 @@ -98,7 +101,7 @@ def optimalPeriod(time, point, unit): point += 1 return (type, max(point, 0)) -def convertSecondsTo(seconds, type): +def convertSecondsTo(seconds, type) -> Any: if type == "seconds": return seconds elif type == "minutes": @@ -113,7 +116,7 @@ def convertSecondsTo(seconds, type): return seconds / 31536000 assert False -def _pluralCount(time, point): +def _pluralCount(time, point) -> int: if point: return 2 return math.floor(time) @@ -121,12 +124,12 @@ def _pluralCount(time, point): # Locale ############################################################################## -def fmtPercentage(float_value, point=1): +def fmtPercentage(float_value, point=1) -> str: "Return float with percentage sign" fmt = '%' + "0.%(b)df" % {'b': point} return locale.format_string(fmt, float_value) + "%" -def fmtFloat(float_value, point=1): +def fmtFloat(float_value, point=1) -> str: "Return a string with decimal separator according to current locale" fmt = '%' + "0.%(b)df" % {'b': point} return locale.format_string(fmt, float_value) @@ -140,7 +143,7 @@ reTag = re.compile("(?s)<.*?>") reEnts = re.compile(r"&#?\w+;") reMedia = re.compile("(?i)]+src=[\"']?([^\"'>]+)[\"']?[^>]*>") -def stripHTML(s): +def stripHTML(s) -> str: s = reComment.sub("", s) s = reStyle.sub("", s) s = reScript.sub("", s) @@ -148,12 +151,12 @@ def stripHTML(s): s = entsToTxt(s) return s -def stripHTMLMedia(s): +def stripHTMLMedia(s) -> Any: "Strip HTML but keep media filenames" s = reMedia.sub(" \\1 ", s) return stripHTML(s) -def minimizeHTML(s): +def minimizeHTML(s) -> str: "Correct Qt's verbose bold/underline/etc." s = re.sub('(.*?)', '\\1', s) @@ -163,7 +166,7 @@ def minimizeHTML(s): '\\1', s) return s -def htmlToTextLine(s): +def htmlToTextLine(s) -> Any: s = s.replace("
", " ") s = s.replace("
", " ") s = s.replace("
", " ") @@ -174,7 +177,7 @@ def htmlToTextLine(s): s = s.strip() return s -def entsToTxt(html): +def entsToTxt(html) -> str: # entitydefs defines nbsp as \xa0 instead of a standard space, so we # replace it first html = html.replace(" ", " ") @@ -198,7 +201,7 @@ def entsToTxt(html): return text # leave as is return reEnts.sub(fixup, html) -def bodyClass(col, card): +def bodyClass(col, card) -> str: bodyclass = "card card%d" % (card.ord+1) if col.conf.get("nightMode"): bodyclass += " nightMode" @@ -207,17 +210,17 @@ def bodyClass(col, card): # IDs ############################################################################## -def hexifyID(id): +def hexifyID(id) -> str: return "%x" % int(id) -def dehexifyID(id): +def dehexifyID(id) -> int: return int(id, 16) -def ids2str(ids): +def ids2str(ids) -> str: """Given a list of integers, return a string '(int1,int2,...)'.""" return "(%s)" % ",".join(str(i) for i in ids) -def timestampID(db, table): +def timestampID(db, table) -> int: "Return a non-conflicting timestamp for table." # be careful not to create multiple objects without flushing them, or they # may share an ID. @@ -226,7 +229,7 @@ def timestampID(db, table): t += 1 return t -def maxID(db): +def maxID(db) -> Any: "Return the first safe ID to use." now = intTime(1000) for tbl in "cards", "notes": @@ -234,7 +237,7 @@ def maxID(db): return now + 1 # used in ankiweb -def base62(num, extra=""): +def base62(num, extra="") -> str: s = string; table = s.ascii_letters + s.digits + extra buf = "" while num: @@ -243,19 +246,19 @@ def base62(num, extra=""): return buf _base91_extra_chars = "!#$%&()*+,-./:;<=>?@[]^_`{|}~" -def base91(num): +def base91(num) -> str: # all printable characters minus quotes, backslash and separators return base62(num, _base91_extra_chars) -def guid64(): +def guid64() -> Any: "Return a base91-encoded 64bit random number." return base91(random.randint(0, 2**64-1)) # increment a guid by one, for note type conflicts -def incGuid(guid): +def incGuid(guid) -> str: return _incGuid(guid[::-1])[::-1] -def _incGuid(guid): +def _incGuid(guid) -> str: s = string; table = s.ascii_letters + s.digits + _base91_extra_chars idx = table.index(guid[0]) if idx + 1 == len(table): @@ -268,21 +271,21 @@ def _incGuid(guid): # Fields ############################################################################## -def joinFields(list): +def joinFields(list) -> str: return "\x1f".join(list) -def splitFields(string): +def splitFields(string) -> Any: return string.split("\x1f") # Checksums ############################################################################## -def checksum(data): +def checksum(data) -> str: if isinstance(data, str): data = data.encode("utf-8") return sha1(data).hexdigest() -def fieldChecksum(data): +def fieldChecksum(data) -> int: # 32 bit unsigned number from first 8 digits of sha1 hash return int(checksum(stripHTMLMedia(data).encode("utf-8"))[:8], 16) @@ -291,7 +294,7 @@ def fieldChecksum(data): _tmpdir = None -def tmpdir(): +def tmpdir() -> Any: "A reusable temp folder which we clean out on each program invocation." global _tmpdir if not _tmpdir: @@ -305,12 +308,12 @@ def tmpdir(): os.mkdir(_tmpdir) return _tmpdir -def tmpfile(prefix="", suffix=""): +def tmpfile(prefix="", suffix="") -> Any: (fd, name) = tempfile.mkstemp(dir=tmpdir(), prefix=prefix, suffix=suffix) os.close(fd) return name -def namedtmp(name, rm=True): +def namedtmp(name, rm=True) -> Any: "Return tmpdir+name. Deletes any existing file." path = os.path.join(tmpdir(), name) if rm: @@ -330,7 +333,7 @@ def noBundledLibs(): if oldlpath is not None: os.environ["LD_LIBRARY_PATH"] = oldlpath -def call(argv, wait=True, **kwargs): +def call(argv, wait=True, **kwargs) -> int: "Execute a command. If WAIT, return exit code." # ensure we don't open a separate window for forking process on windows if isWin: @@ -372,7 +375,7 @@ devMode = os.getenv("ANKIDEV", "") invalidFilenameChars = ":*?\"<>|" -def invalidFilename(str, dirsep=True): +def invalidFilename(str, dirsep=True) -> Optional[str]: for c in invalidFilenameChars: if c in str: return c @@ -383,7 +386,7 @@ def invalidFilename(str, dirsep=True): elif str.strip().startswith("."): return "." -def platDesc(): +def platDesc() -> str: # we may get an interrupted system call, so try this in a loop n = 0 theos = "unknown" @@ -410,9 +413,9 @@ def platDesc(): ############################################################################## class TimedLog: - def __init__(self): + def __init__(self) -> None: self._last = time.time() - def log(self, s): + def log(self, s) -> None: path, num, fn, y = traceback.extract_stack(limit=2)[0] sys.stderr.write("%5dms: %s(): %s\n" % ((time.time() - self._last)*1000, fn, s)) self._last = time.time() @@ -420,7 +423,7 @@ class TimedLog: # Version ############################################################################## -def versionWithBuild(): +def versionWithBuild() -> str: from anki import version try: from anki.buildhash import build # type: ignore diff --git a/aqt/main.py b/aqt/main.py index f098adc6a..32c294bf8 100644 --- a/aqt/main.py +++ b/aqt/main.py @@ -13,6 +13,7 @@ from threading import Thread from typing import Optional from send2trash import send2trash from anki.collection import _Collection +from aqt.profiles import ProfileManager as ProfileManagerType from aqt.qt import * from anki.storage import Collection from anki.utils import isWin, isMac, intTime, splitFields, ids2str, \ @@ -33,7 +34,7 @@ from aqt.qt import sip from anki.lang import _, ngettext class AnkiQt(QMainWindow): - def __init__(self, app: QApplication, profileManager, opts, args): + def __init__(self, app: QApplication, profileManager: ProfileManagerType, opts, args): QMainWindow.__init__(self) self.state = "startup" self.opts = opts