diff --git a/pylib/anki/collection.py b/pylib/anki/collection.py index 615c86df9..989d18925 100644 --- a/pylib/anki/collection.py +++ b/pylib/anki/collection.py @@ -242,10 +242,7 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""", return None def lock(self) -> None: - # make sure we don't accidentally bump mod time - mod = self.db.mod - self.db.execute("update col set mod=mod") - self.db.mod = mod + self.db.begin() def close(self, save: bool = True) -> None: "Disconnect from DB." @@ -260,6 +257,7 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""", self.db.setAutocommit(False) self.db.close() self.db = None + self.backend = None self.media.close() self._closeLog() diff --git a/pylib/anki/dbproxy.py b/pylib/anki/dbproxy.py index d18c0cce2..ae5a8f779 100644 --- a/pylib/anki/dbproxy.py +++ b/pylib/anki/dbproxy.py @@ -3,11 +3,14 @@ from __future__ import annotations +import weakref from typing import Any, Iterable, List, Optional, Sequence, Union import anki -# fixme: remember to null on close to avoid circular ref +# fixme: col.reopen() +# fixme: setAutocommit() +# fixme: transaction/lock handling # fixme: progress # DBValue is actually Union[str, int, float, None], but if defined @@ -24,7 +27,7 @@ class DBProxy: ############### def __init__(self, backend: anki.rsbackend.RustBackend, path: str) -> None: - self._backend = backend + self._backend = weakref.proxy(backend) self._path = path self.mod = False @@ -35,17 +38,20 @@ class DBProxy: # Transactions ############### + def begin(self) -> None: + self._backend.db_begin() + def commit(self) -> None: - # fixme - pass + self._backend.db_commit() def rollback(self) -> None: - # fixme - pass + self._backend.db_rollback() def setAutocommit(self, autocommit: bool) -> None: - # fixme - pass + if autocommit: + self.commit() + else: + self.begin() # Querying ################ @@ -58,6 +64,7 @@ class DBProxy: for stmt in "insert", "update", "delete": if s.startswith(stmt): self.mod = True + assert ":" not in sql # fetch rows # fixme: first_row_only return self._backend.db_query(sql, args) diff --git a/pylib/anki/rsbackend.py b/pylib/anki/rsbackend.py index a44996ed9..ab2534a54 100644 --- a/pylib/anki/rsbackend.py +++ b/pylib/anki/rsbackend.py @@ -6,6 +6,7 @@ import enum import os from dataclasses import dataclass from typing import ( + Any, Callable, Dict, Iterable, @@ -15,7 +16,7 @@ from typing import ( Optional, Tuple, Union, - Any) +) import ankirspy # pytype: disable=import-error import orjson @@ -386,9 +387,20 @@ class RustBackend: self._run_command(pb.BackendInput(restore_trash=pb.Empty())) def db_query(self, sql: str, args: Iterable[ValueForDB]) -> List[DBRow]: - input = orjson.dumps(dict(sql=sql, args=args)) - output = self._backend.db_query(input) - return orjson.loads(output) + return self._db_command(dict(kind="query", sql=sql, args=args)) + + def db_begin(self) -> None: + return self._db_command(dict(kind="begin")) + + def db_commit(self) -> None: + return self._db_command(dict(kind="commit")) + + def db_rollback(self) -> None: + return self._db_command(dict(kind="rollback")) + + def _db_command(self, input: Dict[str, Any]) -> Any: + return orjson.loads(self._backend.db_command(orjson.dumps(input))) + def translate_string_in( key: TR, **kwargs: Union[str, int, float] diff --git a/pylib/anki/storage.py b/pylib/anki/storage.py index c81dac715..0b626d1bb 100644 --- a/pylib/anki/storage.py +++ b/pylib/anki/storage.py @@ -26,9 +26,7 @@ class ServerData: minutes_west: Optional[int] = None -def Collection( - path: str, lock: bool = True, server: Optional[ServerData] = None, log: bool = False -) -> _Collection: +def Collection(path: str, server: Optional[ServerData] = None) -> _Collection: "Open a new or existing collection. Path must be unicode." assert path.endswith(".anki2") (media_dir, media_db) = media_paths_from_col_path(path) @@ -36,33 +34,23 @@ def Collection( if not server: log_path = path.replace(".anki2", "2.log") path = os.path.abspath(path) - create = not os.path.exists(path) - if create: - base = os.path.basename(path) - for c in ("/", ":", "\\"): - assert c not in base # connect backend = RustBackend( path, media_dir, media_db, log_path, server=server is not None ) db = DBProxy(backend, path) + db.begin() db.setAutocommit(True) + + # initial setup required? + create = db.scalar("select models = '{}' from col") if create: - ver = _createDB(db) - else: - ver = _upgradeSchema(db) - db.execute("pragma temp_store = memory") - db.execute("pragma cache_size = 10000") - if not isWin: - db.execute("pragma journal_mode = wal") + initial_db_setup(db) + db.setAutocommit(False) # add db to col and do any remaining upgrades - col = _Collection(db, backend=backend, server=server, log=log) - if ver < SCHEMA_VERSION: - raise Exception("This file requires an older version of Anki.") - elif ver > SCHEMA_VERSION: - raise Exception("This file requires a newer version of Anki.") - elif create: + col = _Collection(db, backend=backend, server=server) + if create: # add in reverse order so basic is default addClozeModel(col) addBasicTypingModel(col) @@ -70,112 +58,15 @@ def Collection( addForwardReverse(col) addBasicModel(col) col.save() - if lock: - try: - col.lock() - except: - col.db.close() - raise return col -def _upgradeSchema(db: DBProxy) -> Any: - return db.scalar("select ver from col") - - # Creating a new collection ###################################################################### -def _createDB(db: DBProxy) -> int: - db.execute("pragma page_size = 4096") - db.execute("pragma legacy_file_format = 0") - db.execute("vacuum") - _addSchema(db) - _updateIndices(db) - db.execute("analyze") - return SCHEMA_VERSION - - -def _addSchema(db: DBProxy, setColConf: bool = True) -> None: - db.executescript( - """ -create table if not exists col ( - id integer primary key, - crt integer not null, - mod integer not null, - scm integer not null, - ver integer not null, - dty integer not null, - usn integer not null, - ls integer not null, - conf text not null, - models text not null, - decks text not null, - dconf text not null, - tags text not null -); - -create table if not exists notes ( - id integer primary key, /* 0 */ - guid text not null, /* 1 */ - mid integer not null, /* 2 */ - mod integer not null, /* 3 */ - usn integer not null, /* 4 */ - tags text not null, /* 5 */ - flds text not null, /* 6 */ - sfld integer not null, /* 7 */ - csum integer not null, /* 8 */ - flags integer not null, /* 9 */ - data text not null /* 10 */ -); - -create table if not exists cards ( - id integer primary key, /* 0 */ - nid integer not null, /* 1 */ - did integer not null, /* 2 */ - ord integer not null, /* 3 */ - mod integer not null, /* 4 */ - usn integer not null, /* 5 */ - type integer not null, /* 6 */ - queue integer not null, /* 7 */ - due integer not null, /* 8 */ - ivl integer not null, /* 9 */ - factor integer not null, /* 10 */ - reps integer not null, /* 11 */ - lapses integer not null, /* 12 */ - left integer not null, /* 13 */ - odue integer not null, /* 14 */ - odid integer not null, /* 15 */ - flags integer not null, /* 16 */ - data text not null /* 17 */ -); - -create table if not exists revlog ( - id integer primary key, - cid integer not null, - usn integer not null, - ease integer not null, - ivl integer not null, - lastIvl integer not null, - factor integer not null, - time integer not null, - type integer not null -); - -create table if not exists graves ( - usn integer not null, - oid integer not null, - type integer not null -); - -insert or ignore into col -values(1,0,0,%(s)s,%(v)s,0,0,0,'','{}','','','{}'); -""" - % ({"v": SCHEMA_VERSION, "s": intTime(1000)}) - ) - if setColConf: - _addColVars(db, *_getColVars(db)) +def initial_db_setup(db: DBProxy) -> None: + _addColVars(db, *_getColVars(db)) def _getColVars(db: DBProxy) -> Tuple[Any, Any, Dict[str, Any]]: @@ -202,23 +93,3 @@ update col set conf = ?, decks = ?, dconf = ?""", json.dumps({"1": g}), json.dumps({"1": gc}), ) - - -def _updateIndices(db: DBProxy) -> None: - "Add indices to the DB." - db.executescript( - """ --- syncing -create index if not exists ix_notes_usn on notes (usn); -create index if not exists ix_cards_usn on cards (usn); -create index if not exists ix_revlog_usn on revlog (usn); --- card spacing, etc -create index if not exists ix_cards_nid on cards (nid); --- scheduling and deck limiting -create index if not exists ix_cards_sched on cards (did, queue, due); --- revlog by card -create index if not exists ix_revlog_cid on revlog (cid); --- field uniqueness -create index if not exists ix_notes_csum on notes (csum); -""" - ) diff --git a/rslib/src/backend/dbproxy.rs b/rslib/src/backend/dbproxy.rs index f86be1b0b..0cec30aca 100644 --- a/rslib/src/backend/dbproxy.rs +++ b/rslib/src/backend/dbproxy.rs @@ -1,23 +1,26 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html -use crate::backend_proto as pb; use crate::err::Result; use crate::storage::SqliteStorage; use rusqlite::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef}; use serde_derive::{Deserialize, Serialize}; #[derive(Deserialize)] -pub(super) struct DBRequest { - sql: String, - args: Vec, +#[serde(tag = "kind", rename_all = "lowercase")] +pub(super) enum DBRequest { + Query { sql: String, args: Vec }, + Begin, + Commit, + Rollback, } -// #[derive(Serialize)] -// pub(super) struct DBResult { -// rows: Vec>, -// } -type DBResult = Vec>; +#[derive(Serialize)] +#[serde(untagged)] +pub(super) enum DBResult { + Rows(Vec>), + None, +} #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] @@ -55,30 +58,41 @@ impl FromSql for SqlValue { } } -pub(super) fn db_query_json_str(db: &SqliteStorage, input: &[u8]) -> Result { +pub(super) fn db_command_bytes(db: &SqliteStorage, input: &[u8]) -> Result { let req: DBRequest = serde_json::from_slice(input)?; - let resp = db_query_json(db, req)?; + let resp = match req { + DBRequest::Query { sql, args } => db_query(db, &sql, &args)?, + DBRequest::Begin => { + db.begin()?; + DBResult::None + } + DBRequest::Commit => { + db.commit()?; + DBResult::None + } + DBRequest::Rollback => { + db.rollback()?; + DBResult::None + } + }; Ok(serde_json::to_string(&resp)?) } -pub(super) fn db_query_json(db: &SqliteStorage, input: DBRequest) -> Result { - let mut stmt = db.db.prepare_cached(&input.sql)?; +pub(super) fn db_query(db: &SqliteStorage, sql: &str, args: &[SqlValue]) -> Result { + let mut stmt = db.db.prepare_cached(sql)?; let columns = stmt.column_count(); - let mut rows = stmt.query(&input.args)?; + let res: std::result::Result>, rusqlite::Error> = stmt + .query_map(args, |row| { + let mut orow = Vec::with_capacity(columns); + for i in 0..columns { + let v: SqlValue = row.get(i)?; + orow.push(v); + } + Ok(orow) + })? + .collect(); - let mut output_rows = vec![]; - - while let Some(row) = rows.next()? { - let mut orow = Vec::with_capacity(columns); - for i in 0..columns { - let v: SqlValue = row.get(i)?; - orow.push(v); - } - - output_rows.push(orow); - } - - Ok(output_rows) + Ok(DBResult::Rows(res?)) } diff --git a/rslib/src/backend/mod.rs b/rslib/src/backend/mod.rs index bbf8b5266..7f479d7d8 100644 --- a/rslib/src/backend/mod.rs +++ b/rslib/src/backend/mod.rs @@ -1,7 +1,7 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html -use crate::backend::dbproxy::db_query_json_str; +use crate::backend::dbproxy::db_command_bytes; use crate::backend_proto::backend_input::Value; use crate::backend_proto::{Empty, RenderedTemplateReplacement, SyncMediaIn}; use crate::err::{AnkiError, NetworkErrorKind, Result, SyncErrorKind}; @@ -37,7 +37,7 @@ pub struct Backend { media_folder: PathBuf, media_db: String, progress_callback: Option, - i18n: I18n, + pub i18n: I18n, log: Logger, } @@ -493,8 +493,8 @@ impl Backend { checker.restore_trash() } - pub fn db_query(&self, input: pb::DbQueryIn) -> Result { - db_query_proto(&self.col, input) + pub fn db_command(&self, input: &[u8]) -> Result { + db_command_bytes(&self.col, input) } } diff --git a/rslib/src/storage/sqlite.rs b/rslib/src/storage/sqlite.rs index 4f6a13bd1..e0d68c640 100644 --- a/rslib/src/storage/sqlite.rs +++ b/rslib/src/storage/sqlite.rs @@ -4,28 +4,12 @@ use crate::err::Result; use crate::err::{AnkiError, DBErrorKind}; use crate::time::i64_unix_timestamp; -use rusqlite::types::{FromSql, FromSqlError, FromSqlResult, ValueRef}; -use rusqlite::{params, Connection, OptionalExtension, NO_PARAMS}; -use serde::de::DeserializeOwned; -use serde_derive::{Deserialize, Serialize}; -use serde_json::Value; -use std::borrow::Cow; -use std::convert::TryFrom; -use std::fmt; +use rusqlite::{params, Connection, NO_PARAMS}; use std::path::{Path, PathBuf}; const SCHEMA_MIN_VERSION: u8 = 11; const SCHEMA_MAX_VERSION: u8 = 11; -macro_rules! cached_sql { - ( $label:expr, $db:expr, $sql:expr ) => {{ - if $label.is_none() { - $label = Some($db.prepare_cached($sql)?); - } - $label.as_mut().unwrap() - }}; -} - // currently public for dbproxy #[derive(Debug)] pub struct SqliteStorage { @@ -42,6 +26,8 @@ fn open_or_create_collection_db(path: &Path) -> Result { db.trace(Some(trace)); } + db.busy_timeout(std::time::Duration::from_secs(0))?; + db.pragma_update(None, "locking_mode", &"exclusive")?; db.pragma_update(None, "page_size", &4096)?; db.pragma_update(None, "cache_size", &(-40 * 1024))?; @@ -78,7 +64,6 @@ impl SqliteStorage { let (create, ver) = schema_version(&db)?; if create { - unimplemented!(); // todo db.prepare_cached("begin exclusive")?.execute(NO_PARAMS)?; db.execute_batch(include_str!("schema11.sql"))?; db.execute( @@ -118,12 +103,16 @@ impl SqliteStorage { } pub(crate) fn commit(&self) -> Result<()> { - self.db.prepare_cached("commit")?.execute(NO_PARAMS)?; + if !self.db.is_autocommit() { + self.db.prepare_cached("commit")?.execute(NO_PARAMS)?; + } Ok(()) } pub(crate) fn rollback(&self) -> Result<()> { - self.db.execute("rollback", NO_PARAMS)?; + if !self.db.is_autocommit() { + self.db.execute("rollback", NO_PARAMS)?; + } Ok(()) } } diff --git a/rspy/src/lib.rs b/rspy/src/lib.rs index d921aa799..7baad5569 100644 --- a/rspy/src/lib.rs +++ b/rspy/src/lib.rs @@ -4,9 +4,10 @@ use anki::backend::{ init_backend, init_i18n_backend, Backend as RustBackend, I18nBackend as RustI18nBackend, }; +use pyo3::exceptions::Exception; use pyo3::prelude::*; use pyo3::types::PyBytes; -use pyo3::{exceptions, wrap_pyfunction}; +use pyo3::{create_exception, exceptions, wrap_pyfunction}; // Regular backend ////////////////////////////////// @@ -16,6 +17,8 @@ struct Backend { backend: RustBackend, } +create_exception!(ankirspy, DBError, Exception); + #[pyfunction] fn buildhash() -> &'static str { include_str!("../../meta/buildhash").trim() @@ -71,11 +74,15 @@ impl Backend { } } - fn db_query(&mut self, py: Python, input: &PyBytes) -> PyObject { + fn db_command(&mut self, py: Python, input: &PyBytes) -> PyResult { let in_bytes = input.as_bytes(); - let out_string = self.backend.db_query(in_bytes).unwrap(); + let out_string = self + .backend + .db_command(in_bytes) + .map_err(|e| DBError::py_err(e.localized_description(&self.backend.i18n)))?; + let out_obj = PyBytes::new(py, out_string.as_bytes()); - out_obj.into() + Ok(out_obj.into()) } }