clone db.py into dbproxy.py

This commit is contained in:
Damien Elmes 2020-03-02 14:13:57 +10:00
parent 0d43e9dca3
commit c1252d68f0
4 changed files with 127 additions and 15 deletions

View File

@ -23,7 +23,7 @@ import anki.template
from anki import hooks
from anki.cards import Card
from anki.consts import *
from anki.db import DB
from anki.dbproxy import DBProxy
from anki.decks import DeckManager
from anki.errors import AnkiError
from anki.lang import _, ngettext
@ -67,7 +67,7 @@ defaultConf = {
# this is initialized by storage.Collection
class _Collection:
db: Optional[DB]
db: Optional[DBProxy]
sched: Union[V1Scheduler, V2Scheduler]
crt: int
mod: int
@ -80,7 +80,7 @@ class _Collection:
def __init__(
self,
db: DB,
db: DBProxy,
backend: RustBackend,
server: Optional["anki.storage.ServerData"] = None,
log: bool = False,
@ -267,7 +267,7 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""",
def reopen(self) -> None:
"Reconnect to DB (after changing threads, etc)."
if not self.db:
self.db = DB(self.path)
self.db = DBProxy(self.path)
self.media.connect()
self._openLog()

112
pylib/anki/dbproxy.py Normal file
View File

@ -0,0 +1,112 @@
# Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
import os
import time
from sqlite3 import Cursor
from sqlite3 import dbapi2 as sqlite
from typing import Any, List, Type
class DBProxy:
def __init__(self, path: str, timeout: int = 0) -> None:
self._db = sqlite.connect(path, timeout=timeout)
self._db.text_factory = self._textFactory
self._path = path
self.echo = os.environ.get("DBECHO")
self.mod = False
def execute(self, sql: str, *a, **ka) -> Cursor:
s = sql.strip().lower()
# mark modified?
for stmt in "insert", "update", "delete":
if s.startswith(stmt):
self.mod = True
t = time.time()
if ka:
# execute("...where id = :id", id=5)
res = self._db.execute(sql, ka)
else:
# execute("...where id = ?", 5)
res = self._db.execute(sql, a)
if self.echo:
# print a, ka
print(sql, "%0.3fms" % ((time.time() - t) * 1000))
if self.echo == "2":
print(a, ka)
return res
def executemany(self, sql: str, l: Any) -> None:
self.mod = True
t = time.time()
self._db.executemany(sql, l)
if self.echo:
print(sql, "%0.3fms" % ((time.time() - t) * 1000))
if self.echo == "2":
print(l)
def commit(self) -> None:
t = time.time()
self._db.commit()
if self.echo:
print("commit %0.3fms" % ((time.time() - t) * 1000))
def executescript(self, sql: str) -> None:
self.mod = True
if self.echo:
print(sql)
self._db.executescript(sql)
def rollback(self) -> None:
self._db.rollback()
def scalar(self, *a, **kw) -> Any:
res = self.execute(*a, **kw).fetchone()
if res:
return res[0]
return None
def all(self, *a, **kw) -> List:
return self.execute(*a, **kw).fetchall()
def first(self, *a, **kw) -> Any:
c = self.execute(*a, **kw)
res = c.fetchone()
c.close()
return res
def list(self, *a, **kw) -> List:
return [x[0] for x in self.execute(*a, **kw)]
def close(self) -> None:
self._db.text_factory = None
self._db.close()
def set_progress_handler(self, *args) -> None:
self._db.set_progress_handler(*args)
def __enter__(self) -> "DBProxy":
self._db.execute("begin")
return self
def __exit__(self, exc_type, *args) -> None:
self._db.close()
def totalChanges(self) -> Any:
return self._db.total_changes
def interrupt(self) -> None:
self._db.interrupt()
def setAutocommit(self, autocommit: bool) -> None:
if autocommit:
self._db.isolation_level = None
else:
self._db.isolation_level = ""
# strip out invalid utf-8 when reading from db
def _textFactory(self, data: bytes) -> str:
return str(data, errors="ignore")
def cursor(self, factory: Type[Cursor] = Cursor) -> Cursor:
return self._db.cursor(factory)

View File

@ -9,7 +9,7 @@ from typing import Any, Dict, Optional, Tuple
from anki.collection import _Collection
from anki.consts import *
from anki.db import DB
from anki.dbproxy import DBProxy
from anki.lang import _
from anki.media import media_paths_from_col_path
from anki.rsbackend import RustBackend
@ -44,7 +44,7 @@ def Collection(
for c in ("/", ":", "\\"):
assert c not in base
# connect
db = DB(path)
db = DBProxy(path)
db.setAutocommit(True)
if create:
ver = _createDB(db)
@ -78,7 +78,7 @@ def Collection(
return col
def _upgradeSchema(db: DB) -> Any:
def _upgradeSchema(db: DBProxy) -> Any:
ver = db.scalar("select ver from col")
if ver == SCHEMA_VERSION:
return ver
@ -238,7 +238,7 @@ def _upgradeClozeModel(col, m) -> None:
######################################################################
def _createDB(db: DB) -> int:
def _createDB(db: DBProxy) -> int:
db.execute("pragma page_size = 4096")
db.execute("pragma legacy_file_format = 0")
db.execute("vacuum")
@ -248,7 +248,7 @@ def _createDB(db: DB) -> int:
return SCHEMA_VERSION
def _addSchema(db: DB, setColConf: bool = True) -> None:
def _addSchema(db: DBProxy, setColConf: bool = True) -> None:
db.executescript(
"""
create table if not exists col (
@ -329,7 +329,7 @@ values(1,0,0,%(s)s,%(v)s,0,0,0,'','{}','','','{}');
_addColVars(db, *_getColVars(db))
def _getColVars(db: DB) -> Tuple[Any, Any, Dict[str, Any]]:
def _getColVars(db: DBProxy) -> Tuple[Any, Any, Dict[str, Any]]:
import anki.collection
import anki.decks
@ -344,7 +344,7 @@ def _getColVars(db: DB) -> Tuple[Any, Any, Dict[str, Any]]:
def _addColVars(
db: DB, g: Dict[str, Any], gc: Dict[str, Any], c: Dict[str, Any]
db: DBProxy, g: Dict[str, Any], gc: Dict[str, Any], c: Dict[str, Any]
) -> None:
db.execute(
"""
@ -355,7 +355,7 @@ update col set conf = ?, decks = ?, dconf = ?""",
)
def _updateIndices(db: DB) -> None:
def _updateIndices(db: DBProxy) -> None:
"Add indices to the DB."
db.executescript(
"""

View File

@ -22,7 +22,7 @@ from hashlib import sha1
from html.entities import name2codepoint
from typing import Iterable, Iterator, List, Optional, Union
from anki.db import DB
from anki.dbproxy import DBProxy
_tmpdir: Optional[str]
@ -142,7 +142,7 @@ def ids2str(ids: Iterable[Union[int, str]]) -> str:
return "(%s)" % ",".join(str(i) for i in ids)
def timestampID(db: DB, table: str) -> int:
def timestampID(db: DBProxy, table: str) -> int:
"Return a non-conflicting timestamp for table."
# be careful not to create multiple objects without flushing them, or they
# may share an ID.
@ -152,7 +152,7 @@ def timestampID(db: DB, table: str) -> int:
return t
def maxID(db: DB) -> int:
def maxID(db: DBProxy) -> int:
"Return the first safe ID to use."
now = intTime(1000)
for tbl in "cards", "notes":