clone db.py into dbproxy.py
This commit is contained in:
parent
0d43e9dca3
commit
c1252d68f0
@ -23,7 +23,7 @@ import anki.template
|
||||
from anki import hooks
|
||||
from anki.cards import Card
|
||||
from anki.consts import *
|
||||
from anki.db import DB
|
||||
from anki.dbproxy import DBProxy
|
||||
from anki.decks import DeckManager
|
||||
from anki.errors import AnkiError
|
||||
from anki.lang import _, ngettext
|
||||
@ -67,7 +67,7 @@ defaultConf = {
|
||||
|
||||
# this is initialized by storage.Collection
|
||||
class _Collection:
|
||||
db: Optional[DB]
|
||||
db: Optional[DBProxy]
|
||||
sched: Union[V1Scheduler, V2Scheduler]
|
||||
crt: int
|
||||
mod: int
|
||||
@ -80,7 +80,7 @@ class _Collection:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: DB,
|
||||
db: DBProxy,
|
||||
backend: RustBackend,
|
||||
server: Optional["anki.storage.ServerData"] = None,
|
||||
log: bool = False,
|
||||
@ -267,7 +267,7 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""",
|
||||
def reopen(self) -> None:
|
||||
"Reconnect to DB (after changing threads, etc)."
|
||||
if not self.db:
|
||||
self.db = DB(self.path)
|
||||
self.db = DBProxy(self.path)
|
||||
self.media.connect()
|
||||
self._openLog()
|
||||
|
||||
|
112
pylib/anki/dbproxy.py
Normal file
112
pylib/anki/dbproxy.py
Normal file
@ -0,0 +1,112 @@
|
||||
# Copyright: Ankitects Pty Ltd and contributors
|
||||
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
|
||||
|
||||
import os
|
||||
import time
|
||||
from sqlite3 import Cursor
|
||||
from sqlite3 import dbapi2 as sqlite
|
||||
from typing import Any, List, Type
|
||||
|
||||
|
||||
class DBProxy:
|
||||
def __init__(self, path: str, timeout: int = 0) -> None:
|
||||
self._db = sqlite.connect(path, timeout=timeout)
|
||||
self._db.text_factory = self._textFactory
|
||||
self._path = path
|
||||
self.echo = os.environ.get("DBECHO")
|
||||
self.mod = False
|
||||
|
||||
def execute(self, sql: str, *a, **ka) -> Cursor:
|
||||
s = sql.strip().lower()
|
||||
# mark modified?
|
||||
for stmt in "insert", "update", "delete":
|
||||
if s.startswith(stmt):
|
||||
self.mod = True
|
||||
t = time.time()
|
||||
if ka:
|
||||
# execute("...where id = :id", id=5)
|
||||
res = self._db.execute(sql, ka)
|
||||
else:
|
||||
# execute("...where id = ?", 5)
|
||||
res = self._db.execute(sql, a)
|
||||
if self.echo:
|
||||
# print a, ka
|
||||
print(sql, "%0.3fms" % ((time.time() - t) * 1000))
|
||||
if self.echo == "2":
|
||||
print(a, ka)
|
||||
return res
|
||||
|
||||
def executemany(self, sql: str, l: Any) -> None:
|
||||
self.mod = True
|
||||
t = time.time()
|
||||
self._db.executemany(sql, l)
|
||||
if self.echo:
|
||||
print(sql, "%0.3fms" % ((time.time() - t) * 1000))
|
||||
if self.echo == "2":
|
||||
print(l)
|
||||
|
||||
def commit(self) -> None:
|
||||
t = time.time()
|
||||
self._db.commit()
|
||||
if self.echo:
|
||||
print("commit %0.3fms" % ((time.time() - t) * 1000))
|
||||
|
||||
def executescript(self, sql: str) -> None:
|
||||
self.mod = True
|
||||
if self.echo:
|
||||
print(sql)
|
||||
self._db.executescript(sql)
|
||||
|
||||
def rollback(self) -> None:
|
||||
self._db.rollback()
|
||||
|
||||
def scalar(self, *a, **kw) -> Any:
|
||||
res = self.execute(*a, **kw).fetchone()
|
||||
if res:
|
||||
return res[0]
|
||||
return None
|
||||
|
||||
def all(self, *a, **kw) -> List:
|
||||
return self.execute(*a, **kw).fetchall()
|
||||
|
||||
def first(self, *a, **kw) -> Any:
|
||||
c = self.execute(*a, **kw)
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
return res
|
||||
|
||||
def list(self, *a, **kw) -> List:
|
||||
return [x[0] for x in self.execute(*a, **kw)]
|
||||
|
||||
def close(self) -> None:
|
||||
self._db.text_factory = None
|
||||
self._db.close()
|
||||
|
||||
def set_progress_handler(self, *args) -> None:
|
||||
self._db.set_progress_handler(*args)
|
||||
|
||||
def __enter__(self) -> "DBProxy":
|
||||
self._db.execute("begin")
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, *args) -> None:
|
||||
self._db.close()
|
||||
|
||||
def totalChanges(self) -> Any:
|
||||
return self._db.total_changes
|
||||
|
||||
def interrupt(self) -> None:
|
||||
self._db.interrupt()
|
||||
|
||||
def setAutocommit(self, autocommit: bool) -> None:
|
||||
if autocommit:
|
||||
self._db.isolation_level = None
|
||||
else:
|
||||
self._db.isolation_level = ""
|
||||
|
||||
# strip out invalid utf-8 when reading from db
|
||||
def _textFactory(self, data: bytes) -> str:
|
||||
return str(data, errors="ignore")
|
||||
|
||||
def cursor(self, factory: Type[Cursor] = Cursor) -> Cursor:
|
||||
return self._db.cursor(factory)
|
@ -9,7 +9,7 @@ from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from anki.collection import _Collection
|
||||
from anki.consts import *
|
||||
from anki.db import DB
|
||||
from anki.dbproxy import DBProxy
|
||||
from anki.lang import _
|
||||
from anki.media import media_paths_from_col_path
|
||||
from anki.rsbackend import RustBackend
|
||||
@ -44,7 +44,7 @@ def Collection(
|
||||
for c in ("/", ":", "\\"):
|
||||
assert c not in base
|
||||
# connect
|
||||
db = DB(path)
|
||||
db = DBProxy(path)
|
||||
db.setAutocommit(True)
|
||||
if create:
|
||||
ver = _createDB(db)
|
||||
@ -78,7 +78,7 @@ def Collection(
|
||||
return col
|
||||
|
||||
|
||||
def _upgradeSchema(db: DB) -> Any:
|
||||
def _upgradeSchema(db: DBProxy) -> Any:
|
||||
ver = db.scalar("select ver from col")
|
||||
if ver == SCHEMA_VERSION:
|
||||
return ver
|
||||
@ -238,7 +238,7 @@ def _upgradeClozeModel(col, m) -> None:
|
||||
######################################################################
|
||||
|
||||
|
||||
def _createDB(db: DB) -> int:
|
||||
def _createDB(db: DBProxy) -> int:
|
||||
db.execute("pragma page_size = 4096")
|
||||
db.execute("pragma legacy_file_format = 0")
|
||||
db.execute("vacuum")
|
||||
@ -248,7 +248,7 @@ def _createDB(db: DB) -> int:
|
||||
return SCHEMA_VERSION
|
||||
|
||||
|
||||
def _addSchema(db: DB, setColConf: bool = True) -> None:
|
||||
def _addSchema(db: DBProxy, setColConf: bool = True) -> None:
|
||||
db.executescript(
|
||||
"""
|
||||
create table if not exists col (
|
||||
@ -329,7 +329,7 @@ values(1,0,0,%(s)s,%(v)s,0,0,0,'','{}','','','{}');
|
||||
_addColVars(db, *_getColVars(db))
|
||||
|
||||
|
||||
def _getColVars(db: DB) -> Tuple[Any, Any, Dict[str, Any]]:
|
||||
def _getColVars(db: DBProxy) -> Tuple[Any, Any, Dict[str, Any]]:
|
||||
import anki.collection
|
||||
import anki.decks
|
||||
|
||||
@ -344,7 +344,7 @@ def _getColVars(db: DB) -> Tuple[Any, Any, Dict[str, Any]]:
|
||||
|
||||
|
||||
def _addColVars(
|
||||
db: DB, g: Dict[str, Any], gc: Dict[str, Any], c: Dict[str, Any]
|
||||
db: DBProxy, g: Dict[str, Any], gc: Dict[str, Any], c: Dict[str, Any]
|
||||
) -> None:
|
||||
db.execute(
|
||||
"""
|
||||
@ -355,7 +355,7 @@ update col set conf = ?, decks = ?, dconf = ?""",
|
||||
)
|
||||
|
||||
|
||||
def _updateIndices(db: DB) -> None:
|
||||
def _updateIndices(db: DBProxy) -> None:
|
||||
"Add indices to the DB."
|
||||
db.executescript(
|
||||
"""
|
||||
|
@ -22,7 +22,7 @@ from hashlib import sha1
|
||||
from html.entities import name2codepoint
|
||||
from typing import Iterable, Iterator, List, Optional, Union
|
||||
|
||||
from anki.db import DB
|
||||
from anki.dbproxy import DBProxy
|
||||
|
||||
_tmpdir: Optional[str]
|
||||
|
||||
@ -142,7 +142,7 @@ def ids2str(ids: Iterable[Union[int, str]]) -> str:
|
||||
return "(%s)" % ",".join(str(i) for i in ids)
|
||||
|
||||
|
||||
def timestampID(db: DB, table: str) -> int:
|
||||
def timestampID(db: DBProxy, table: str) -> int:
|
||||
"Return a non-conflicting timestamp for table."
|
||||
# be careful not to create multiple objects without flushing them, or they
|
||||
# may share an ID.
|
||||
@ -152,7 +152,7 @@ def timestampID(db: DB, table: str) -> int:
|
||||
return t
|
||||
|
||||
|
||||
def maxID(db: DB) -> int:
|
||||
def maxID(db: DBProxy) -> int:
|
||||
"Return the first safe ID to use."
|
||||
now = intTime(1000)
|
||||
for tbl in "cards", "notes":
|
||||
|
Loading…
Reference in New Issue
Block a user