tweak db type hints

This commit is contained in:
Damien Elmes 2020-03-03 12:05:33 +10:00
parent b5c6134d80
commit 77cf7dd4b7
3 changed files with 23 additions and 13 deletions

View File

@ -5,7 +5,15 @@
# fixme: progress
from sqlite3 import dbapi2 as sqlite
from typing import Any, Iterable, List, Optional
from typing import Any, Iterable, List, Optional, Sequence, Union
# DBValue is actually Union[str, int, float, None], but if defined
# that way, every call site needs to do a type check prior to using
# the return values.
ValueFromDB = Any
Row = Sequence[ValueFromDB]
ValueForDB = Union[str, int, float, None]
class DBProxy:
@ -38,7 +46,9 @@ class DBProxy:
# Querying
################
def _query(self, sql: str, *args, first_row_only: bool = False) -> List[List]:
def _query(
self, sql: str, *args: ValueForDB, first_row_only: bool = False
) -> List[Row]:
# mark modified?
s = sql.strip().lower()
for stmt in "insert", "update", "delete":
@ -59,20 +69,20 @@ class DBProxy:
# Query shortcuts
###################
def all(self, sql: str, *args) -> List:
def all(self, sql: str, *args: ValueForDB) -> List[Row]:
return self._query(sql, *args)
def list(self, sql: str, *args) -> List:
def list(self, sql: str, *args: ValueForDB) -> List[ValueFromDB]:
return [x[0] for x in self._query(sql, *args)]
def first(self, sql: str, *args) -> Optional[List]:
def first(self, sql: str, *args: ValueForDB) -> Optional[Row]:
rows = self._query(sql, *args, first_row_only=True)
if rows:
return rows[0]
else:
return None
def scalar(self, sql: str, *args) -> Optional[Any]:
def scalar(self, sql: str, *args: ValueForDB) -> ValueFromDB:
rows = self._query(sql, *args, first_row_only=True)
if rows:
return rows[0][0]
@ -86,7 +96,7 @@ class DBProxy:
# Updates
################
def executemany(self, sql: str, args: Iterable) -> None:
def executemany(self, sql: str, args: Iterable[Iterable[ValueForDB]]) -> None:
self.mod = True
self._db.executemany(sql, args)

View File

@ -138,8 +138,8 @@ class Scheduler:
def dueForecast(self, days: int = 7) -> List[Any]:
"Return counts over next DAYS. Includes today."
daysd = dict(
self.col.db.all(
daysd: Dict[int, int] = dict(
self.col.db.all( # type: ignore
f"""
select due, count() from cards
where did in %s and queue = {QUEUE_TYPE_REV}
@ -542,7 +542,7 @@ select count() from cards where did in %s and queue = {QUEUE_TYPE_PREVIEW}
if self._lrnQueue:
return True
cutoff = intTime() + self.col.conf["collapseTime"]
self._lrnQueue = self.col.db.all(
self._lrnQueue = self.col.db.all( # type: ignore
f"""
select due, id from cards where
did in %s and queue in ({QUEUE_TYPE_LRN},{QUEUE_TYPE_PREVIEW}) and due < ?

View File

@ -8,7 +8,7 @@ import io
import json
import os
import random
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import anki
from anki.consts import *
@ -31,7 +31,7 @@ class UnexpectedSchemaChange(Exception):
class Syncer:
chunkRows: Optional[List[List]]
chunkRows: Optional[List[Sequence]]
def __init__(self, col: anki.storage._Collection, server=None) -> None:
self.col = col.weakref()
@ -248,7 +248,7 @@ class Syncer:
self.tablesLeft = ["revlog", "cards", "notes"]
self.chunkRows = None
def getChunkRows(self, table) -> List[List]:
def getChunkRows(self, table) -> List[Sequence]:
lim = self.usnLim()
x = self.col.db.all
d = (self.maxUsn, lim)