add begin/commit/rollback, and support creating collections

all but one unit test is now passing
This commit is contained in:
Damien Elmes 2020-03-05 10:54:30 +10:00
parent 6db4418f05
commit 2cd7885ec0
8 changed files with 109 additions and 211 deletions

View File

@ -242,10 +242,7 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""",
return None return None
def lock(self) -> None: def lock(self) -> None:
# make sure we don't accidentally bump mod time self.db.begin()
mod = self.db.mod
self.db.execute("update col set mod=mod")
self.db.mod = mod
def close(self, save: bool = True) -> None: def close(self, save: bool = True) -> None:
"Disconnect from DB." "Disconnect from DB."
@ -260,6 +257,7 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""",
self.db.setAutocommit(False) self.db.setAutocommit(False)
self.db.close() self.db.close()
self.db = None self.db = None
self.backend = None
self.media.close() self.media.close()
self._closeLog() self._closeLog()

View File

@ -3,11 +3,14 @@
from __future__ import annotations from __future__ import annotations
import weakref
from typing import Any, Iterable, List, Optional, Sequence, Union from typing import Any, Iterable, List, Optional, Sequence, Union
import anki import anki
# fixme: remember to null on close to avoid circular ref # fixme: col.reopen()
# fixme: setAutocommit()
# fixme: transaction/lock handling
# fixme: progress # fixme: progress
# DBValue is actually Union[str, int, float, None], but if defined # DBValue is actually Union[str, int, float, None], but if defined
@ -24,7 +27,7 @@ class DBProxy:
############### ###############
def __init__(self, backend: anki.rsbackend.RustBackend, path: str) -> None: def __init__(self, backend: anki.rsbackend.RustBackend, path: str) -> None:
self._backend = backend self._backend = weakref.proxy(backend)
self._path = path self._path = path
self.mod = False self.mod = False
@ -35,17 +38,20 @@ class DBProxy:
# Transactions # Transactions
############### ###############
def begin(self) -> None:
self._backend.db_begin()
def commit(self) -> None: def commit(self) -> None:
# fixme self._backend.db_commit()
pass
def rollback(self) -> None: def rollback(self) -> None:
# fixme self._backend.db_rollback()
pass
def setAutocommit(self, autocommit: bool) -> None: def setAutocommit(self, autocommit: bool) -> None:
# fixme if autocommit:
pass self.commit()
else:
self.begin()
# Querying # Querying
################ ################
@ -58,6 +64,7 @@ class DBProxy:
for stmt in "insert", "update", "delete": for stmt in "insert", "update", "delete":
if s.startswith(stmt): if s.startswith(stmt):
self.mod = True self.mod = True
assert ":" not in sql
# fetch rows # fetch rows
# fixme: first_row_only # fixme: first_row_only
return self._backend.db_query(sql, args) return self._backend.db_query(sql, args)

View File

