add begin/commit/rollback, and support creating collections
all but one unit test is now passing
This commit is contained in:
parent
6db4418f05
commit
2cd7885ec0
@ -242,10 +242,7 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""",
|
||||
return None
|
||||
|
||||
def lock(self) -> None:
|
||||
# make sure we don't accidentally bump mod time
|
||||
mod = self.db.mod
|
||||
self.db.execute("update col set mod=mod")
|
||||
self.db.mod = mod
|
||||
self.db.begin()
|
||||
|
||||
def close(self, save: bool = True) -> None:
|
||||
"Disconnect from DB."
|
||||
@ -260,6 +257,7 @@ crt=?, mod=?, scm=?, dty=?, usn=?, ls=?, conf=?""",
|
||||
self.db.setAutocommit(False)
|
||||
self.db.close()
|
||||
self.db = None
|
||||
self.backend = None
|
||||
self.media.close()
|
||||
self._closeLog()
|
||||
|
||||
|
@ -3,11 +3,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from typing import Any, Iterable, List, Optional, Sequence, Union
|
||||
|
||||
import anki
|
||||
|
||||
# fixme: remember to null on close to avoid circular ref
|
||||
# fixme: col.reopen()
|
||||
# fixme: setAutocommit()
|
||||
# fixme: transaction/lock handling
|
||||
# fixme: progress
|
||||
|
||||
# DBValue is actually Union[str, int, float, None], but if defined
|
||||
@ -24,7 +27,7 @@ class DBProxy:
|
||||
###############
|
||||
|
||||
def __init__(self, backend: anki.rsbackend.RustBackend, path: str) -> None:
|
||||
self._backend = backend
|
||||
self._backend = weakref.proxy(backend)
|
||||
self._path = path
|
||||
self.mod = False
|
||||
|
||||
@ -35,17 +38,20 @@ class DBProxy:
|
||||
# Transactions
|
||||
###############
|
||||
|
||||
def begin(self) -> None:
|
||||
self._backend.db_begin()
|
||||
|
||||
def commit(self) -> None:
|
||||
# fixme
|
||||
pass
|
||||
self._backend.db_commit()
|
||||
|
||||
def rollback(self) -> None:
|
||||
# fixme
|
||||
pass
|
||||
self._backend.db_rollback()
|
||||
|
||||
def setAutocommit(self, autocommit: bool) -> None:
|
||||
# fixme
|
||||
pass
|
||||
if autocommit:
|
||||
self.commit()
|
||||
else:
|
||||
self.begin()
|
||||
|
||||
# Querying
|
||||
################
|
||||
@ -58,6 +64,7 @@ class DBProxy:
|
||||
for stmt in "insert", "update", "delete":
|
||||
if s.startswith(stmt):
|
||||
self.mod = True
|
||||
assert ":" not in sql
|
||||
# fetch rows
|
||||
# fixme: first_row_only
|
||||
return self._backend.db_query(sql, args)
|
||||
|
@ -6,6 +6,7 @@ import enum
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
@ -15,7 +16,7 @@ from typing import (
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
Any)
|
||||
)
|
||||
|
||||
import ankirspy # pytype: disable=import-error
|
||||
import orjson
|
||||
@ -386,9 +387,20 @@ class RustBackend:
|
||||
self._run_command(pb.BackendInput(restore_trash=pb.Empty()))
|
||||
|
||||
def db_query(self, sql: str, args: Iterable[ValueForDB]) -> List[DBRow]:
|
||||
input = orjson.dumps(dict(sql=sql, args=args))
|
||||
output = self._backend.db_query(input)
|
||||
return orjson.loads(output)
|
||||
return self._db_command(dict(kind="query", sql=sql, args=args))
|
||||
|
||||
def db_begin(self) -> None:
|
||||
return self._db_command(dict(kind="begin"))
|
||||
|
||||
def db_commit(self) -> None:
|
||||
return self._db_command(dict(kind="commit"))
|
||||
|
||||
def db_rollback(self) -> None:
|
||||
return self._db_command(dict(kind="rollback"))
|
||||
|
||||
def _db_command(self, input: Dict[str, Any]) -> Any:
|
||||
return orjson.loads(self._backend.db_command(orjson.dumps(input)))
|
||||
|
||||
|
||||
def translate_string_in(
|
||||
key: TR, **kwargs: Union[str, int, float]
|
||||
|
@ -26,9 +26,7 @@ class ServerData:
|
||||
minutes_west: Optional[int] = None
|
||||
|
||||
|
||||
def Collection(
|
||||
path: str, lock: bool = True, server: Optional[ServerData] = None, log: bool = False
|
||||
) -> _Collection:
|
||||
def Collection(path: str, server: Optional[ServerData] = None) -> _Collection:
|
||||
"Open a new or existing collection. Path must be unicode."
|
||||
assert path.endswith(".anki2")
|
||||
(media_dir, media_db) = media_paths_from_col_path(path)
|
||||
@ -36,33 +34,23 @@ def Collection(
|
||||
if not server:
|
||||
log_path = path.replace(".anki2", "2.log")
|
||||
path = os.path.abspath(path)
|
||||
create = not os.path.exists(path)
|
||||
if create:
|
||||
base = os.path.basename(path)
|
||||
for c in ("/", ":", "\\"):
|
||||
assert c not in base
|
||||
# connect
|
||||
backend = RustBackend(
|
||||
path, media_dir, media_db, log_path, server=server is not None
|
||||
)
|
||||
db = DBProxy(backend, path)
|
||||
db.begin()
|
||||
db.setAutocommit(True)
|
||||
|
||||
# initial setup required?
|
||||
create = db.scalar("select models = '{}' from col")
|
||||
if create:
|
||||
ver = _createDB(db)
|
||||
else:
|
||||
ver = _upgradeSchema(db)
|
||||
db.execute("pragma temp_store = memory")
|
||||
db.execute("pragma cache_size = 10000")
|
||||
if not isWin:
|
||||
db.execute("pragma journal_mode = wal")
|
||||
initial_db_setup(db)
|
||||
|
||||
db.setAutocommit(False)
|
||||
# add db to col and do any remaining upgrades
|
||||
col = _Collection(db, backend=backend, server=server, log=log)
|
||||
if ver < SCHEMA_VERSION:
|
||||
raise Exception("This file requires an older version of Anki.")
|
||||
elif ver > SCHEMA_VERSION:
|
||||
raise Exception("This file requires a newer version of Anki.")
|
||||
elif create:
|
||||
col = _Collection(db, backend=backend, server=server)
|
||||
if create:
|
||||
# add in reverse order so basic is default
|
||||
addClozeModel(col)
|
||||
addBasicTypingModel(col)
|
||||
@ -70,111 +58,14 @@ def Collection(
|
||||
addForwardReverse(col)
|
||||
addBasicModel(col)
|
||||
col.save()
|
||||
if lock:
|
||||
try:
|
||||
col.lock()
|
||||
except:
|
||||
col.db.close()
|
||||
raise
|
||||
return col
|
||||
|
||||
|
||||
def _upgradeSchema(db: DBProxy) -> Any:
|
||||
return db.scalar("select ver from col")
|
||||
|
||||
|
||||
# Creating a new collection
|
||||
######################################################################
|
||||
|
||||
|
||||
def _createDB(db: DBProxy) -> int:
|
||||
db.execute("pragma page_size = 4096")
|
||||
db.execute("pragma legacy_file_format = 0")
|
||||
db.execute("vacuum")
|
||||
_addSchema(db)
|
||||
_updateIndices(db)
|
||||
db.execute("analyze")
|
||||
return SCHEMA_VERSION
|
||||
|
||||
|
||||
def _addSchema(db: DBProxy, setColConf: bool = True) -> None:
|
||||
db.executescript(
|
||||
"""
|
||||
create table if not exists col (
|
||||
id integer primary key,
|
||||
crt integer not null,
|
||||
mod integer not null,
|
||||
scm integer not null,
|
||||
ver integer not null,
|
||||
dty integer not null,
|
||||
usn integer not null,
|
||||
ls integer not null,
|
||||
conf text not null,
|
||||
models text not null,
|
||||
decks text not null,
|
||||
dconf text not null,
|
||||
tags text not null
|
||||
);
|
||||
|
||||
create table if not exists notes (
|
||||
id integer primary key, /* 0 */
|
||||
guid text not null, /* 1 */
|
||||
mid integer not null, /* 2 */
|
||||
mod integer not null, /* 3 */
|
||||
usn integer not null, /* 4 */
|
||||
tags text not null, /* 5 */
|
||||
flds text not null, /* 6 */
|
||||
sfld integer not null, /* 7 */
|
||||
csum integer not null, /* 8 */
|
||||
flags integer not null, /* 9 */
|
||||
data text not null /* 10 */
|
||||
);
|
||||
|
||||
create table if not exists cards (
|
||||
id integer primary key, /* 0 */
|
||||
nid integer not null, /* 1 */
|
||||
did integer not null, /* 2 */
|
||||
ord integer not null, /* 3 */
|
||||
mod integer not null, /* 4 */
|
||||
usn integer not null, /* 5 */
|
||||
type integer not null, /* 6 */
|
||||
queue integer not null, /* 7 */
|
||||
due integer not null, /* 8 */
|
||||
ivl integer not null, /* 9 */
|
||||
factor integer not null, /* 10 */
|
||||
reps integer not null, /* 11 */
|
||||
lapses integer not null, /* 12 */
|
||||
left integer not null, /* 13 */
|
||||
odue integer not null, /* 14 */
|
||||
odid integer not null, /* 15 */
|
||||
flags integer not null, /* 16 */
|
||||
data text not null /* 17 */
|
||||
);
|
||||
|
||||
create table if not exists revlog (
|
||||
id integer primary key,
|
||||
cid integer not null,
|
||||
usn integer not null,
|
||||
ease integer not null,
|
||||
ivl integer not null,
|
||||
lastIvl integer not null,
|
||||
factor integer not null,
|
||||
time integer not null,
|
||||
type integer not null
|
||||
);
|
||||
|
||||
create table if not exists graves (
|
||||
usn integer not null,
|
||||
oid integer not null,
|
||||
type integer not null
|
||||
);
|
||||
|
||||
insert or ignore into col
|
||||
values(1,0,0,%(s)s,%(v)s,0,0,0,'','{}','','','{}');
|
||||
"""
|
||||
% ({"v": SCHEMA_VERSION, "s": intTime(1000)})
|
||||
)
|
||||
if setColConf:
|
||||
def initial_db_setup(db: DBProxy) -> None:
|
||||
_addColVars(db, *_getColVars(db))
|
||||
|
||||
|
||||
@ -202,23 +93,3 @@ update col set conf = ?, decks = ?, dconf = ?""",
|
||||
json.dumps({"1": g}),
|
||||
json.dumps({"1": gc}),
|
||||
)
|
||||
|
||||
|
||||
def _updateIndices(db: DBProxy) -> None:
|
||||
"Add indices to the DB."
|
||||
db.executescript(
|
||||
"""
|
||||
-- syncing
|
||||
create index if not exists ix_notes_usn on notes (usn);
|
||||
create index if not exists ix_cards_usn on cards (usn);
|
||||
create index if not exists ix_revlog_usn on revlog (usn);
|
||||
-- card spacing, etc
|
||||
create index if not exists ix_cards_nid on cards (nid);
|
||||
-- scheduling and deck limiting
|
||||
create index if not exists ix_cards_sched on cards (did, queue, due);
|
||||
-- revlog by card
|
||||
create index if not exists ix_revlog_cid on revlog (cid);
|
||||
-- field uniqueness
|
||||
create index if not exists ix_notes_csum on notes (csum);
|
||||
"""
|
||||
)
|
||||
|
@ -1,23 +1,26 @@
|
||||
// Copyright: Ankitects Pty Ltd and contributors
|
||||
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
|
||||
|
||||
use crate::backend_proto as pb;
|
||||
use crate::err::Result;
|
||||
use crate::storage::SqliteStorage;
|
||||
use rusqlite::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef};
|
||||
use serde_derive::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct DBRequest {
|
||||
sql: String,
|
||||
args: Vec<SqlValue>,
|
||||
#[serde(tag = "kind", rename_all = "lowercase")]
|
||||
pub(super) enum DBRequest {
|
||||
Query { sql: String, args: Vec<SqlValue> },
|
||||
Begin,
|
||||
Commit,
|
||||
Rollback,
|
||||
}
|
||||
|
||||
// #[derive(Serialize)]
|
||||
// pub(super) struct DBResult {
|
||||
// rows: Vec<Vec<SqlValue>>,
|
||||
// }
|
||||
type DBResult = Vec<Vec<SqlValue>>;
|
||||
#[derive(Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub(super) enum DBResult {
|
||||
Rows(Vec<Vec<SqlValue>>),
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
@ -55,30 +58,41 @@ impl FromSql for SqlValue {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn db_query_json_str(db: &SqliteStorage, input: &[u8]) -> Result<String> {
|
||||
pub(super) fn db_command_bytes(db: &SqliteStorage, input: &[u8]) -> Result<String> {
|
||||
let req: DBRequest = serde_json::from_slice(input)?;
|
||||
let resp = db_query_json(db, req)?;
|
||||
let resp = match req {
|
||||
DBRequest::Query { sql, args } => db_query(db, &sql, &args)?,
|
||||
DBRequest::Begin => {
|
||||
db.begin()?;
|
||||
DBResult::None
|
||||
}
|
||||
DBRequest::Commit => {
|
||||
db.commit()?;
|
||||
DBResult::None
|
||||
}
|
||||
DBRequest::Rollback => {
|
||||
db.rollback()?;
|
||||
DBResult::None
|
||||
}
|
||||
};
|
||||
Ok(serde_json::to_string(&resp)?)
|
||||
}
|
||||
|
||||
pub(super) fn db_query_json(db: &SqliteStorage, input: DBRequest) -> Result<DBResult> {
|
||||
let mut stmt = db.db.prepare_cached(&input.sql)?;
|
||||
pub(super) fn db_query(db: &SqliteStorage, sql: &str, args: &[SqlValue]) -> Result<DBResult> {
|
||||
let mut stmt = db.db.prepare_cached(sql)?;
|
||||
|
||||
let columns = stmt.column_count();
|
||||
|
||||
let mut rows = stmt.query(&input.args)?;
|
||||
|
||||
let mut output_rows = vec![];
|
||||
|
||||
while let Some(row) = rows.next()? {
|
||||
let res: std::result::Result<Vec<Vec<_>>, rusqlite::Error> = stmt
|
||||
.query_map(args, |row| {
|
||||
let mut orow = Vec::with_capacity(columns);
|
||||
for i in 0..columns {
|
||||
let v: SqlValue = row.get(i)?;
|
||||
orow.push(v);
|
||||
}
|
||||
Ok(orow)
|
||||
})?
|
||||
.collect();
|
||||
|
||||
output_rows.push(orow);
|
||||
}
|
||||
|
||||
Ok(output_rows)
|
||||
Ok(DBResult::Rows(res?))
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright: Ankitects Pty Ltd and contributors
|
||||
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
|
||||
|
||||
use crate::backend::dbproxy::db_query_json_str;
|
||||
use crate::backend::dbproxy::db_command_bytes;
|
||||
use crate::backend_proto::backend_input::Value;
|
||||
use crate::backend_proto::{Empty, RenderedTemplateReplacement, SyncMediaIn};
|
||||
use crate::err::{AnkiError, NetworkErrorKind, Result, SyncErrorKind};
|
||||
@ -37,7 +37,7 @@ pub struct Backend {
|
||||
media_folder: PathBuf,
|
||||
media_db: String,
|
||||
progress_callback: Option<ProtoProgressCallback>,
|
||||
i18n: I18n,
|
||||
pub i18n: I18n,
|
||||
log: Logger,
|
||||
}
|
||||
|
||||
@ -493,8 +493,8 @@ impl Backend {
|
||||
checker.restore_trash()
|
||||
}
|
||||
|
||||
pub fn db_query(&self, input: pb::DbQueryIn) -> Result<pb::DbQueryOut> {
|
||||
db_query_proto(&self.col, input)
|
||||
pub fn db_command(&self, input: &[u8]) -> Result<String> {
|
||||
db_command_bytes(&self.col, input)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4,28 +4,12 @@
|
||||
use crate::err::Result;
|
||||
use crate::err::{AnkiError, DBErrorKind};
|
||||
use crate::time::i64_unix_timestamp;
|
||||
use rusqlite::types::{FromSql, FromSqlError, FromSqlResult, ValueRef};
|
||||
use rusqlite::{params, Connection, OptionalExtension, NO_PARAMS};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde_derive::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::borrow::Cow;
|
||||
use std::convert::TryFrom;
|
||||
use std::fmt;
|
||||
use rusqlite::{params, Connection, NO_PARAMS};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
const SCHEMA_MIN_VERSION: u8 = 11;
|
||||
const SCHEMA_MAX_VERSION: u8 = 11;
|
||||
|
||||
macro_rules! cached_sql {
|
||||
( $label:expr, $db:expr, $sql:expr ) => {{
|
||||
if $label.is_none() {
|
||||
$label = Some($db.prepare_cached($sql)?);
|
||||
}
|
||||
$label.as_mut().unwrap()
|
||||
}};
|
||||
}
|
||||
|
||||
// currently public for dbproxy
|
||||
#[derive(Debug)]
|
||||
pub struct SqliteStorage {
|
||||
@ -42,6 +26,8 @@ fn open_or_create_collection_db(path: &Path) -> Result<Connection> {
|
||||
db.trace(Some(trace));
|
||||
}
|
||||
|
||||
db.busy_timeout(std::time::Duration::from_secs(0))?;
|
||||
|
||||
db.pragma_update(None, "locking_mode", &"exclusive")?;
|
||||
db.pragma_update(None, "page_size", &4096)?;
|
||||
db.pragma_update(None, "cache_size", &(-40 * 1024))?;
|
||||
@ -78,7 +64,6 @@ impl SqliteStorage {
|
||||
|
||||
let (create, ver) = schema_version(&db)?;
|
||||
if create {
|
||||
unimplemented!(); // todo
|
||||
db.prepare_cached("begin exclusive")?.execute(NO_PARAMS)?;
|
||||
db.execute_batch(include_str!("schema11.sql"))?;
|
||||
db.execute(
|
||||
@ -118,12 +103,16 @@ impl SqliteStorage {
|
||||
}
|
||||
|
||||
pub(crate) fn commit(&self) -> Result<()> {
|
||||
if !self.db.is_autocommit() {
|
||||
self.db.prepare_cached("commit")?.execute(NO_PARAMS)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn rollback(&self) -> Result<()> {
|
||||
if !self.db.is_autocommit() {
|
||||
self.db.execute("rollback", NO_PARAMS)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -4,9 +4,10 @@
|
||||
use anki::backend::{
|
||||
init_backend, init_i18n_backend, Backend as RustBackend, I18nBackend as RustI18nBackend,
|
||||
};
|
||||
use pyo3::exceptions::Exception;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{exceptions, wrap_pyfunction};
|
||||
use pyo3::{create_exception, exceptions, wrap_pyfunction};
|
||||
|
||||
// Regular backend
|
||||
//////////////////////////////////
|
||||
@ -16,6 +17,8 @@ struct Backend {
|
||||
backend: RustBackend,
|
||||
}
|
||||
|
||||
create_exception!(ankirspy, DBError, Exception);
|
||||
|
||||
#[pyfunction]
|
||||
fn buildhash() -> &'static str {
|
||||
include_str!("../../meta/buildhash").trim()
|
||||
@ -71,11 +74,15 @@ impl Backend {
|
||||
}
|
||||
}
|
||||
|
||||
fn db_query(&mut self, py: Python, input: &PyBytes) -> PyObject {
|
||||
fn db_command(&mut self, py: Python, input: &PyBytes) -> PyResult<PyObject> {
|
||||
let in_bytes = input.as_bytes();
|
||||
let out_string = self.backend.db_query(in_bytes).unwrap();
|
||||
let out_string = self
|
||||
.backend
|
||||
.db_command(in_bytes)
|
||||
.map_err(|e| DBError::py_err(e.localized_description(&self.backend.i18n)))?;
|
||||
|
||||
let out_obj = PyBytes::new(py, out_string.as_bytes());
|
||||
out_obj.into()
|
||||
Ok(out_obj.into())
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user