wrap the collection in a mutex so DB access is thread safe

This commit is contained in:
Damien Elmes 2020-03-06 09:22:46 +10:00
parent 47c142a74c
commit 14546c8a8b

View File

@ -8,7 +8,7 @@ use crate::collection::{open_collection, Collection};
use crate::err::{AnkiError, NetworkErrorKind, Result, SyncErrorKind}; use crate::err::{AnkiError, NetworkErrorKind, Result, SyncErrorKind};
use crate::i18n::{tr_args, FString, I18n}; use crate::i18n::{tr_args, FString, I18n};
use crate::latex::{extract_latex, extract_latex_expanding_clozes, ExtractedLatex}; use crate::latex::{extract_latex, extract_latex_expanding_clozes, ExtractedLatex};
use crate::log::default_logger; use crate::log::{default_logger, Logger};
use crate::media::check::MediaChecker; use crate::media::check::MediaChecker;
use crate::media::sync::MediaSyncProgress; use crate::media::sync::MediaSyncProgress;
use crate::media::MediaManager; use crate::media::MediaManager;
@ -24,6 +24,7 @@ use fluent::FluentValue;
use prost::Message; use prost::Message;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
mod dbproxy; mod dbproxy;
@ -31,12 +32,14 @@ mod dbproxy;
pub type ProtoProgressCallback = Box<dyn Fn(Vec<u8>) -> bool + Send>; pub type ProtoProgressCallback = Box<dyn Fn(Vec<u8>) -> bool + Send>;
pub struct Backend { pub struct Backend {
col: Collection, col: Arc<Mutex<Collection>>,
#[allow(dead_code)] #[allow(dead_code)]
col_path: PathBuf, col_path: PathBuf,
media_folder: PathBuf, media_folder: PathBuf,
media_db: String, media_db: String,
progress_callback: Option<ProtoProgressCallback>, progress_callback: Option<ProtoProgressCallback>,
i18n: I18n,
log: Logger,
} }
enum Progress<'a> { enum Progress<'a> {
@ -122,7 +125,12 @@ pub fn init_backend(init_msg: &[u8]) -> std::result::Result<Backend, String> {
log::terminal(), log::terminal(),
); );
let col = open_collection(&input.collection_path, input.server, i18n, logger) let col = open_collection(
&input.collection_path,
input.server,
i18n.clone(),
logger.clone(),
)
.map_err(|e| format!("Unable to open collection: {:?}", e))?; .map_err(|e| format!("Unable to open collection: {:?}", e))?;
match Backend::new( match Backend::new(
@ -130,6 +138,8 @@ pub fn init_backend(init_msg: &[u8]) -> std::result::Result<Backend, String> {
&input.collection_path, &input.collection_path,
&input.media_folder_path, &input.media_folder_path,
&input.media_db_path, &input.media_db_path,
i18n,
logger,
) { ) {
Ok(backend) => Ok(backend), Ok(backend) => Ok(backend),
Err(e) => Err(format!("{:?}", e)), Err(e) => Err(format!("{:?}", e)),
@ -142,18 +152,22 @@ impl Backend {
col_path: &str, col_path: &str,
media_folder: &str, media_folder: &str,
media_db: &str, media_db: &str,
i18n: I18n,
log: Logger,
) -> Result<Backend> { ) -> Result<Backend> {
Ok(Backend { Ok(Backend {
col, col: Arc::new(Mutex::new(col)),
col_path: col_path.into(), col_path: col_path.into(),
media_folder: media_folder.into(), media_folder: media_folder.into(),
media_db: media_db.into(), media_db: media_db.into(),
progress_callback: None, progress_callback: None,
i18n,
log,
}) })
} }
pub fn i18n(&self) -> &I18n { pub fn i18n(&self) -> &I18n {
&self.col.i18n &self.i18n
} }
/// Decode a request, process it, and return the encoded result. /// Decode a request, process it, and return the encoded result.
@ -165,7 +179,7 @@ impl Backend {
Err(_e) => { Err(_e) => {
// unable to decode // unable to decode
let err = AnkiError::invalid_input("couldn't decode backend request"); let err = AnkiError::invalid_input("couldn't decode backend request");
let oerr = anki_error_to_proto_error(err, &self.col.i18n); let oerr = anki_error_to_proto_error(err, &self.i18n);
let output = pb::BackendOutput { let output = pb::BackendOutput {
value: Some(oerr.into()), value: Some(oerr.into()),
}; };
@ -183,12 +197,12 @@ impl Backend {
let oval = if let Some(ival) = input.value { let oval = if let Some(ival) = input.value {
match self.run_command_inner(ival) { match self.run_command_inner(ival) {
Ok(output) => output, Ok(output) => output,
Err(err) => anki_error_to_proto_error(err, &self.col.i18n).into(), Err(err) => anki_error_to_proto_error(err, &self.i18n).into(),
} }
} else { } else {
anki_error_to_proto_error( anki_error_to_proto_error(
AnkiError::invalid_input("unrecognized backend input value"), AnkiError::invalid_input("unrecognized backend input value"),
&self.col.i18n, &self.i18n,
) )
.into() .into()
}; };
@ -233,12 +247,12 @@ impl Backend {
Value::StudiedToday(input) => OValue::StudiedToday(studied_today( Value::StudiedToday(input) => OValue::StudiedToday(studied_today(
input.cards as usize, input.cards as usize,
input.seconds as f32, input.seconds as f32,
&self.col.i18n, &self.i18n,
)), )),
Value::CongratsLearnMsg(input) => OValue::CongratsLearnMsg(learning_congrats( Value::CongratsLearnMsg(input) => OValue::CongratsLearnMsg(learning_congrats(
input.remaining as usize, input.remaining as usize,
input.next_due, input.next_due,
&self.col.i18n, &self.i18n,
)), )),
Value::EmptyTrash(_) => { Value::EmptyTrash(_) => {
self.empty_trash()?; self.empty_trash()?;
@ -253,7 +267,7 @@ impl Backend {
fn fire_progress_callback(&self, progress: Progress) -> bool { fn fire_progress_callback(&self, progress: Progress) -> bool {
if let Some(cb) = &self.progress_callback { if let Some(cb) = &self.progress_callback {
let bytes = progress_to_proto_bytes(progress, &self.col.i18n); let bytes = progress_to_proto_bytes(progress, &self.i18n);
cb(bytes) cb(bytes)
} else { } else {
true true
@ -333,7 +347,7 @@ impl Backend {
&input.answer_template, &input.answer_template,
&fields, &fields,
input.card_ordinal as u16, input.card_ordinal as u16,
&self.col.i18n, &self.i18n,
)?; )?;
// return // return
@ -411,7 +425,7 @@ impl Backend {
}; };
let mut rt = Runtime::new().unwrap(); let mut rt = Runtime::new().unwrap();
rt.block_on(mgr.sync_media(callback, &input.endpoint, &input.hkey, self.col.log.clone())) rt.block_on(mgr.sync_media(callback, &input.endpoint, &input.hkey, self.log.clone()))
} }
fn check_media(&self) -> Result<pb::MediaCheckOut> { fn check_media(&self) -> Result<pb::MediaCheckOut> {
@ -419,7 +433,7 @@ impl Backend {
|progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32)); |progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32));
let mgr = MediaManager::new(&self.media_folder, &self.media_db)?; let mgr = MediaManager::new(&self.media_folder, &self.media_db)?;
self.col.transact(None, |ctx| { self.col.lock().unwrap().transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, callback); let mut checker = MediaChecker::new(ctx, &mgr, callback);
let mut output = checker.check()?; let mut output = checker.check()?;
@ -452,7 +466,7 @@ impl Backend {
.map(|(k, v)| (k.as_str(), translate_arg_to_fluent_val(&v))) .map(|(k, v)| (k.as_str(), translate_arg_to_fluent_val(&v)))
.collect(); .collect();
self.col.i18n.trn(key, map) self.i18n.trn(key, map)
} }
fn format_time_span(&self, input: pb::FormatTimeSpanIn) -> String { fn format_time_span(&self, input: pb::FormatTimeSpanIn) -> String {
@ -461,14 +475,12 @@ impl Backend {
None => return "".to_string(), None => return "".to_string(),
}; };
match context { match context {
pb::format_time_span_in::Context::Precise => { pb::format_time_span_in::Context::Precise => time_span(input.seconds, &self.i18n, true),
time_span(input.seconds, &self.col.i18n, true)
}
pb::format_time_span_in::Context::Intervals => { pb::format_time_span_in::Context::Intervals => {
time_span(input.seconds, &self.col.i18n, false) time_span(input.seconds, &self.i18n, false)
} }
pb::format_time_span_in::Context::AnswerButtons => { pb::format_time_span_in::Context::AnswerButtons => {
answer_button_time(input.seconds, &self.col.i18n) answer_button_time(input.seconds, &self.i18n)
} }
} }
} }
@ -478,7 +490,7 @@ impl Backend {
|progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32)); |progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32));
let mgr = MediaManager::new(&self.media_folder, &self.media_db)?; let mgr = MediaManager::new(&self.media_folder, &self.media_db)?;
self.col.transact(None, |ctx| { self.col.lock().unwrap().transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, callback); let mut checker = MediaChecker::new(ctx, &mgr, callback);
checker.empty_trash() checker.empty_trash()
@ -490,7 +502,7 @@ impl Backend {
|progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32)); |progress: usize| self.fire_progress_callback(Progress::MediaCheck(progress as u32));
let mgr = MediaManager::new(&self.media_folder, &self.media_db)?; let mgr = MediaManager::new(&self.media_folder, &self.media_db)?;
self.col.transact(None, |ctx| { self.col.lock().unwrap().transact(None, |ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, callback); let mut checker = MediaChecker::new(ctx, &mgr, callback);
checker.restore_trash() checker.restore_trash()
@ -498,7 +510,10 @@ impl Backend {
} }
pub fn db_command(&self, input: &[u8]) -> Result<String> { pub fn db_command(&self, input: &[u8]) -> Result<String> {
db_command_bytes(&self.col.storage.context(self.col.server), input) self.col
.lock()
.unwrap()
.with_ctx(|ctx| db_command_bytes(&ctx.storage, input))
} }
} }