@ -6,6 +6,7 @@ import enum
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import (
Any,
Callable, Callable,
Dict, Dict,
Iterable, Iterable,
@ -15,7 +16,7 @@ from typing import (
Optional, Optional,
Tuple, Tuple,
Union, Union,
Any) )
import ankirspy # pytype: disable=import-error import ankirspy # pytype: disable=import-error
import orjson import orjson
@ -386,9 +387,20 @@ class RustBackend:
self._run_command(pb.BackendInput(restore_trash=pb.Empty())) self._run_command(pb.BackendInput(restore_trash=pb.Empty()))
def db_query(self, sql: str, args: Iterable[ValueForDB]) -> List[DBRow]: def db_query(self, sql: str, args: Iterable[ValueForDB]) -> List[DBRow]:
input = orjson.dumps(dict(sql=sql, args=args)) return self._db_command(dict(kind="query", sql=sql, args=args))
output = self._backend.db_query(input)
return orjson.loads(output) def db_begin(self) -> None:
return self._db_command(dict(kind="begin"))
def db_commit(self) -> None:
return self._db_command(dict(kind="commit"))
def db_rollback(self) -> None:
return self._db_command(dict(kind="rollback"))
def _db_command(self, input: Dict[str, Any]) -> Any:
return orjson.loads(self._backend.db_command(orjson.dumps(input)))
def translate_string_in( def translate_string_in(
key: TR, **kwargs: Union[str, int, float] key: TR, **kwargs: Union[str, int, float]

View File

@ -26,9 +26,7 @@ class ServerData:
minutes_west: Optional[int] = None minutes_west: Optional[int] = None
def Collection( def Collection(path: str, server: Optional[ServerData] = None) -> _Collection:
path: str, lock: bool = True, server: Optional[ServerData] = None, log: bool = False
) -> _Collection:
"Open a new or existing collection. Path must be unicode." "Open a new or existing collection. Path must be unicode."
assert path.endswith(".anki2") assert path.endswith(".anki2")
(media_dir, media_db) = media_paths_from_col_path(path) (media_dir, media_db) = media_paths_from_col_path(path)
@ -36,33 +34,23 @@ def Collection(
if not server: if not server:
log_path = path.replace(".anki2", "2.log") log_path = path.replace(".anki2", "2.log")
path = os.path.abspath(path) path = os.path.abspath(path)
create = not os.path.exists(path)
if create:
base = os.path.basename(path)
for c in ("/", ":", "\\"):
assert c not in base
# connect # connect
backend = RustBackend( backend = RustBackend(
path, media_dir, media_db, log_path, server=server is not None path, media_dir, media_db, log_path, server=server is not None
) )
db = DBProxy(backend, path) db = DBProxy(backend, path)
db.begin()
db.setAutocommit(True) db.setAutocommit(True)
# initial setup required?
create = db.scalar("select models = '{}' from col")
if create: if create:
ver = _createDB(db) initial_db_setup(db)
else:
ver = _upgradeSchema(db)
db.execute("pragma temp_store = memory")
db.execute("pragma cache_size = 10000")
if not isWin:
db.execute("pragma journal_mode = wal")
db.setAutocommit(False) db.setAutocommit(False)
# add db to col and do any remaining upgrades # add db to col and do any remaining upgrades
col = _Collection(db, backend=backend, server=server, log=log) col = _Collection(db, backend=backend, server=server)
if ver < SCHEMA_VERSION: if create:
raise Exception("This file requires an older version of Anki.")
elif ver > SCHEMA_VERSION:
raise Exception("This file requires a newer version of Anki.")
elif create:
# add in reverse order so basic is default # add in reverse order so basic is default
addClozeModel(col) addClozeModel(col)
addBasicTypingModel(col) addBasicTypingModel(col)
@ -70,111 +58,14 @@ def Collection(
addForwardReverse(col) addForwardReverse(col)
addBasicModel(col) addBasicModel(col)
col.save() col.save()
if lock:
try:
col.lock()
except:
col.db.close()
raise
return col return col
def _upgradeSchema(db: DBProxy) -> Any:
return db.scalar("select ver from col")
# Creating a new collection # Creating a new collection
###################################################################### ######################################################################
def _createDB(db: DBProxy) -> int: def initial_db_setup(db: DBProxy) -> None:
db.execute("pragma page_size = 4096")
db.execute("pragma legacy_file_format = 0")
db.execute("vacuum")
_addSchema(db)
_updateIndices(db)
db.execute("analyze")
return SCHEMA_VERSION
def _addSchema(db: DBProxy, setColConf: bool = True) -> None:
db.executescript(
"""
create table if not exists col (
id integer primary key,
crt integer not null,
mod integer not null,
scm integer not null,
ver integer not null,
dty integer not null,
usn integer not null,
ls integer not null,
conf text not null,
models text not null,
decks text not null,
dconf text not null,
tags text not null
);
create table if not exists notes (
id integer primary key, /* 0 */
guid text not null, /* 1 */
mid integer not null, /* 2 */
mod integer not null, /* 3 */
usn integer not null, /* 4 */
tags text not null, /* 5 */
flds text not null, /* 6 */
sfld integer not null, /* 7 */
csum integer not null, /* 8 */
flags integer not null, /* 9 */
data text not null /* 10 */
);
create table if not exists cards (
id integer primary key, /* 0 */
nid integer not null, /* 1 */
did integer not null, /* 2 */
ord integer not null, /* 3 */
mod integer not null, /* 4 */
usn integer not null, /* 5 */
type integer not null, /* 6 */
queue integer not null, /* 7 */
due integer not null, /* 8 */
ivl integer not null, /* 9 */
factor integer not null, /* 10 */
reps integer not null, /* 11 */
lapses integer not null, /* 12 */
left integer not null, /* 13 */
odue integer not null, /* 14 */
odid integer not null, /* 15 */
flags integer not null, /* 16 */
data text not null /* 17 */
);
create table if not exists revlog (
id integer primary key,
cid integer not null,
usn integer not null,
ease integer not null,
ivl integer not null,
lastIvl integer not null,
factor integer not null,
time integer not null,
type integer not null
);
create table if not exists graves (
usn integer not null,
oid integer not null,
type integer not null
);
insert or ignore into col
values(1,0,0,%(s)s,%(v)s,0,0,0,'','{}','','','{}');
"""
% ({"v": SCHEMA_VERSION, "s": intTime(1000)})
)
if setColConf:
_addColVars(db, *_getColVars(db)) _addColVars(db, *_getColVars(db))
@ -202,23 +93,3 @@ update col set conf = ?, decks = ?, dconf = ?""",
json.dumps({"1": g}), json.dumps({"1": g}),
json.dumps({"1": gc}), json.dumps({"1": gc}),
) )
def _updateIndices(db: DBProxy) -> None:
"Add indices to the DB."
db.executescript(
"""
-- syncing
create index if not exists ix_notes_usn on notes (usn);
create index if not exists ix_cards_usn on cards (usn);
create index if not exists ix_revlog_usn on revlog (usn);
-- card spacing, etc
create index if not exists ix_cards_nid on cards (nid);
-- scheduling and deck limiting
create index if not exists ix_cards_sched on cards (did, queue, due);
-- revlog by card
create index if not exists ix_revlog_cid on revlog (cid);
-- field uniqueness
create index if not exists ix_notes_csum on notes (csum);
"""
)

View File

@ -1,23 +1,26 @@
// Copyright: Ankitects Pty Ltd and contributors // Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::backend_proto as pb;
use crate::err::Result; use crate::err::Result;
use crate::storage::SqliteStorage; use crate::storage::SqliteStorage;
use rusqlite::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef}; use rusqlite::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef};
use serde_derive::{Deserialize, Serialize}; use serde_derive::{Deserialize, Serialize};
#[derive(Deserialize)] #[derive(Deserialize)]
pub(super) struct DBRequest { #[serde(tag = "kind", rename_all = "lowercase")]
sql: String, pub(super) enum DBRequest {
args: Vec<SqlValue>, Query { sql: String, args: Vec<SqlValue> },
Begin,
Commit,
Rollback,
} }
// #[derive(Serialize)] #[derive(Serialize)]
// pub(super) struct DBResult { #[serde(untagged)]
// rows: Vec<Vec<SqlValue>>, pub(super) enum DBResult {
// } Rows(Vec<Vec<SqlValue>>),
type DBResult = Vec<Vec<SqlValue>>; None,
}
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)] #[serde(untagged)]
@ -55,30 +58,41 @@ impl FromSql for SqlValue {
} }
} }
pub(super) fn db_query_json_str(db: &SqliteStorage, input: &[u8]) -> Result<String> { pub(super) fn db_command_bytes(db: &SqliteStorage, input: &[u8]) -> Result<String> {
let req: DBRequest = serde_json::from_slice(input)?; let req: DBRequest = serde_json::from_slice(input)?;
let resp = db_query_json(db, req)?; let resp = match req {
DBRequest::Query { sql, args } => db_query(db, &sql, &args)?,
DBRequest::Begin => {
db.begin()?;
DBResult::None
}
DBRequest::Commit => {
db.commit()?;
DBResult::None
}
DBRequest::Rollback => {
db.rollback()?;
DBResult::None
}
};
Ok(serde_json::to_string(&resp)?) Ok(serde_json::to_string(&resp)?)
} }
pub(super) fn db_query_json(db: &SqliteStorage, input: DBRequest) -> Result<DBResult> { pub(super) fn db_query(db: &SqliteStorage, sql: &str, args: &[SqlValue]) -> Result<DBResult> {
let mut stmt = db.db.prepare_cached(&input.sql)?; let mut stmt = db.db.prepare_cached(sql)?;
let columns = stmt.column_count(); let columns = stmt.column_count();
let mut rows = stmt.query(&input.args)?; let res: std::result::Result<Vec<Vec<_>>, rusqlite::Error> = stmt
.query_map(args, |row| {
let mut output_rows = vec![];
while let Some(row) = rows.next()? {
let mut orow = Vec::with_capacity(columns); let mut orow = Vec::with_capacity(columns);
for i in 0..columns { for i in 0..columns {
let v: SqlValue = row.get(i)?; let v: SqlValue = row.get(i)?;
orow.push(v); orow.push(v);
} }
Ok(orow)
})?
.collect();
output_rows.push(orow); Ok(DBResult::Rows(res?))
}
Ok(output_rows)
} }

View File

@ -1,7 +1,7 @@
// Copyright: Ankitects Pty Ltd and contributors // Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::backend::dbproxy::db_query_json_str; use crate::backend::dbproxy::db_command_bytes;
use crate::backend_proto::backend_input::Value; use crate::backend_proto::backend_input::Value;
use crate::backend_proto::{Empty, RenderedTemplateReplacement, SyncMediaIn}; use crate::backend_proto::{Empty, RenderedTemplateReplacement, SyncMediaIn};
use crate::err::{AnkiError, NetworkErrorKind, Result, SyncErrorKind}; use crate::err::{AnkiError, NetworkErrorKind, Result, SyncErrorKind};
@ -37,7 +37,7 @@ pub struct Backend {
media_folder: PathBuf, media_folder: PathBuf,
media_db: String, media_db: String,
progress_callback: Option<ProtoProgressCallback>, progress_callback: Option<ProtoProgressCallback>,
i18n: I18n, pub i18n: I18n,
log: Logger, log: Logger,
} }
@ -493,8 +493,8 @@ impl Backend {
checker.restore_trash() checker.restore_trash()
} }
pub fn db_query(&self, input: pb::DbQueryIn) -> Result<pb::DbQueryOut> { pub fn db_command(&self, input: &[u8]) -> Result<String> {
db_query_proto(&self.col, input) db_command_bytes(&self.col, input)
} }
} }

View File

@ -4,28 +4,12 @@
use crate::err::Result; use crate::err::Result;
use crate::err::{AnkiError, DBErrorKind}; use crate::err::{AnkiError, DBErrorKind};
use crate::time::i64_unix_timestamp; use crate::time::i64_unix_timestamp;
use rusqlite::types::{FromSql, FromSqlError, FromSqlResult, ValueRef}; use rusqlite::{params, Connection, NO_PARAMS};
use rusqlite::{params, Connection, OptionalExtension, NO_PARAMS};
use serde::de::DeserializeOwned;
use serde_derive::{Deserialize, Serialize};
use serde_json::Value;
use std::borrow::Cow;
use std::convert::TryFrom;
use std::fmt;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
const SCHEMA_MIN_VERSION: u8 = 11; const SCHEMA_MIN_VERSION: u8 = 11;
const SCHEMA_MAX_VERSION: u8 = 11; const SCHEMA_MAX_VERSION: u8 = 11;
macro_rules! cached_sql {
( $label:expr, $db:expr, $sql:expr ) => {{
if $label.is_none() {
$label = Some($db.prepare_cached($sql)?);
}
$label.as_mut().unwrap()
}};
}
// currently public for dbproxy // currently public for dbproxy
#[derive(Debug)] #[derive(Debug)]
pub struct SqliteStorage { pub struct SqliteStorage {
@ -42,6 +26,8 @@ fn open_or_create_collection_db(path: &Path) -> Result<Connection> {
db.trace(Some(trace)); db.trace(Some(trace));
} }
db.busy_timeout(std::time::Duration::from_secs(0))?;
db.pragma_update(None, "locking_mode", &"exclusive")?; db.pragma_update(None, "locking_mode", &"exclusive")?;
db.pragma_update(None, "page_size", &4096)?; db.pragma_update(None, "page_size", &4096)?;
db.pragma_update(None, "cache_size", &(-40 * 1024))?; db.pragma_update(None, "cache_size", &(-40 * 1024))?;
@ -78,7 +64,6 @@ impl SqliteStorage {
let (create, ver) = schema_version(&db)?; let (create, ver) = schema_version(&db)?;
if create { if create {
unimplemented!(); // todo
db.prepare_cached("begin exclusive")?.execute(NO_PARAMS)?; db.prepare_cached("begin exclusive")?.execute(NO_PARAMS)?;
db.execute_batch(include_str!("schema11.sql"))?; db.execute_batch(include_str!("schema11.sql"))?;
db.execute( db.execute(
@ -118,12 +103,16 @@ impl SqliteStorage {
} }
pub(crate) fn commit(&self) -> Result<()> { pub(crate) fn commit(&self) -> Result<()> {
if !self.db.is_autocommit() {
self.db.prepare_cached("commit")?.execute(NO_PARAMS)?; self.db.prepare_cached("commit")?.execute(NO_PARAMS)?;
}
Ok(()) Ok(())
} }
pub(crate) fn rollback(&self) -> Result<()> { pub(crate) fn rollback(&self) -> Result<()> {
if !self.db.is_autocommit() {
self.db.execute("rollback", NO_PARAMS)?; self.db.execute("rollback", NO_PARAMS)?;
}
Ok(()) Ok(())
} }
} }

View File

@ -4,9 +4,10 @@
use anki::backend::{ use anki::backend::{
init_backend, init_i18n_backend, Backend as RustBackend, I18nBackend as RustI18nBackend, init_backend, init_i18n_backend, Backend as RustBackend, I18nBackend as RustI18nBackend,
}; };
use pyo3::exceptions::Exception;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::PyBytes; use pyo3::types::PyBytes;
use pyo3::{exceptions, wrap_pyfunction}; use pyo3::{create_exception, exceptions, wrap_pyfunction};
// Regular backend // Regular backend
////////////////////////////////// //////////////////////////////////
@ -16,6 +17,8 @@ struct Backend {
backend: RustBackend, backend: RustBackend,
} }
create_exception!(ankirspy, DBError, Exception);
#[pyfunction] #[pyfunction]
fn buildhash() -> &'static str { fn buildhash() -> &'static str {
include_str!("../../meta/buildhash").trim() include_str!("../../meta/buildhash").trim()
@ -71,11 +74,15 @@ impl Backend {
} }
} }
fn db_query(&mut self, py: Python, input: &PyBytes) -> PyObject { fn db_command(&mut self, py: Python, input: &PyBytes) -> PyResult<PyObject> {
let in_bytes = input.as_bytes(); let in_bytes = input.as_bytes();
let out_string = self.backend.db_query(in_bytes).unwrap(); let out_string = self
.backend
.db_command(in_bytes)
.map_err(|e| DBError::py_err(e.localized_description(&self.backend.i18n)))?;
let out_obj = PyBytes::new(py, out_string.as_bytes()); let out_obj = PyBytes::new(py, out_string.as_bytes());
out_obj.into() Ok(out_obj.into())
} }
} }