diff --git a/qt/aqt/mediacheck.py b/qt/aqt/mediacheck.py index 44700b18f..3682af587 100644 --- a/qt/aqt/mediacheck.py +++ b/qt/aqt/mediacheck.py @@ -59,6 +59,7 @@ class MediaChecker: def _set_progress_enabled(self, enabled: bool) -> None: if self._progress_timer: + self._progress_timer.stop() self._progress_timer.deleteLater() self._progress_timer = None if enabled: diff --git a/qt/aqt/mediasync.py b/qt/aqt/mediasync.py index d0796aae7..a5617c0b6 100644 --- a/qt/aqt/mediasync.py +++ b/qt/aqt/mediasync.py @@ -78,6 +78,7 @@ class MediaSyncer: def _on_finished(self, future: Future) -> None: self._syncing = False if self._progress_timer: + self._progress_timer.stop() self._progress_timer.deleteLater() self._progress_timer = None gui_hooks.media_sync_did_start_or_stop(False) diff --git a/qt/aqt/progress.py b/qt/aqt/progress.py index fd2561ef6..bd85834fb 100644 --- a/qt/aqt/progress.py +++ b/qt/aqt/progress.py @@ -236,6 +236,7 @@ class ProgressManager: self._show_timer.stop() self._show_timer = None if self._backend_timer: + self._backend_timer.stop() self._backend_timer.deleteLater() self._backend_timer = None diff --git a/rslib/src/backend/collection.rs b/rslib/src/backend/collection.rs index 9092f10b5..4cdc2a242 100644 --- a/rslib/src/backend/collection.rs +++ b/rslib/src/backend/collection.rs @@ -7,11 +7,10 @@ pub(super) use anki_proto::collection::collection_service::Service as Collection use anki_proto::generic; use tracing::error; -use super::progress::Progress; use super::Backend; -use crate::backend::progress::progress_to_proto; use crate::collection::CollectionBuilder; use crate::prelude::*; +use crate::progress::progress_to_proto; use crate::storage::SchemaVersion; impl CollectionService for Backend { @@ -38,7 +37,8 @@ impl CollectionService for Backend { .set_force_schema11(input.force_schema11) .set_media_paths(input.media_folder_path, input.media_db_path) .set_server(self.server) - .set_tr(self.tr.clone()); + .set_tr(self.tr.clone()) + .set_shared_progress_state(self.progress_state.clone()); *guard = Some(builder.build()?); @@ -70,16 +70,11 @@ impl CollectionService for Backend { &self, _input: generic::Empty, ) -> Result { - let mut handler = self.new_progress_handler(); - let progress_fn = move |progress, throttle| { - handler.update(Progress::DatabaseCheck(progress), throttle); - }; self.with_col(|col| { - col.check_database(progress_fn).map(|problems| { - anki_proto::collection::CheckDatabaseResponse { + col.check_database() + .map(|problems| anki_proto::collection::CheckDatabaseResponse { problems: problems.to_i18n_strings(&col.tr), - } - }) + }) }) } diff --git a/rslib/src/backend/import_export.rs b/rslib/src/backend/import_export.rs index 6b4b0dd47..6e1ccb519 100644 --- a/rslib/src/backend/import_export.rs +++ b/rslib/src/backend/import_export.rs @@ -8,11 +8,8 @@ use anki_proto::import_export::export_limit; pub(super) use anki_proto::import_export::importexport_service::Service as ImportExportService; use anki_proto::import_export::ExportLimit; -use super::progress::Progress; use super::Backend; use crate::import_export::package::import_colpkg; -use crate::import_export::ExportProgress; -use crate::import_export::ImportProgress; use crate::import_export::NoteLog; use crate::prelude::*; use crate::search::SearchNode; @@ -30,12 +27,7 @@ impl ImportExportService for Backend { let col_inner = guard.take().unwrap(); col_inner - .export_colpkg( - input.out_path, - input.include_media, - input.legacy, - self.export_progress_fn(), - ) + .export_colpkg(input.out_path, input.include_media, input.legacy) .map(Into::into) } @@ -50,7 +42,7 @@ impl ImportExportService for Backend { &input.col_path, Path::new(&input.media_folder), Path::new(&input.media_db), - self.import_progress_fn(), + self.new_progress_handler(), ) .map(Into::into) } @@ -59,7 +51,7 @@ impl ImportExportService for Backend { &self, input: anki_proto::import_export::ImportAnkiPackageRequest, ) -> Result { - self.with_col(|col| col.import_apkg(&input.package_path, self.import_progress_fn())) + self.with_col(|col| col.import_apkg(&input.package_path)) .map(Into::into) } @@ -75,7 +67,6 @@ impl ImportExportService for Backend { input.with_media, input.legacy, None, - self.export_progress_fn(), ) }) .map(Into::into) @@ -101,21 +92,15 @@ impl ImportExportService for Backend { &self, input: anki_proto::import_export::ImportCsvRequest, ) -> Result { - self.with_col(|col| { - col.import_csv( - &input.path, - input.metadata.unwrap_or_default(), - self.import_progress_fn(), - ) - }) - .map(Into::into) + self.with_col(|col| col.import_csv(&input.path, input.metadata.unwrap_or_default())) + .map(Into::into) } fn export_note_csv( &self, input: anki_proto::import_export::ExportNoteCsvRequest, ) -> Result { - self.with_col(|col| col.export_note_csv(input, self.export_progress_fn())) + self.with_col(|col| col.export_note_csv(input)) .map(Into::into) } @@ -128,7 +113,6 @@ impl ImportExportService for Backend { &input.out_path, SearchNode::from(input.limit.unwrap_or_default()), input.with_html, - self.export_progress_fn(), ) }) .map(Into::into) @@ -138,7 +122,7 @@ impl ImportExportService for Backend { &self, input: generic::String, ) -> Result { - self.with_col(|col| col.import_json_file(&input.val, self.import_progress_fn())) + self.with_col(|col| col.import_json_file(&input.val)) .map(Into::into) } @@ -146,23 +130,11 @@ impl ImportExportService for Backend { &self, input: generic::String, ) -> Result { - self.with_col(|col| col.import_json_string(&input.val, self.import_progress_fn())) + self.with_col(|col| col.import_json_string(&input.val)) .map(Into::into) } } -impl Backend { - fn import_progress_fn(&self) -> impl FnMut(ImportProgress, bool) -> bool { - let mut handler = self.new_progress_handler(); - move |progress, throttle| handler.update(Progress::Import(progress), throttle) - } - - fn export_progress_fn(&self) -> impl FnMut(ExportProgress, bool) -> bool { - let mut handler = self.new_progress_handler(); - move |progress, throttle| handler.update(Progress::Export(progress), throttle) - } -} - impl From> for anki_proto::import_export::ImportResponse { fn from(output: OpOutput) -> Self { Self { diff --git a/rslib/src/backend/media.rs b/rslib/src/backend/media.rs index 9a9ab2d8e..49362aff8 100644 --- a/rslib/src/backend/media.rs +++ b/rslib/src/backend/media.rs @@ -5,29 +5,20 @@ use anki_proto::generic; pub(super) use anki_proto::media::media_service::Service as MediaService; use super::notes::to_i64s; -use super::progress::Progress; use super::Backend; -use crate::media::check::MediaChecker; use crate::prelude::*; impl MediaService for Backend { type Error = AnkiError; - // media - //----------------------------------------------- - fn check_media(&self, _input: generic::Empty) -> Result { - let mut handler = self.new_progress_handler(); - let progress_fn = - move |progress| handler.update(Progress::MediaCheck(progress as u32), true); self.with_col(|col| { - let mgr = col.media()?; - col.transact_no_undo(|ctx| { - let mut checker = MediaChecker::new(ctx, &mgr, progress_fn); + col.transact_no_undo(|col| { + let mut checker = col.media_checker()?; let mut output = checker.check()?; let mut report = checker.summarize_output(&mut output); - ctx.report_media_field_referencing_templates(&mut report)?; + col.report_media_field_referencing_templates(&mut report)?; Ok(anki_proto::media::CheckMediaResponse { unused: output.unused, @@ -44,11 +35,8 @@ impl MediaService for Backend { &self, input: anki_proto::media::TrashMediaFilesRequest, ) -> Result { - self.with_col(|col| { - let mgr = col.media()?; - mgr.remove_files(&input.fnames) - }) - .map(Into::into) + self.with_col(|col| col.media()?.remove_files(&input.fnames)) + .map(Into::into) } fn add_media_file( @@ -56,8 +44,8 @@ impl MediaService for Backend { input: anki_proto::media::AddMediaFileRequest, ) -> Result { self.with_col(|col| { - let mgr = col.media()?; - Ok(mgr + Ok(col + .media()? .add_file(&input.desired_name, &input.data)? .to_string() .into()) @@ -65,27 +53,12 @@ impl MediaService for Backend { } fn empty_trash(&self, _input: generic::Empty) -> Result { - let mut handler = self.new_progress_handler(); - let progress_fn = - move |progress| handler.update(Progress::MediaCheck(progress as u32), true); - - self.with_col(|col| { - let mgr = col.media()?; - let mut checker = MediaChecker::new(col, &mgr, progress_fn); - checker.empty_trash() - }) - .map(Into::into) + self.with_col(|col| col.media_checker()?.empty_trash()) + .map(Into::into) } fn restore_trash(&self, _input: generic::Empty) -> Result { - let mut handler = self.new_progress_handler(); - let progress_fn = - move |progress| handler.update(Progress::MediaCheck(progress as u32), true); - self.with_col(|col| { - let mgr = col.media()?; - let mut checker = MediaChecker::new(col, &mgr, progress_fn); - checker.restore_trash() - }) - .map(Into::into) + self.with_col(|col| col.media_checker()?.restore_trash()) + .map(Into::into) } } diff --git a/rslib/src/backend/mod.rs b/rslib/src/backend/mod.rs index 25765c68c..c7b790f9d 100644 --- a/rslib/src/backend/mod.rs +++ b/rslib/src/backend/mod.rs @@ -22,7 +22,6 @@ mod media; mod notes; mod notetypes; mod ops; -mod progress; mod scheduler; mod search; mod stats; @@ -36,7 +35,6 @@ use std::thread::JoinHandle; use anki_proto::ServiceIndex; use once_cell::sync::OnceCell; -use progress::AbortHandleSlot; use prost::Message; use tokio::runtime; use tokio::runtime::Runtime; @@ -55,7 +53,6 @@ use self::links::LinksService; use self::media::MediaService; use self::notes::NotesService; use self::notetypes::NotetypesService; -use self::progress::ProgressState; use self::scheduler::SchedulerService; use self::search::SearchService; use self::stats::StatsService; @@ -64,6 +61,10 @@ use self::sync::SyncState; use self::tags::TagsService; use crate::backend::dbproxy::db_command_bytes; use crate::prelude::*; +use crate::progress::AbortHandleSlot; +use crate::progress::Progress; +use crate::progress::ProgressState; +use crate::progress::ThrottlingProgressHandler; pub struct Backend { col: Arc>>, @@ -196,4 +197,13 @@ impl Backend { fn db_command(&self, input: &[u8]) -> Result> { self.with_col(|col| db_command_bytes(col, input)) } + + /// Useful for operations that function with a closed collection, such as + /// a colpkg import. For collection operations, you can use + /// [Collection::new_progress_handler] instead. + pub(crate) fn new_progress_handler + Default + Clone>( + &self, + ) -> ThrottlingProgressHandler

{ + ThrottlingProgressHandler::new(self.progress_state.clone()) + } } diff --git a/rslib/src/backend/sync/mod.rs b/rslib/src/backend/sync/mod.rs index 93ee7f000..6ffd7b81a 100644 --- a/rslib/src/backend/sync/mod.rs +++ b/rslib/src/backend/sync/mod.rs @@ -13,15 +13,13 @@ use futures::future::Abortable; use reqwest::Url; use tracing::warn; -use super::progress::AbortHandleSlot; use super::Backend; use crate::prelude::*; +use crate::progress::AbortHandleSlot; use crate::sync::collection::normal::ClientSyncState; -use crate::sync::collection::normal::NormalSyncProgress; use crate::sync::collection::normal::SyncActionRequired; use crate::sync::collection::normal::SyncOutput; use crate::sync::collection::progress::sync_abort; -use crate::sync::collection::progress::FullSyncProgress; use crate::sync::collection::status::online_sync_status_check; use crate::sync::http_client::HttpSyncClient; use crate::sync::login::sync_login; @@ -198,12 +196,13 @@ impl Backend { } // start the sync - let mgr = self.col.lock().unwrap().as_mut().unwrap().media()?; - let mut handler = self.new_progress_handler(); - let progress_fn = move |progress| handler.update(progress, true); - + let (mgr, progress) = { + let mut col = self.col.lock().unwrap(); + let col = col.as_mut().unwrap(); + (col.media()?, col.new_progress_handler()) + }; let rt = self.runtime_handle(); - let sync_fut = mgr.sync_media(progress_fn, auth); + let sync_fut = mgr.sync_media(progress, auth); let abortable_sync = Abortable::new(sync_fut, abort_reg); let result = rt.block_on(abortable_sync); @@ -308,12 +307,7 @@ impl Backend { let rt = self.runtime_handle(); let ret = self.with_col(|col| { - let mut handler = self.new_progress_handler(); - let progress_fn = move |progress: NormalSyncProgress, throttle: bool| { - handler.update(progress, throttle); - }; - - let sync_fut = col.normal_sync(auth.clone(), progress_fn); + let sync_fut = col.normal_sync(auth.clone()); let abortable_sync = Abortable::new(sync_fut, abort_reg); match rt.block_on(abortable_sync) { @@ -360,19 +354,14 @@ impl Backend { let (_guard, abort_reg) = self.sync_abort_handle()?; - let builder = col_inner.as_builder(); - - let mut handler = self.new_progress_handler(); - let progress_fn = Box::new(move |progress: FullSyncProgress, throttle: bool| { - handler.update(progress, throttle); - }); + let mut builder = col_inner.as_builder(); let result = if upload { - let sync_fut = col_inner.full_upload(auth, progress_fn); + let sync_fut = col_inner.full_upload(auth); let abortable_sync = Abortable::new(sync_fut, abort_reg); rt.block_on(abortable_sync) } else { - let sync_fut = col_inner.full_download(auth, progress_fn); + let sync_fut = col_inner.full_download(auth); let abortable_sync = Abortable::new(sync_fut, abort_reg); rt.block_on(abortable_sync) }; diff --git a/rslib/src/collection/mod.rs b/rslib/src/collection/mod.rs index 0da766549..d80b0798a 100644 --- a/rslib/src/collection/mod.rs +++ b/rslib/src/collection/mod.rs @@ -11,6 +11,7 @@ use std::fmt::Debug; use std::fmt::Formatter; use std::path::PathBuf; use std::sync::Arc; +use std::sync::Mutex; use anki_i18n::I18n; use anki_io::create_dir_all; @@ -21,6 +22,7 @@ use crate::decks::DeckId; use crate::error::Result; use crate::notetype::Notetype; use crate::notetype::NotetypeId; +use crate::progress::ProgressState; use crate::scheduler::queue::CardQueues; use crate::scheduler::SchedulerInfo; use crate::storage::SchemaVersion; @@ -39,6 +41,7 @@ pub struct CollectionBuilder { check_integrity: bool, // temporary option for AnkiDroid force_schema11: Option, + progress_handler: Option>>, } impl CollectionBuilder { @@ -50,7 +53,7 @@ impl CollectionBuilder { builder } - pub fn build(&self) -> Result { + pub fn build(&mut self) -> Result { let col_path = self .collection_path .clone() @@ -74,7 +77,10 @@ impl CollectionBuilder { media_db, tr, server, - state: CollectionState::default(), + state: CollectionState { + progress: self.progress_handler.clone().unwrap_or_default(), + ..Default::default() + }, }; Ok(col) @@ -121,6 +127,13 @@ impl CollectionBuilder { self.check_integrity = check_integrity; self } + + /// If provided, progress info will be written to the provided mutex, and + /// can be tracked on a separate thread. + pub fn set_shared_progress_state(&mut self, state: Arc>) -> &mut Self { + self.progress_handler = Some(state); + self + } } #[derive(Debug, Default)] @@ -137,11 +150,11 @@ pub struct CollectionState { /// The modification time at the last backup, so we don't create multiple /// identical backups. pub(crate) last_backup_modified: Option, + pub(crate) progress: Arc>, } pub struct Collection { pub storage: SqliteStorage, - #[allow(dead_code)] pub(crate) col_path: PathBuf, pub(crate) media_folder: PathBuf, pub(crate) media_db: PathBuf, diff --git a/rslib/src/dbcheck.rs b/rslib/src/dbcheck.rs index 550d3cde1..310d8727e 100644 --- a/rslib/src/dbcheck.rs +++ b/rslib/src/dbcheck.rs @@ -21,6 +21,7 @@ use crate::notetype::Notetype; use crate::notetype::NotetypeId; use crate::notetype::NotetypeKind; use crate::prelude::*; +use crate::progress::ThrottlingProgressHandler; use crate::timestamp::TimestampMillis; use crate::timestamp::TimestampSecs; @@ -39,12 +40,16 @@ pub struct CheckDatabaseOutput { invalid_ids: usize, } -#[derive(Debug, Clone, Copy)] -pub(crate) enum DatabaseCheckProgress { +#[derive(Debug, Clone, Copy, Default)] +pub enum DatabaseCheckProgress { + #[default] Integrity, Optimize, Cards, - Notes { current: u32, total: u32 }, + Notes { + current: usize, + total: usize, + }, History, } @@ -93,11 +98,9 @@ impl CheckDatabaseOutput { impl Collection { /// Check the database, returning a list of problems that were fixed. - pub(crate) fn check_database(&mut self, mut progress_fn: F) -> Result - where - F: FnMut(DatabaseCheckProgress, bool), - { - progress_fn(DatabaseCheckProgress::Integrity, false); + pub(crate) fn check_database(&mut self) -> Result { + let mut progress = self.new_progress_handler(); + progress.set(DatabaseCheckProgress::Integrity)?; debug!("quick check"); if self.storage.quick_check_corrupt() { debug!("quick check failed"); @@ -107,21 +110,21 @@ impl Collection { )); } - progress_fn(DatabaseCheckProgress::Optimize, false); + progress.set(DatabaseCheckProgress::Optimize)?; debug!("optimize"); self.storage.optimize()?; - self.transact_no_undo(|col| col.check_database_inner(progress_fn)) + self.transact_no_undo(|col| col.check_database_inner(progress)) } - fn check_database_inner(&mut self, mut progress_fn: F) -> Result - where - F: FnMut(DatabaseCheckProgress, bool), - { + fn check_database_inner( + &mut self, + mut progress: ThrottlingProgressHandler, + ) -> Result { let mut out = CheckDatabaseOutput::default(); // cards first, as we need to be able to read them to process notes - progress_fn(DatabaseCheckProgress::Cards, false); + progress.set(DatabaseCheckProgress::Cards)?; debug!("check cards"); self.check_card_properties(&mut out)?; self.check_orphaned_cards(&mut out)?; @@ -131,9 +134,9 @@ impl Collection { self.check_filtered_cards(&mut out)?; debug!("check notetypes"); - self.check_notetypes(&mut out, &mut progress_fn)?; + self.check_notetypes(&mut out, &mut progress)?; - progress_fn(DatabaseCheckProgress::History, false); + progress.set(DatabaseCheckProgress::History)?; debug!("check review log"); self.check_revlog(&mut out)?; @@ -207,14 +210,11 @@ impl Collection { Ok(()) } - fn check_notetypes( + fn check_notetypes( &mut self, out: &mut CheckDatabaseOutput, - mut progress_fn: F, - ) -> Result<()> - where - F: FnMut(DatabaseCheckProgress, bool), - { + progress: &mut ThrottlingProgressHandler, + ) -> Result<()> { let nids_by_notetype = self.storage.all_note_ids_by_notetype()?; let norm = self.get_config_bool(BoolKey::NormalizeNoteText); let usn = self.usn()?; @@ -225,8 +225,11 @@ impl Collection { self.storage.clear_all_tags()?; let total_notes = self.storage.total_notes()?; - let mut checked_notes = 0; + progress.set(DatabaseCheckProgress::Notes { + current: 0, + total: total_notes as usize, + })?; for (ntid, group) in &nids_by_notetype.into_iter().group_by(|tup| tup.0) { debug!("check notetype: {}", ntid); let mut group = group.peekable(); @@ -241,14 +244,10 @@ impl Collection { let mut genctx = None; for (_, nid) in group { - progress_fn( - DatabaseCheckProgress::Notes { - current: checked_notes, - total: total_notes, - }, - true, - ); - checked_notes += 1; + progress.increment(|p| { + let DatabaseCheckProgress::Notes { current, .. } = p else { unreachable!() }; + current + })?; let mut note = self.get_note_fixing_invalid_utf8(nid, out)?; let original = note.clone(); @@ -434,8 +433,6 @@ mod test { use crate::decks::DeckId; use crate::search::SortMode; - fn progress_fn(_progress: DatabaseCheckProgress, _throttle: bool) {} - #[test] fn cards() -> Result<()> { let mut col = Collection::new(); @@ -448,7 +445,7 @@ mod test { .db .execute_batch("update cards set ivl=1.5,due=2000000,odue=1.5")?; - let out = col.check_database(progress_fn)?; + let out = col.check_database()?; assert_eq!( out, CheckDatabaseOutput { @@ -458,12 +455,12 @@ mod test { } ); // should be idempotent - assert_eq!(col.check_database(progress_fn)?, Default::default()); + assert_eq!(col.check_database()?, Default::default()); // missing deck col.storage.db.execute_batch("update cards set did=123")?; - let out = col.check_database(progress_fn)?; + let out = col.check_database()?; assert_eq!( out, CheckDatabaseOutput { @@ -482,7 +479,7 @@ mod test { // missing note col.storage.remove_note(note.id)?; - let out = col.check_database(progress_fn)?; + let out = col.check_database()?; assert_eq!( out, CheckDatabaseOutput { @@ -508,7 +505,7 @@ mod test { values (0,0,0,0,1.5,1.5,0,0,0)", )?; - let out = col.check_database(progress_fn)?; + let out = col.check_database()?; assert_eq!( out, CheckDatabaseOutput { @@ -536,7 +533,7 @@ mod test { card.id.0 += 1; col.storage.add_card(&mut card)?; - let out = col.check_database(progress_fn)?; + let out = col.check_database()?; assert_eq!( out, CheckDatabaseOutput { @@ -556,7 +553,7 @@ mod test { card.template_idx = 10; col.storage.add_card(&mut card)?; - let out = col.check_database(progress_fn)?; + let out = col.check_database()?; assert_eq!( out, CheckDatabaseOutput { @@ -583,7 +580,7 @@ mod test { col.storage .db .execute_batch("update notes set flds = 'a\x1fb\x1fc\x1fd'")?; - let out = col.check_database(progress_fn)?; + let out = col.check_database()?; assert_eq!( out, CheckDatabaseOutput { @@ -598,7 +595,7 @@ mod test { col.storage .db .execute_batch("update notes set flds = 'a'")?; - let out = col.check_database(progress_fn)?; + let out = col.check_database()?; assert_eq!( out, CheckDatabaseOutput { @@ -626,7 +623,7 @@ mod test { .execute([deck.id])?; assert_eq!(col.storage.get_all_deck_names()?.len(), 2); - let out = col.check_database(progress_fn)?; + let out = col.check_database()?; assert_eq!( out, CheckDatabaseOutput { @@ -657,7 +654,7 @@ mod test { col.set_tag_collapsed("one", false)?; - col.check_database(progress_fn)?; + col.check_database()?; assert!(col.storage.get_tag("one")?.unwrap().expanded); assert!(!col.storage.get_tag("two")?.unwrap().expanded); diff --git a/rslib/src/import_export/gather.rs b/rslib/src/import_export/gather.rs index 3182defe5..ea6afd974 100644 --- a/rslib/src/import_export/gather.rs +++ b/rslib/src/import_export/gather.rs @@ -8,10 +8,10 @@ use anki_io::filename_is_safe; use itertools::Itertools; use super::ExportProgress; -use super::IncrementableProgress; use crate::decks::immediate_parent_name; use crate::latex::extract_latex; use crate::prelude::*; +use crate::progress::ThrottlingProgressHandler; use crate::revlog::RevlogEntry; use crate::search::CardTableGuard; use crate::search::NoteTableGuard; @@ -61,7 +61,7 @@ impl ExchangeData { pub(super) fn gather_media_names( &mut self, - progress: &mut IncrementableProgress, + progress: &mut ThrottlingProgressHandler, ) -> Result<()> { let mut inserter = |name: String| { if filename_is_safe(&name) { diff --git a/rslib/src/import_export/mod.rs b/rslib/src/import_export/mod.rs index 3b070def3..c3ac029ac 100644 --- a/rslib/src/import_export/mod.rs +++ b/rslib/src/import_export/mod.rs @@ -6,8 +6,6 @@ mod insert; pub mod package; pub mod text; -use std::marker::PhantomData; - pub use anki_proto::import_export::import_response::Log as NoteLog; pub use anki_proto::import_export::import_response::Note as LogNote; use snafu::Snafu; @@ -18,18 +16,20 @@ use crate::text::strip_html_preserving_media_filenames; use crate::text::truncate_to_char_boundary; use crate::text::CowMapping; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum ImportProgress { - File, + #[default] Extracting, + File, Gathering, Media(usize), MediaCheck(usize), Notes(usize), } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum ExportProgress { + #[default] File, Gathering, Notes(usize), @@ -37,80 +37,6 @@ pub enum ExportProgress { Media(usize), } -/// Wrapper around a progress function, usually passed by the -/// [crate::backend::Backend], to make repeated calls more ergonomic. -pub(crate) struct IncrementableProgress

(Box bool>); - -impl

IncrementableProgress

{ - /// `progress_fn: (progress, throttle) -> should_continue` - pub(crate) fn new(progress_fn: impl 'static + FnMut(P, bool) -> bool) -> Self { - Self(Box::new(progress_fn)) - } - - /// Returns an [Incrementor] with an `increment()` function for use in - /// loops. - pub(crate) fn incrementor<'inc, 'progress: 'inc, 'map: 'inc>( - &'progress mut self, - mut count_map: impl 'map + FnMut(usize) -> P, - ) -> Incrementor<'inc, impl FnMut(usize) -> Result<()> + 'inc> { - Incrementor::new(move |u| self.update(count_map(u), true)) - } - - /// Manually triggers an update. - /// Returns [AnkiError::Interrupted] if the operation should be cancelled. - pub(crate) fn call(&mut self, progress: P) -> Result<()> { - self.update(progress, false) - } - - fn update(&mut self, progress: P, throttle: bool) -> Result<()> { - if (self.0)(progress, throttle) { - Ok(()) - } else { - Err(AnkiError::Interrupted) - } - } - - /// Stopgap for returning a progress fn compliant with the media code. - pub(crate) fn media_db_fn( - &mut self, - count_map: impl 'static + Fn(usize) -> P, - ) -> Result bool + '_> { - Ok(move |count| (self.0)(count_map(count), true)) - } -} - -pub(crate) struct Incrementor<'f, F: 'f + FnMut(usize) -> Result<()>> { - update_fn: F, - count: usize, - update_interval: usize, - _phantom: PhantomData<&'f ()>, -} - -impl<'f, F: 'f + FnMut(usize) -> Result<()>> Incrementor<'f, F> { - fn new(update_fn: F) -> Self { - Self { - update_fn, - count: 0, - update_interval: 17, - _phantom: PhantomData, - } - } - - /// Increments the progress counter, periodically triggering an update. - /// Returns [AnkiError::Interrupted] if the operation should be cancelled. - pub(crate) fn increment(&mut self) -> Result<()> { - self.count += 1; - if self.count % self.update_interval != 0 { - return Ok(()); - } - (self.update_fn)(self.count) - } - - pub(crate) fn count(&self) -> usize { - self.count - } -} - impl Note { pub(crate) fn into_log_note(self) -> LogNote { LogNote { diff --git a/rslib/src/import_export/package/apkg/export.rs b/rslib/src/import_export/package/apkg/export.rs index 9fd5c2d56..6bf16f0a8 100644 --- a/rslib/src/import_export/package/apkg/export.rs +++ b/rslib/src/import_export/package/apkg/export.rs @@ -16,8 +16,8 @@ use crate::import_export::package::colpkg::export::export_collection; use crate::import_export::package::media::MediaIter; use crate::import_export::package::Meta; use crate::import_export::ExportProgress; -use crate::import_export::IncrementableProgress; use crate::prelude::*; +use crate::progress::ThrottlingProgressHandler; impl Collection { /// Returns number of exported notes. @@ -30,10 +30,8 @@ impl Collection { with_media: bool, legacy: bool, media_fn: Option) -> MediaIter>>, - progress_fn: impl 'static + FnMut(ExportProgress, bool) -> bool, ) -> Result { - let mut progress = IncrementableProgress::new(progress_fn); - progress.call(ExportProgress::File)?; + let mut progress = self.new_progress_handler(); let temp_apkg = new_tempfile_in_parent_of(out_path.as_ref())?; let mut temp_col = new_tempfile()?; let temp_col_path = temp_col @@ -54,7 +52,7 @@ impl Collection { with_media, )?; - progress.call(ExportProgress::File)?; + progress.set(ExportProgress::File)?; let media = if let Some(media_fn) = media_fn { media_fn(data.media_filenames) } else { @@ -80,19 +78,19 @@ impl Collection { meta: &Meta, path: &str, search: impl TryIntoSearch, - progress: &mut IncrementableProgress, + progress: &mut ThrottlingProgressHandler, with_scheduling: bool, with_media: bool, ) -> Result { let mut data = ExchangeData::default(); - progress.call(ExportProgress::Gathering)?; + progress.set(ExportProgress::Gathering)?; data.gather_data(self, search, with_scheduling)?; if with_media { data.gather_media_names(progress)?; } let mut temp_col = Collection::new_minimal(path)?; - progress.call(ExportProgress::File)?; + progress.set(ExportProgress::File)?; temp_col.insert_data(&data)?; temp_col.set_creation_stamp(self.storage.creation_stamp()?)?; temp_col.set_creation_utc_offset(data.creation_utc_offset)?; diff --git a/rslib/src/import_export/package/apkg/import/media.rs b/rslib/src/import_export/package/apkg/import/media.rs index f9e7ca391..d467a5333 100644 --- a/rslib/src/import_export/package/apkg/import/media.rs +++ b/rslib/src/import_export/package/apkg/import/media.rs @@ -15,10 +15,10 @@ use crate::import_export::package::media::extract_media_entries; use crate::import_export::package::media::MediaCopier; use crate::import_export::package::media::SafeMediaEntry; use crate::import_export::ImportProgress; -use crate::import_export::IncrementableProgress; use crate::media::files::add_hash_suffix_to_file_stem; use crate::media::files::sha1_of_reader; use crate::prelude::*; +use crate::progress::ThrottlingProgressHandler; /// Map of source media files, that do not already exist in the target. #[derive(Default)] @@ -76,7 +76,7 @@ fn prepare_media( media_entries: Vec, archive: &mut ZipArchive, existing_sha1s: &HashMap, - progress: &mut IncrementableProgress, + progress: &mut ThrottlingProgressHandler, ) -> Result { let mut media_map = MediaUseMap::default(); let mut incrementor = progress.incrementor(ImportProgress::MediaCheck); diff --git a/rslib/src/import_export/package/apkg/import/mod.rs b/rslib/src/import_export/package/apkg/import/mod.rs index ddc32384a..fcce79db3 100644 --- a/rslib/src/import_export/package/apkg/import/mod.rs +++ b/rslib/src/import_export/package/apkg/import/mod.rs @@ -24,10 +24,10 @@ use crate::collection::CollectionBuilder; use crate::import_export::gather::ExchangeData; use crate::import_export::package::Meta; use crate::import_export::ImportProgress; -use crate::import_export::IncrementableProgress; use crate::import_export::NoteLog; use crate::media::MediaManager; use crate::prelude::*; +use crate::progress::ThrottlingProgressHandler; use crate::search::SearchNode; struct Context<'a> { @@ -37,20 +37,17 @@ struct Context<'a> { meta: Meta, data: ExchangeData, usn: Usn, - progress: IncrementableProgress, + progress: ThrottlingProgressHandler, } impl Collection { - pub fn import_apkg( - &mut self, - path: impl AsRef, - progress_fn: impl 'static + FnMut(ImportProgress, bool) -> bool, - ) -> Result> { + pub fn import_apkg(&mut self, path: impl AsRef) -> Result> { let file = open_file(path)?; let archive = ZipArchive::new(file)?; + let progress = self.new_progress_handler(); self.transact(Op::Import, |col| { - let mut ctx = Context::new(archive, col, progress_fn)?; + let mut ctx = Context::new(archive, col, progress)?; ctx.import() }) } @@ -60,10 +57,8 @@ impl<'a> Context<'a> { fn new( mut archive: ZipArchive, target_col: &'a mut Collection, - progress_fn: impl 'static + FnMut(ImportProgress, bool) -> bool, + mut progress: ThrottlingProgressHandler, ) -> Result { - let mut progress = IncrementableProgress::new(progress_fn); - progress.call(ImportProgress::Extracting)?; let media_manager = target_col.media()?; let meta = Meta::from_archive(&mut archive)?; let data = ExchangeData::gather_from_archive( @@ -102,7 +97,7 @@ impl ExchangeData { archive: &mut ZipArchive, meta: &Meta, search: impl TryIntoSearch, - progress: &mut IncrementableProgress, + progress: &mut ThrottlingProgressHandler, with_scheduling: bool, ) -> Result { let tempfile = collection_to_tempfile(meta, archive)?; @@ -110,7 +105,7 @@ impl ExchangeData { col.maybe_fix_invalid_ids()?; col.maybe_upgrade_scheduler()?; - progress.call(ImportProgress::Gathering)?; + progress.set(ImportProgress::Gathering)?; let mut data = ExchangeData::default(); data.gather_data(&mut col, search, with_scheduling)?; diff --git a/rslib/src/import_export/package/apkg/import/notes.rs b/rslib/src/import_export/package/apkg/import/notes.rs index 4782e1280..68035bfb0 100644 --- a/rslib/src/import_export/package/apkg/import/notes.rs +++ b/rslib/src/import_export/package/apkg/import/notes.rs @@ -14,9 +14,9 @@ use super::media::MediaUseMap; use super::Context; use crate::import_export::package::media::safe_normalized_file_name; use crate::import_export::ImportProgress; -use crate::import_export::IncrementableProgress; use crate::import_export::NoteLog; use crate::prelude::*; +use crate::progress::ThrottlingProgressHandler; use crate::text::replace_media_refs; struct NoteContext<'a> { @@ -164,7 +164,7 @@ impl<'n> NoteContext<'n> { fn import_notes( &mut self, notes: Vec, - progress: &mut IncrementableProgress, + progress: &mut ThrottlingProgressHandler, ) -> Result<()> { let mut incrementor = progress.incrementor(ImportProgress::Notes); @@ -297,15 +297,15 @@ mod test { macro_rules! import_note { ($col:expr, $note:expr, $old_notetype:expr => $new_notetype:expr) => {{ let mut media_map = MediaUseMap::default(); + let mut progress = $col.new_progress_handler(); let mut ctx = NoteContext::new(Usn(1), &mut $col, &mut media_map).unwrap(); ctx.remapped_notetypes.insert($old_notetype, $new_notetype); - let mut progress = IncrementableProgress::new(|_, _| true); ctx.import_notes(vec![$note], &mut progress).unwrap(); ctx.imports.log }}; ($col:expr, $note:expr, $media_map:expr) => {{ + let mut progress = $col.new_progress_handler(); let mut ctx = NoteContext::new(Usn(1), &mut $col, &mut $media_map).unwrap(); - let mut progress = IncrementableProgress::new(|_, _| true); ctx.import_notes(vec![$note], &mut progress).unwrap(); ctx.imports.log }}; diff --git a/rslib/src/import_export/package/apkg/tests.rs b/rslib/src/import_export/package/apkg/tests.rs index 303812ab3..2553334a1 100644 --- a/rslib/src/import_export/package/apkg/tests.rs +++ b/rslib/src/import_export/package/apkg/tests.rs @@ -48,10 +48,9 @@ fn roundtrip_inner(legacy: bool) { true, legacy, None, - |_, _| true, ) .unwrap(); - target_col.import_apkg(&apkg_path, |_, _| true).unwrap(); + target_col.import_apkg(&apkg_path).unwrap(); target_col.assert_decks(); target_col.assert_notetype(¬etype); diff --git a/rslib/src/import_export/package/colpkg/export.rs b/rslib/src/import_export/package/colpkg/export.rs index 04073532a..9d85c19aa 100644 --- a/rslib/src/import_export/package/colpkg/export.rs +++ b/rslib/src/import_export/package/colpkg/export.rs @@ -33,8 +33,8 @@ use crate::import_export::package::media::new_media_entry; use crate::import_export::package::media::MediaCopier; use crate::import_export::package::media::MediaIter; use crate::import_export::ExportProgress; -use crate::import_export::IncrementableProgress; use crate::prelude::*; +use crate::progress::ThrottlingProgressHandler; use crate::storage::SchemaVersion; /// Enable multithreaded compression if over this size. For smaller files, @@ -48,10 +48,8 @@ impl Collection { out_path: impl AsRef, include_media: bool, legacy: bool, - progress_fn: impl 'static + FnMut(ExportProgress, bool) -> bool, ) -> Result<()> { - let mut progress = IncrementableProgress::new(progress_fn); - progress.call(ExportProgress::File)?; + let mut progress = self.new_progress_handler(); let colpkg_name = out_path.as_ref(); let temp_colpkg = new_tempfile_in_parent_of(colpkg_name)?; let src_path = self.col_path.clone(); @@ -87,7 +85,7 @@ fn export_collection_file( media_dir: Option, legacy: bool, tr: &I18n, - progress: &mut IncrementableProgress, + progress: &mut ThrottlingProgressHandler, ) -> Result<()> { let meta = if legacy { Meta::new_legacy() @@ -112,6 +110,7 @@ pub(crate) fn export_colpkg_from_data( tr: &I18n, ) -> Result<()> { let col_size = col_data.len(); + let mut progress = ThrottlingProgressHandler::new(Default::default()); export_collection( Meta::new(), out_path, @@ -119,7 +118,7 @@ pub(crate) fn export_colpkg_from_data( col_size, MediaIter::empty(), tr, - &mut IncrementableProgress::new(|_, _| true), + &mut progress, ) } @@ -130,9 +129,8 @@ pub(crate) fn export_collection( col_size: usize, media: MediaIter, tr: &I18n, - progress: &mut IncrementableProgress, + progress: &mut ThrottlingProgressHandler, ) -> Result<()> { - progress.call(ExportProgress::File)?; let out_file = File::create(&out_path)?; let mut zip = ZipWriter::new(out_file); @@ -217,7 +215,7 @@ fn write_media( meta: &Meta, zip: &mut ZipWriter, media: MediaIter, - progress: &mut IncrementableProgress, + progress: &mut ThrottlingProgressHandler, ) -> Result<()> { let mut media_entries = vec![]; write_media_files(meta, zip, media, &mut media_entries, progress)?; @@ -261,7 +259,7 @@ fn write_media_files( zip: &mut ZipWriter, media: MediaIter, media_entries: &mut Vec, - progress: &mut IncrementableProgress, + progress: &mut ThrottlingProgressHandler, ) -> Result<()> { let mut copier = MediaCopier::new(meta.zstd_compressed()); let mut incrementor = progress.incrementor(ExportProgress::Media); diff --git a/rslib/src/import_export/package/colpkg/import.rs b/rslib/src/import_export/package/colpkg/import.rs index a028c86ec..b8316c032 100644 --- a/rslib/src/import_export/package/colpkg/import.rs +++ b/rslib/src/import_export/package/colpkg/import.rs @@ -24,19 +24,17 @@ use crate::import_export::package::media::SafeMediaEntry; use crate::import_export::package::Meta; use crate::import_export::ImportError; use crate::import_export::ImportProgress; -use crate::import_export::IncrementableProgress; use crate::media::MediaManager; use crate::prelude::*; +use crate::progress::ThrottlingProgressHandler; pub fn import_colpkg( colpkg_path: &str, target_col_path: &str, target_media_folder: &Path, media_db: &Path, - progress_fn: impl 'static + FnMut(ImportProgress, bool) -> bool, + mut progress: ThrottlingProgressHandler, ) -> Result<()> { - let mut progress = IncrementableProgress::new(progress_fn); - progress.call(ImportProgress::File)?; let col_path = PathBuf::from(target_col_path); let mut tempfile = new_tempfile_in_parent_of(&col_path)?; @@ -45,9 +43,9 @@ pub fn import_colpkg( let meta = Meta::from_archive(&mut archive)?; copy_collection(&mut archive, &mut tempfile, &meta)?; - progress.call(ImportProgress::File)?; + progress.set(ImportProgress::File)?; check_collection_and_mod_schema(tempfile.path())?; - progress.call(ImportProgress::File)?; + progress.set(ImportProgress::File)?; restore_media( &meta, @@ -82,7 +80,7 @@ fn check_collection_and_mod_schema(col_path: &Path) -> Result<()> { fn restore_media( meta: &Meta, - progress: &mut IncrementableProgress, + progress: &mut ThrottlingProgressHandler, archive: &mut ZipArchive, media_folder: &Path, media_db: &Path, @@ -164,7 +162,7 @@ struct MediaComparer<'a>(Option>>); impl<'a> MediaComparer<'a> { fn new( meta: &Meta, - progress: &mut IncrementableProgress, + progress: &mut ThrottlingProgressHandler, media_manager: &'a MediaManager, ) -> Result { Ok(Self(if meta.media_list_is_hashmap() { diff --git a/rslib/src/import_export/package/colpkg/tests.rs b/rslib/src/import_export/package/colpkg/tests.rs index ebda63a7b..ccfe9adaf 100644 --- a/rslib/src/import_export/package/colpkg/tests.rs +++ b/rslib/src/import_export/package/colpkg/tests.rs @@ -40,7 +40,8 @@ fn roundtrip() -> Result<()> { // export to a file let col = collection_with_media(dir, name)?; let colpkg_name = dir.join(format!("{name}.colpkg")); - col.export_colpkg(&colpkg_name, true, legacy, |_, _| true)?; + let progress = col.new_progress_handler(); + col.export_colpkg(&colpkg_name, true, legacy)?; // import into a new collection let anki2_name = dir @@ -56,7 +57,7 @@ fn roundtrip() -> Result<()> { &anki2_name, &import_media_dir, &import_media_db, - |_, _| true, + progress, )?; // confirm collection imported @@ -89,8 +90,7 @@ fn normalization_check_on_export() -> Result<()> { // manually write a file in the wrong encoding. write_file(col.media_folder.join("ぱぱ.jpg"), "nfd encoding")?; assert_eq!( - col.export_colpkg(&colpkg_name, true, false, |_, _| true,) - .unwrap_err(), + col.export_colpkg(&colpkg_name, true, false,).unwrap_err(), AnkiError::MediaCheckRequired ); // file should have been cleaned up diff --git a/rslib/src/import_export/text/csv/export.rs b/rslib/src/import_export/text/csv/export.rs index 656698b6d..25f1d1014 100644 --- a/rslib/src/import_export/text/csv/export.rs +++ b/rslib/src/import_export/text/csv/export.rs @@ -15,7 +15,6 @@ use regex::Regex; use super::metadata::Delimiter; use crate::import_export::text::csv::metadata::DelimeterExt; use crate::import_export::ExportProgress; -use crate::import_export::IncrementableProgress; use crate::notetype::RenderCardOutput; use crate::prelude::*; use crate::search::SearchNode; @@ -32,10 +31,8 @@ impl Collection { path: &str, search: impl TryIntoSearch, with_html: bool, - progress_fn: impl 'static + FnMut(ExportProgress, bool) -> bool, ) -> Result { - let mut progress = IncrementableProgress::new(progress_fn); - progress.call(ExportProgress::File)?; + let mut progress = self.new_progress_handler::(); let mut incrementor = progress.incrementor(ExportProgress::Cards); let mut writer = file_writer_with_header(path, with_html)?; @@ -52,13 +49,8 @@ impl Collection { Ok(cards.len()) } - pub fn export_note_csv( - &mut self, - mut request: ExportNoteCsvRequest, - progress_fn: impl 'static + FnMut(ExportProgress, bool) -> bool, - ) -> Result { - let mut progress = IncrementableProgress::new(progress_fn); - progress.call(ExportProgress::File)?; + pub fn export_note_csv(&mut self, mut request: ExportNoteCsvRequest) -> Result { + let mut progress = self.new_progress_handler::(); let mut incrementor = progress.incrementor(ExportProgress::Notes); let guard = self.search_notes_into_table(Into::::into(&mut request))?; diff --git a/rslib/src/import_export/text/csv/import.rs b/rslib/src/import_export/text/csv/import.rs index 936163b1c..a74ded8c3 100644 --- a/rslib/src/import_export/text/csv/import.rs +++ b/rslib/src/import_export/text/csv/import.rs @@ -18,24 +18,19 @@ use crate::import_export::text::csv::metadata::Delimiter; use crate::import_export::text::ForeignData; use crate::import_export::text::ForeignNote; use crate::import_export::text::NameOrId; -use crate::import_export::ImportProgress; use crate::import_export::NoteLog; use crate::prelude::*; use crate::text::strip_utf8_bom; impl Collection { - pub fn import_csv( - &mut self, - path: &str, - metadata: CsvMetadata, - progress_fn: impl 'static + FnMut(ImportProgress, bool) -> bool, - ) -> Result> { + pub fn import_csv(&mut self, path: &str, metadata: CsvMetadata) -> Result> { + let progress = self.new_progress_handler(); let file = open_file(path)?; let mut ctx = ColumnContext::new(&metadata)?; let notes = ctx.deserialize_csv(file, metadata.delimiter())?; let mut data = ForeignData::from(metadata); data.notes = notes; - data.import(self, progress_fn) + data.import(self, progress) } } diff --git a/rslib/src/import_export/text/import.rs b/rslib/src/import_export/text/import.rs index 116f25221..b9ac7fe7f 100644 --- a/rslib/src/import_export/text/import.rs +++ b/rslib/src/import_export/text/import.rs @@ -18,7 +18,6 @@ use crate::import_export::text::ForeignNotetype; use crate::import_export::text::ForeignTemplate; use crate::import_export::text::MatchScope; use crate::import_export::ImportProgress; -use crate::import_export::IncrementableProgress; use crate::import_export::NoteLog; use crate::notes::field_checksum; use crate::notes::normalize_field; @@ -26,16 +25,16 @@ use crate::notetype::CardGenContext; use crate::notetype::CardTemplate; use crate::notetype::NoteField; use crate::prelude::*; +use crate::progress::ThrottlingProgressHandler; use crate::text::strip_html_preserving_media_filenames; impl ForeignData { pub fn import( self, col: &mut Collection, - progress_fn: impl 'static + FnMut(ImportProgress, bool) -> bool, + mut progress: ThrottlingProgressHandler, ) -> Result> { - let mut progress = IncrementableProgress::new(progress_fn); - progress.call(ImportProgress::File)?; + progress.set(ImportProgress::File)?; col.transact(Op::Import, |col| { self.update_config(col)?; let mut ctx = Context::new(&self, col)?; @@ -229,7 +228,7 @@ impl<'a> Context<'a> { notes: Vec, global_tags: &[String], updated_tags: &[String], - progress: &mut IncrementableProgress, + progress: &mut ThrottlingProgressHandler, ) -> Result { let mut incrementor = progress.incrementor(ImportProgress::Notes); let mut log = new_note_log(self.dupe_resolution, notes.len() as u32); @@ -654,8 +653,10 @@ mod test { data.add_note(&["same", "old"]); data.dupe_resolution = DupeResolution::Duplicate; - data.clone().import(&mut col, |_, _| true).unwrap(); - data.import(&mut col, |_, _| true).unwrap(); + let progress = col.new_progress_handler(); + data.clone().import(&mut col, progress).unwrap(); + let progress = col.new_progress_handler(); + data.import(&mut col, progress).unwrap(); assert_eq!(col.storage.notes_table_len(), 2); } @@ -665,12 +666,13 @@ mod test { let mut data = ForeignData::with_defaults(); data.add_note(&["same", "old"]); data.dupe_resolution = DupeResolution::Preserve; - - data.clone().import(&mut col, |_, _| true).unwrap(); + let progress = col.new_progress_handler(); + data.clone().import(&mut col, progress).unwrap(); assert_eq!(col.storage.notes_table_len(), 1); data.notes[0].fields[1].replace("new".to_string()); - data.import(&mut col, |_, _| true).unwrap(); + let progress = col.new_progress_handler(); + data.import(&mut col, progress).unwrap(); let notes = col.storage.get_all_notes(); assert_eq!(notes.len(), 1); assert_eq!(notes[0].fields()[1], "old"); @@ -682,12 +684,13 @@ mod test { let mut data = ForeignData::with_defaults(); data.add_note(&["same", "old"]); data.dupe_resolution = DupeResolution::Update; - - data.clone().import(&mut col, |_, _| true).unwrap(); + let progress = col.new_progress_handler(); + data.clone().import(&mut col, progress).unwrap(); assert_eq!(col.storage.notes_table_len(), 1); data.notes[0].fields[1].replace("new".to_string()); - data.import(&mut col, |_, _| true).unwrap(); + let progress = col.new_progress_handler(); + data.import(&mut col, progress).unwrap(); assert_eq!(col.storage.get_all_notes()[0].fields()[1], "new"); } @@ -698,13 +701,14 @@ mod test { data.add_note(&["same", "unchanged"]); data.add_note(&["same", "unchanged"]); data.dupe_resolution = DupeResolution::Update; - - data.clone().import(&mut col, |_, _| true).unwrap(); + let progress = col.new_progress_handler(); + data.clone().import(&mut col, progress).unwrap(); assert_eq!(col.storage.notes_table_len(), 2); data.notes[0].fields[1] = None; data.notes[1].fields.pop(); - data.import(&mut col, |_, _| true).unwrap(); + let progress = col.new_progress_handler(); + data.import(&mut col, progress).unwrap(); let notes = col.storage.get_all_notes(); assert_eq!(notes[0].fields(), &["same", "unchanged"]); assert_eq!(notes[0].fields(), &["same", "unchanged"]); @@ -719,13 +723,15 @@ mod test { let mut data = ForeignData::with_defaults(); data.dupe_resolution = DupeResolution::Update; data.add_note(&["神", "new"]); + let progress = col.new_progress_handler(); - data.clone().import(&mut col, |_, _| true).unwrap(); + data.clone().import(&mut col, progress).unwrap(); assert_eq!(col.storage.get_all_notes()[0].fields(), &["神", "new"]); col.set_config_bool(BoolKey::NormalizeNoteText, false, false) .unwrap(); - data.import(&mut col, |_, _| true).unwrap(); + let progress = col.new_progress_handler(); + data.import(&mut col, progress).unwrap(); let notes = col.storage.get_all_notes(); assert_eq!(notes[0].fields(), &["神", "new"]); assert_eq!(notes[1].fields(), &["神", "new"]); @@ -738,8 +744,8 @@ mod test { data.add_note(&["foo"]); data.notes[0].tags.replace(vec![String::from("bar")]); data.global_tags = vec![String::from("baz")]; - - data.import(&mut col, |_, _| true).unwrap(); + let progress = col.new_progress_handler(); + data.import(&mut col, progress).unwrap(); assert_eq!(col.storage.get_all_notes()[0].tags, ["bar", "baz"]); } @@ -750,8 +756,8 @@ mod test { data.add_note(&["foo"]); data.notes[0].tags.replace(vec![String::from("bar")]); data.global_tags = vec![String::from("baz")]; - - data.import(&mut col, |_, _| true).unwrap(); + let progress = col.new_progress_handler(); + data.import(&mut col, progress).unwrap(); assert_eq!(col.storage.get_all_notes()[0].tags, ["bar", "baz"]); } @@ -769,8 +775,8 @@ mod test { let mut data = ForeignData::with_defaults(); data.match_scope = MatchScope::NotetypeAndDeck; data.add_note(&["foo", "new"]); - - data.import(&mut col, |_, _| true).unwrap(); + let progress = col.new_progress_handler(); + data.import(&mut col, progress).unwrap(); let notes = col.storage.get_all_notes(); // same deck, should be updated assert_eq!(notes[0].fields()[1], "new"); diff --git a/rslib/src/import_export/text/json.rs b/rslib/src/import_export/text/json.rs index ca7257037..1abf49f5f 100644 --- a/rslib/src/import_export/text/json.rs +++ b/rslib/src/import_export/text/json.rs @@ -4,29 +4,20 @@ use anki_io::read_file; use crate::import_export::text::ForeignData; -use crate::import_export::ImportProgress; use crate::import_export::NoteLog; use crate::prelude::*; impl Collection { - pub fn import_json_file( - &mut self, - path: &str, - mut progress_fn: impl 'static + FnMut(ImportProgress, bool) -> bool, - ) -> Result> { - progress_fn(ImportProgress::Gathering, false); + pub fn import_json_file(&mut self, path: &str) -> Result> { + let progress = self.new_progress_handler(); let slice = read_file(path)?; let data: ForeignData = serde_json::from_slice(&slice)?; - data.import(self, progress_fn) + data.import(self, progress) } - pub fn import_json_string( - &mut self, - json: &str, - mut progress_fn: impl 'static + FnMut(ImportProgress, bool) -> bool, - ) -> Result> { - progress_fn(ImportProgress::Gathering, false); + pub fn import_json_string(&mut self, json: &str) -> Result> { + let progress = self.new_progress_handler(); let data: ForeignData = serde_json::from_str(json)?; - data.import(self, progress_fn) + data.import(self, progress) } } diff --git a/rslib/src/lib.rs b/rslib/src/lib.rs index 2360aa8f0..d2f4c2b6e 100644 --- a/rslib/src/lib.rs +++ b/rslib/src/lib.rs @@ -28,6 +28,7 @@ pub mod notetype; pub mod ops; mod preferences; pub mod prelude; +mod progress; pub mod revlog; pub mod scheduler; pub mod search; diff --git a/rslib/src/media/check.rs b/rslib/src/media/check.rs index 2e90fbb7c..85d3bed70 100644 --- a/rslib/src/media/check.rs +++ b/rslib/src/media/check.rs @@ -19,6 +19,8 @@ use crate::media::files::normalize_nfc_filename; use crate::media::files::trash_folder; use crate::media::MediaManager; use crate::prelude::*; +use crate::progress::ThrottlingProgressHandler; +use crate::sync::media::progress::MediaCheckProgress; use crate::sync::media::MAX_INDIVIDUAL_MEDIA_FILE_SIZE; use crate::text::extract_media_refs; use crate::text::normalize_to_nfc; @@ -45,31 +47,25 @@ struct MediaFolderCheck { oversize: Vec, } -pub struct MediaChecker<'a, 'b, P> -where - P: FnMut(usize) -> bool, -{ - ctx: &'a mut Collection, - mgr: &'b MediaManager, - progress_cb: P, - checked: usize, +impl Collection { + pub fn media_checker(&mut self) -> Result> { + MediaChecker::new(self) + } } -impl

MediaChecker<'_, '_, P> -where - P: FnMut(usize) -> bool, -{ - pub(crate) fn new<'a, 'b>( - ctx: &'a mut Collection, - mgr: &'b MediaManager, - progress_cb: P, - ) -> MediaChecker<'a, 'b, P> { - MediaChecker { - ctx, - mgr, - progress_cb, - checked: 0, - } +pub struct MediaChecker<'a> { + col: &'a mut Collection, + media: MediaManager, + progress: ThrottlingProgressHandler, +} + +impl MediaChecker<'_> { + pub(crate) fn new(col: &mut Collection) -> Result> { + Ok(MediaChecker { + media: col.media()?, + progress: col.new_progress_handler(), + col, + }) } pub fn check(&mut self) -> Result { @@ -91,7 +87,7 @@ where pub fn summarize_output(&self, output: &mut MediaCheckOutput) -> String { let mut buf = String::new(); - let tr = &self.ctx.tr; + let tr = &self.col.tr; // top summary area if output.trash_count > 0 { @@ -179,6 +175,10 @@ where buf } + fn increment_progress(&mut self) -> Result<()> { + self.progress.increment(|p| &mut p.checked) + } + /// Check all the files in the media folder. /// /// - Renames files with invalid names @@ -186,13 +186,11 @@ where /// - Gathers a list of all files fn check_media_folder(&mut self) -> Result { let mut out = MediaFolderCheck::default(); - for dentry in self.mgr.media_folder.read_dir()? { + + for dentry in self.media.media_folder.read_dir()? { let dentry = dentry?; - self.checked += 1; - if self.checked % 10 == 0 { - self.fire_progress_cb()?; - } + self.increment_progress()?; // if the filename is not valid unicode, skip it let fname_os = dentry.file_name(); @@ -220,7 +218,7 @@ where if let Some(norm_name) = filename_if_normalized(disk_fname) { out.files.push(norm_name.into_owned()); } else { - match data_for_file(&self.mgr.media_folder, disk_fname)? { + match data_for_file(&self.media.media_folder, disk_fname)? { Some(data) => { let norm_name = self.normalize_file(disk_fname, data)?; out.renamed @@ -242,38 +240,27 @@ where /// Write file data to normalized location, moving old file to trash. fn normalize_file<'a>(&mut self, disk_fname: &'a str, data: Vec) -> Result> { // add a copy of the file using the correct name - let fname = self.mgr.add_file(disk_fname, &data)?; + let fname = self.media.add_file(disk_fname, &data)?; debug!(from = disk_fname, to = &fname.as_ref(), "renamed"); assert_ne!(fname.as_ref(), disk_fname); // remove the original file - let path = &self.mgr.media_folder.join(disk_fname); + let path = &self.media.media_folder.join(disk_fname); fs::remove_file(path)?; Ok(fname) } - fn fire_progress_cb(&mut self) -> Result<()> { - if (self.progress_cb)(self.checked) { - Ok(()) - } else { - Err(AnkiError::Interrupted) - } - } - /// Returns the count and total size of the files in the trash folder fn files_in_trash(&mut self) -> Result<(u64, u64)> { - let trash = trash_folder(&self.mgr.media_folder)?; + let trash = trash_folder(&self.media.media_folder)?; let mut total_files = 0; let mut total_bytes = 0; for dentry in trash.read_dir()? { let dentry = dentry?; - self.checked += 1; - if self.checked % 10 == 0 { - self.fire_progress_cb()?; - } + self.increment_progress()?; if dentry.file_name() == ".DS_Store" { continue; @@ -289,15 +276,12 @@ where } pub fn empty_trash(&mut self) -> Result<()> { - let trash = trash_folder(&self.mgr.media_folder)?; + let trash = trash_folder(&self.media.media_folder)?; for dentry in trash.read_dir()? { let dentry = dentry?; - self.checked += 1; - if self.checked % 10 == 0 { - self.fire_progress_cb()?; - } + self.increment_progress()?; fs::remove_file(dentry.path())?; } @@ -306,17 +290,14 @@ where } pub fn restore_trash(&mut self) -> Result<()> { - let trash = trash_folder(&self.mgr.media_folder)?; + let trash = trash_folder(&self.media.media_folder)?; for dentry in trash.read_dir()? { let dentry = dentry?; - self.checked += 1; - if self.checked % 10 == 0 { - self.fire_progress_cb()?; - } + self.increment_progress()?; - let orig_path = self.mgr.media_folder.join(dentry.file_name()); + let orig_path = self.media.media_folder.join(dentry.file_name()); // if the original filename doesn't exist, we can just rename if let Err(e) = fs::metadata(&orig_path) { if e.kind() == io::ErrorKind::NotFound { @@ -329,7 +310,7 @@ where let fname_os = dentry.file_name(); let fname = fname_os.to_string_lossy(); if let Some(data) = data_for_file(&trash, fname.as_ref())? { - let _new_fname = self.mgr.add_file(fname.as_ref(), &data)?; + let _new_fname = self.media.add_file(fname.as_ref(), &data)?; } else { debug!(?fname, "file disappeared while restoring trash"); } @@ -346,17 +327,14 @@ where renamed: &HashMap, ) -> Result>> { let mut referenced_files = HashMap::new(); - let notetypes = self.ctx.get_all_notetypes()?; + let notetypes = self.col.get_all_notetypes()?; let mut collection_modified = false; - let nids = self.ctx.search_notes_unordered("")?; - let usn = self.ctx.usn()?; + let nids = self.col.search_notes_unordered("")?; + let usn = self.col.usn()?; for nid in nids { - self.checked += 1; - if self.checked % 10 == 0 { - self.fire_progress_cb()?; - } - let mut note = self.ctx.storage.get_note(nid)?.unwrap(); + self.increment_progress()?; + let mut note = self.col.storage.get_note(nid)?.unwrap(); let nt = notetypes.get(¬e.notetype_id).ok_or_else(|| { AnkiError::db_error("missing note type", DbErrorKind::MissingEntity) })?; @@ -366,12 +344,16 @@ where .or_insert_with(Vec::new) .push(nid) }; - if fix_and_extract_media_refs(&mut note, &mut tracker, renamed, &self.mgr.media_folder)? - { + if fix_and_extract_media_refs( + &mut note, + &mut tracker, + renamed, + &self.media.media_folder, + )? { // note was modified, needs saving note.prepare_for_update(nt, false)?; note.set_modified(usn); - self.ctx.storage.update_note(¬e)?; + self.col.storage.update_note(¬e)?; collection_modified = true; } @@ -557,10 +539,8 @@ pub(crate) mod test { write_file(mgr.media_folder.join("_under.jpg"), "foo")?; write_file(mgr.media_folder.join("unused.jpg"), "foo")?; - let progress = |_n| true; - let (output, report) = { - let mut checker = MediaChecker::new(&mut col, &mgr, progress); + let mut checker = col.media_checker()?; let output = checker.check()?; let summary = checker.summarize_output(&mut output.clone()); (output, summary) @@ -628,9 +608,7 @@ Unused: unused.jpg let trash_folder = trash_folder(&mgr.media_folder)?; write_file(trash_folder.join("test.jpg"), "test")?; - let progress = |_n| true; - - let mut checker = MediaChecker::new(&mut col, &mgr, progress); + let mut checker = col.media_checker()?; checker.restore_trash()?; // file should have been moved to media folder @@ -644,7 +622,7 @@ Unused: unused.jpg // are equal write_file(trash_folder.join("test.jpg"), "test")?; - let mut checker = MediaChecker::new(&mut col, &mgr, progress); + let mut checker = col.media_checker()?; checker.restore_trash()?; assert_eq!(files_in_dir(&trash_folder), Vec::::new()); @@ -656,7 +634,7 @@ Unused: unused.jpg // but rename if required write_file(trash_folder.join("test.jpg"), "test2")?; - let mut checker = MediaChecker::new(&mut col, &mgr, progress); + let mut checker = col.media_checker()?; checker.restore_trash()?; assert_eq!(files_in_dir(&trash_folder), Vec::::new()); @@ -677,10 +655,8 @@ Unused: unused.jpg write_file(mgr.media_folder.join("ぱぱ.jpg"), "nfd encoding")?; - let progress = |_n| true; - let mut output = { - let mut checker = MediaChecker::new(&mut col, &mgr, progress); + let mut checker = col.media_checker()?; checker.check() }?; diff --git a/rslib/src/media/mod.rs b/rslib/src/media/mod.rs index effe56122..39b420c32 100644 --- a/rslib/src/media/mod.rs +++ b/rslib/src/media/mod.rs @@ -16,6 +16,7 @@ use crate::media::files::mtime_as_i64; use crate::media::files::remove_files; use crate::media::files::sha1_of_data; use crate::prelude::*; +use crate::progress::ThrottlingProgressHandler; use crate::sync::http_client::HttpSyncClient; use crate::sync::login::SyncAuth; use crate::sync::media::database::client::changetracker::ChangeTracker; @@ -139,10 +140,11 @@ impl MediaManager { } /// Sync media. - pub async fn sync_media(self, progress: F, auth: SyncAuth) -> Result<()> - where - F: FnMut(MediaSyncProgress) -> bool, - { + pub async fn sync_media( + self, + progress: ThrottlingProgressHandler, + auth: SyncAuth, + ) -> Result<()> { let client = HttpSyncClient::new(auth); let mut syncer = MediaSyncer::new(self, progress, client)?; syncer.sync().await diff --git a/rslib/src/backend/progress.rs b/rslib/src/progress.rs similarity index 52% rename from rslib/src/backend/progress.rs rename to rslib/src/progress.rs index 7d504ad8a..3224bcd8b 100644 --- a/rslib/src/backend/progress.rs +++ b/rslib/src/progress.rs @@ -1,54 +1,131 @@ // Copyright: Ankitects Pty Ltd and contributors // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html +use std::marker::PhantomData; use std::sync::Arc; use std::sync::Mutex; use anki_i18n::I18n; use futures::future::AbortHandle; -use super::Backend; use crate::dbcheck::DatabaseCheckProgress; +use crate::error::AnkiError; +use crate::error::Result; use crate::import_export::ExportProgress; use crate::import_export::ImportProgress; +use crate::prelude::Collection; use crate::sync::collection::normal::NormalSyncProgress; use crate::sync::collection::progress::FullSyncProgress; use crate::sync::collection::progress::SyncStage; +use crate::sync::media::progress::MediaCheckProgress; use crate::sync::media::progress::MediaSyncProgress; -pub(super) struct ThrottlingProgressHandler { - pub state: Arc>, - pub last_update: coarsetime::Instant, +/// Stores progress state that can be updated cheaply, and will update a +/// Mutex-protected copy that other threads can check, if more than 0.1 +/// secs has elapsed since the previous update. +/// If another thread has set the `want_abort` flag on the shared state, +/// then the next non-throttled update will fail with [AnkiError::Interrupted]. +/// Automatically updates the shared state on creation, with the default +/// value for the type. +#[derive(Debug, Default)] +pub struct ThrottlingProgressHandler + Default> { + pub(crate) state: P, + shared_state: Arc>, + last_shared_update: coarsetime::Instant, } -impl ThrottlingProgressHandler { - /// Returns true if should continue. - pub(super) fn update(&mut self, progress: impl Into, throttle: bool) -> bool { - let now = coarsetime::Instant::now(); - if throttle && now.duration_since(self.last_update).as_f64() < 0.1 { - return true; +impl + Default + Clone> ThrottlingProgressHandler

{ + pub(crate) fn new(shared_state: Arc>) -> Self { + let initial = P::default(); + { + let mut guard = shared_state.lock().unwrap(); + guard.last_progress = Some(initial.clone().into()); + guard.want_abort = false; } - self.last_update = now; - let mut guard = self.state.lock().unwrap(); - guard.last_progress.replace(progress.into()); - let want_abort = guard.want_abort; - guard.want_abort = false; - !want_abort + Self { + shared_state, + state: initial, + ..Default::default() + } + } + + /// Overwrite the currently-stored state. This does not throttle, and should + /// be used when you want to ensure the UI state gets updated, and + /// ensure that the abort flag is checked between expensive steps. + pub(crate) fn set(&mut self, progress: P) -> Result<()> { + self.update(false, |state| *state = progress) + } + + /// Mutate the currently-stored state, and maybe update shared state. + pub(crate) fn update(&mut self, throttle: bool, mutator: impl FnOnce(&mut P)) -> Result<()> { + mutator(&mut self.state); + + let now = coarsetime::Instant::now(); + if throttle && now.duration_since(self.last_shared_update).as_f64() < 0.1 { + return Ok(()); + } + self.last_shared_update = now; + + let mut guard = self.shared_state.lock().unwrap(); + guard.last_progress.replace(self.state.clone().into()); + + if std::mem::take(&mut guard.want_abort) { + Err(AnkiError::Interrupted) + } else { + Ok(()) + } + } + + /// Check the abort flag, and trigger a UI update if it was throttled. + pub(crate) fn check_cancelled(&mut self) -> Result<()> { + self.set(self.state.clone()) + } + + /// An alternative to incrementor() below, that can be used across function + /// calls easily, as it continues from the previous state. + pub(crate) fn increment(&mut self, accessor: impl Fn(&mut P) -> &mut usize) -> Result<()> { + let field = accessor(&mut self.state); + *field += 1; + if *field % 17 == 0 { + self.update(true, |_| ())?; + } + Ok(()) + } + + /// Returns an [Incrementor] with an `increment()` function for use in + /// loops. + pub(crate) fn incrementor<'inc, 'progress: 'inc, 'map: 'inc>( + &'progress mut self, + mut count_map: impl 'map + FnMut(usize) -> P, + ) -> Incrementor<'inc, impl FnMut(usize) -> Result<()> + 'inc> { + Incrementor::new(move |u| self.update(true, |p| *p = count_map(u))) + } + + /// Stopgap for returning a progress fn compliant with the media code. + pub(crate) fn media_db_fn( + &mut self, + count_map: impl 'static + Fn(usize) -> P, + ) -> Result bool + '_> + where + P: Into, + { + Ok(move |count| self.update(true, |p| *p = count_map(count)).is_ok()) } } -pub(super) struct ProgressState { +#[derive(Default, Debug)] +pub struct ProgressState { pub want_abort: bool, pub last_progress: Option, } // fixme: this should support multiple abort handles. -pub(super) type AbortHandleSlot = Arc>>; +pub(crate) type AbortHandleSlot = Arc>>; -#[derive(Clone, Copy)] -pub(super) enum Progress { +#[derive(Clone, Copy, Debug)] +pub enum Progress { MediaSync(MediaSyncProgress), - MediaCheck(u32), + MediaCheck(MediaCheckProgress), FullSync(FullSyncProgress), NormalSync(NormalSyncProgress), DatabaseCheck(DatabaseCheckProgress), @@ -56,7 +133,7 @@ pub(super) enum Progress { Export(ExportProgress), } -pub(super) fn progress_to_proto( +pub(crate) fn progress_to_proto( progress: Option, tr: &I18n, ) -> anki_proto::collection::Progress { @@ -66,7 +143,7 @@ pub(super) fn progress_to_proto( anki_proto::collection::progress::Value::MediaSync(media_sync_progress(p, tr)) } Progress::MediaCheck(n) => anki_proto::collection::progress::Value::MediaCheck( - tr.media_check_checked(n).into(), + tr.media_check_checked(n.checked).into(), ), Progress::FullSync(p) => anki_proto::collection::progress::Value::FullSync( anki_proto::collection::progress::FullSync { @@ -113,8 +190,8 @@ pub(super) fn progress_to_proto( anki_proto::collection::progress::Value::DatabaseCheck( anki_proto::collection::progress::DatabaseCheck { stage, - stage_total, - stage_current, + stage_total: stage_total as u32, + stage_current: stage_current as u32, }, ) } @@ -175,22 +252,72 @@ impl From for Progress { } } +impl From for Progress { + fn from(p: MediaCheckProgress) -> Self { + Progress::MediaCheck(p) + } +} + impl From for Progress { fn from(p: NormalSyncProgress) -> Self { Progress::NormalSync(p) } } -impl Backend { - pub(super) fn new_progress_handler(&self) -> ThrottlingProgressHandler { - { - let mut guard = self.progress_state.lock().unwrap(); - guard.want_abort = false; - guard.last_progress = None; - } - ThrottlingProgressHandler { - state: Arc::clone(&self.progress_state), - last_update: coarsetime::Instant::now(), - } +impl From for Progress { + fn from(p: DatabaseCheckProgress) -> Self { + Progress::DatabaseCheck(p) + } +} + +impl From for Progress { + fn from(p: ImportProgress) -> Self { + Progress::Import(p) + } +} + +impl From for Progress { + fn from(p: ExportProgress) -> Self { + Progress::Export(p) + } +} + +impl Collection { + pub fn new_progress_handler + Default + Clone>( + &self, + ) -> ThrottlingProgressHandler

{ + ThrottlingProgressHandler::new(self.state.progress.clone()) + } +} + +pub(crate) struct Incrementor<'f, F: 'f + FnMut(usize) -> Result<()>> { + update_fn: F, + count: usize, + update_interval: usize, + _phantom: PhantomData<&'f ()>, +} + +impl<'f, F: 'f + FnMut(usize) -> Result<()>> Incrementor<'f, F> { + fn new(update_fn: F) -> Self { + Self { + update_fn, + count: 0, + update_interval: 17, + _phantom: PhantomData, + } + } + + /// Increments the progress counter, periodically triggering an update. + /// Returns [AnkiError::Interrupted] if the operation should be cancelled. + pub(crate) fn increment(&mut self) -> Result<()> { + self.count += 1; + if self.count % self.update_interval != 0 { + return Ok(()); + } + (self.update_fn)(self.count) + } + + pub(crate) fn count(&self) -> usize { + self.count } } diff --git a/rslib/src/sync/collection/changes.rs b/rslib/src/sync/collection/changes.rs index 474259b75..af3789b1c 100644 --- a/rslib/src/sync/collection/changes.rs +++ b/rslib/src/sync/collection/changes.rs @@ -19,7 +19,6 @@ use crate::error::SyncErrorKind; use crate::notetype::NotetypeSchema11; use crate::prelude::*; use crate::sync::collection::normal::ClientSyncState; -use crate::sync::collection::normal::NormalSyncProgress; use crate::sync::collection::normal::NormalSyncer; use crate::sync::collection::protocol::SyncProtocol; use crate::sync::collection::start::ServerSyncState; @@ -52,10 +51,7 @@ pub struct DecksAndConfig { config: Vec, } -impl NormalSyncer<'_, F> -where - F: FnMut(NormalSyncProgress, bool), -{ +impl NormalSyncer<'_> { // This was assumed to a cheap operation when originally written - it didn't // anticipate the large deck trees and note types some users would create. // They should be chunked in the future, like other objects. Syncing tags @@ -79,16 +75,18 @@ where "sending" ); - self.progress.local_update += local.notetypes.len() - + local.decks_and_config.decks.len() - + local.decks_and_config.config.len() - + local.tags.len(); + self.progress.update(false, |p| { + p.local_update += local.notetypes.len() + + local.decks_and_config.decks.len() + + local.decks_and_config.config.len() + + local.tags.len(); + })?; let remote = self .server .apply_changes(ApplyChangesRequest { changes: local }.try_into_sync_request()?) .await? .json()?; - self.fire_progress_cb(true); + self.progress.check_cancelled()?; debug!( notetypes = remote.notetypes.len(), @@ -98,13 +96,15 @@ where "received" ); - self.progress.remote_update += remote.notetypes.len() - + remote.decks_and_config.decks.len() - + remote.decks_and_config.config.len() - + remote.tags.len(); + self.progress.update(false, |p| { + p.remote_update += remote.notetypes.len() + + remote.decks_and_config.decks.len() + + remote.decks_and_config.config.len() + + remote.tags.len(); + })?; self.col.apply_changes(remote, state.server_usn)?; - self.fire_progress_cb(true); + self.progress.check_cancelled()?; Ok(()) } } diff --git a/rslib/src/sync/collection/chunks.rs b/rslib/src/sync/collection/chunks.rs index f786f3bfc..fcea323d3 100644 --- a/rslib/src/sync/collection/chunks.rs +++ b/rslib/src/sync/collection/chunks.rs @@ -17,7 +17,6 @@ use crate::serde::deserialize_int_from_number; use crate::storage::card::data::card_data_string; use crate::storage::card::data::CardData; use crate::sync::collection::normal::ClientSyncState; -use crate::sync::collection::normal::NormalSyncProgress; use crate::sync::collection::normal::NormalSyncer; use crate::sync::collection::protocol::EmptyInput; use crate::sync::collection::protocol::SyncProtocol; @@ -87,10 +86,7 @@ pub struct CardEntry { pub data: String, } -impl NormalSyncer<'_, F> -where - F: FnMut(NormalSyncProgress, bool), -{ +impl NormalSyncer<'_> { pub(in crate::sync) async fn process_chunks_from_server( &mut self, state: &ClientSyncState, @@ -106,13 +102,14 @@ where "received" ); - self.progress.remote_update += - chunk.cards.len() + chunk.notes.len() + chunk.revlog.len(); + self.progress.update(false, |p| { + p.remote_update += chunk.cards.len() + chunk.notes.len() + chunk.revlog.len() + })?; let done = chunk.done; self.col.apply_chunk(chunk, state.pending_usn)?; - self.fire_progress_cb(true); + self.progress.check_cancelled()?; if done { return Ok(()); @@ -138,14 +135,15 @@ where "sending" ); - self.progress.local_update += - chunk.cards.len() + chunk.notes.len() + chunk.revlog.len(); + self.progress.update(false, |p| { + p.local_update += chunk.cards.len() + chunk.notes.len() + chunk.revlog.len() + })?; self.server .apply_chunk(ApplyChunkRequest { chunk }.try_into_sync_request()?) .await?; - self.fire_progress_cb(true); + self.progress.check_cancelled()?; if done { return Ok(()); diff --git a/rslib/src/sync/collection/download.rs b/rslib/src/sync/collection/download.rs index 6aa8bdacc..66c5400e4 100644 --- a/rslib/src/sync/collection/download.rs +++ b/rslib/src/sync/collection/download.rs @@ -9,9 +9,7 @@ use anki_io::write_file; use crate::collection::CollectionBuilder; use crate::prelude::*; use crate::storage::SchemaVersion; -use crate::sync::collection::progress::FullSyncProgressFn; use crate::sync::collection::protocol::EmptyInput; -use crate::sync::collection::protocol::SyncProtocol; use crate::sync::error::HttpResult; use crate::sync::error::OrHttpErr; use crate::sync::http_client::HttpSyncClient; @@ -19,21 +17,21 @@ use crate::sync::login::SyncAuth; impl Collection { /// Download collection from AnkiWeb. Caller must re-open afterwards. - pub async fn full_download( - self, - auth: SyncAuth, - progress_fn: FullSyncProgressFn, - ) -> Result<()> { - let mut server = HttpSyncClient::new(auth); - server.set_full_sync_progress_fn(Some(progress_fn)); - self.full_download_with_server(server).await + pub async fn full_download(self, auth: SyncAuth) -> Result<()> { + self.full_download_with_server(HttpSyncClient::new(auth)) + .await } - pub(crate) async fn full_download_with_server(self, server: HttpSyncClient) -> Result<()> { + // pub for tests + pub(super) async fn full_download_with_server(self, server: HttpSyncClient) -> Result<()> { let col_path = self.col_path.clone(); let _col_folder = col_path.parent().or_invalid("couldn't get col_folder")?; + let progress = self.new_progress_handler(); self.close(None)?; - let out_data = server.download(EmptyInput::request()).await?.data; + let out_data = server + .download_with_progress(EmptyInput::request(), progress) + .await? + .data; // check file ok let temp_file = new_tempfile_in_parent_of(&col_path)?; write_file(temp_file.path(), out_data)?; diff --git a/rslib/src/sync/collection/finish.rs b/rslib/src/sync/collection/finish.rs index b06f03b44..b69fbcad7 100644 --- a/rslib/src/sync/collection/finish.rs +++ b/rslib/src/sync/collection/finish.rs @@ -3,15 +3,11 @@ use crate::prelude::*; use crate::sync::collection::normal::ClientSyncState; -use crate::sync::collection::normal::NormalSyncProgress; use crate::sync::collection::normal::NormalSyncer; use crate::sync::collection::protocol::EmptyInput; use crate::sync::collection::protocol::SyncProtocol; -impl NormalSyncer<'_, F> -where - F: FnMut(NormalSyncProgress, bool), -{ +impl NormalSyncer<'_> { pub(in crate::sync) async fn finalize(&mut self, state: &ClientSyncState) -> Result<()> { let new_server_mtime = self.server.finish(EmptyInput::request()).await?.json()?; self.col.finalize_sync(state, new_server_mtime) diff --git a/rslib/src/sync/collection/normal.rs b/rslib/src/sync/collection/normal.rs index eef6fae55..6b8378b22 100644 --- a/rslib/src/sync/collection/normal.rs +++ b/rslib/src/sync/collection/normal.rs @@ -9,6 +9,7 @@ use crate::error::AnkiError; use crate::error::SyncError; use crate::error::SyncErrorKind; use crate::prelude::Usn; +use crate::progress::ThrottlingProgressHandler; use crate::sync::collection::progress::SyncStage; use crate::sync::collection::protocol::EmptyInput; use crate::sync::collection::protocol::SyncProtocol; @@ -16,11 +17,10 @@ use crate::sync::collection::status::online_sync_status_check; use crate::sync::http_client::HttpSyncClient; use crate::sync::login::SyncAuth; -pub struct NormalSyncer<'a, F> { +pub struct NormalSyncer<'a> { pub(in crate::sync) col: &'a mut Collection, pub(in crate::sync) server: HttpSyncClient, - pub(in crate::sync) progress: NormalSyncProgress, - pub(in crate::sync) progress_fn: F, + pub(in crate::sync) progress: ThrottlingProgressHandler, } #[derive(Default, Debug, Clone, Copy)] @@ -54,29 +54,17 @@ pub struct ClientSyncState { pub(in crate::sync) pending_usn: Usn, } -impl NormalSyncer<'_, F> -where - F: FnMut(NormalSyncProgress, bool), -{ - pub fn new(col: &mut Collection, server: HttpSyncClient, progress_fn: F) -> NormalSyncer<'_, F> - where - F: FnMut(NormalSyncProgress, bool), - { +impl NormalSyncer<'_> { + pub fn new(col: &mut Collection, server: HttpSyncClient) -> NormalSyncer<'_> { NormalSyncer { + progress: col.new_progress_handler(), col, server, - progress: NormalSyncProgress::default(), - progress_fn, } } - pub(in crate::sync) fn fire_progress_cb(&mut self, throttle: bool) { - (self.progress_fn)(self.progress, throttle) - } - pub async fn sync(&mut self) -> error::Result { debug!("fetching meta..."); - self.fire_progress_cb(false); let local = self.col.sync_meta()?; let state = online_sync_status_check(local, &mut self.server).await?; debug!(?state, "fetched"); @@ -120,8 +108,8 @@ where /// Sync. Caller must have created a transaction, and should call /// abort on failure. async fn normal_sync_inner(&mut self, mut state: ClientSyncState) -> error::Result { - self.progress.stage = SyncStage::Syncing; - self.fire_progress_cb(false); + self.progress + .update(false, |p| p.stage = SyncStage::Syncing)?; debug!("start"); self.start_and_process_deletions(&state).await?; @@ -132,8 +120,8 @@ where debug!("begin stream to server"); self.send_chunks_to_server(&state).await?; - self.progress.stage = SyncStage::Finalizing; - self.fire_progress_cb(false); + self.progress + .update(false, |p| p.stage = SyncStage::Finalizing)?; debug!("sanity check"); self.sanity_check().await?; @@ -164,15 +152,8 @@ impl From for SyncOutput { } impl Collection { - pub async fn normal_sync( - &mut self, - auth: SyncAuth, - progress_fn: F, - ) -> error::Result - where - F: FnMut(NormalSyncProgress, bool), - { - NormalSyncer::new(self, HttpSyncClient::new(auth), progress_fn) + pub async fn normal_sync(&mut self, auth: SyncAuth) -> error::Result { + NormalSyncer::new(self, HttpSyncClient::new(auth)) .sync() .await } diff --git a/rslib/src/sync/collection/progress.rs b/rslib/src/sync/collection/progress.rs index e136868b4..419d80fcf 100644 --- a/rslib/src/sync/collection/progress.rs +++ b/rslib/src/sync/collection/progress.rs @@ -27,5 +27,3 @@ pub async fn sync_abort(auth: SyncAuth) -> error::Result<()> { .await? .json() } - -pub type FullSyncProgressFn = Box; diff --git a/rslib/src/sync/collection/sanity.rs b/rslib/src/sync/collection/sanity.rs index 5fa991f77..fb7f25255 100644 --- a/rslib/src/sync/collection/sanity.rs +++ b/rslib/src/sync/collection/sanity.rs @@ -10,7 +10,6 @@ use tracing::info; use crate::error::SyncErrorKind; use crate::prelude::*; use crate::serde::default_on_invalid; -use crate::sync::collection::normal::NormalSyncProgress; use crate::sync::collection::normal::NormalSyncer; use crate::sync::collection::protocol::SyncProtocol; use crate::sync::request::IntoSyncRequest; @@ -51,10 +50,7 @@ pub struct SanityCheckDueCounts { pub review: u32, } -impl NormalSyncer<'_, F> -where - F: FnMut(NormalSyncProgress, bool), -{ +impl NormalSyncer<'_> { /// Caller should force full sync after rolling back. pub(in crate::sync) async fn sanity_check(&mut self) -> Result<()> { let local_counts = self.col.storage.sanity_check_info()?; diff --git a/rslib/src/sync/collection/start.rs b/rslib/src/sync/collection/start.rs index f15b0c369..4523b3adb 100644 --- a/rslib/src/sync/collection/start.rs +++ b/rslib/src/sync/collection/start.rs @@ -11,15 +11,11 @@ use crate::sync::collection::chunks::ChunkableIds; use crate::sync::collection::graves::ApplyGravesRequest; use crate::sync::collection::graves::Graves; use crate::sync::collection::normal::ClientSyncState; -use crate::sync::collection::normal::NormalSyncProgress; use crate::sync::collection::normal::NormalSyncer; use crate::sync::collection::protocol::SyncProtocol; use crate::sync::request::IntoSyncRequest; -impl NormalSyncer<'_, F> -where - F: FnMut(NormalSyncProgress, bool), -{ +impl NormalSyncer<'_> { pub(in crate::sync) async fn start_and_process_deletions( &mut self, state: &ClientSyncState, @@ -58,16 +54,20 @@ where while let Some(chunk) = local.take_chunk() { debug!("sending graves chunk"); - self.progress.local_remove += chunk.cards.len() + chunk.notes.len() + chunk.decks.len(); + self.progress.update(false, |p| { + p.local_remove += chunk.cards.len() + chunk.notes.len() + chunk.decks.len() + })?; self.server .apply_graves(ApplyGravesRequest { chunk }.try_into_sync_request()?) .await?; - self.fire_progress_cb(true); + self.progress.check_cancelled()?; } - self.progress.remote_remove = remote.cards.len() + remote.notes.len() + remote.decks.len(); + self.progress.update(false, |p| { + p.remote_remove = remote.cards.len() + remote.notes.len() + remote.decks.len() + })?; self.col.apply_graves(remote, state.server_usn)?; - self.fire_progress_cb(true); + self.progress.check_cancelled()?; debug!("applied server graves"); Ok(()) diff --git a/rslib/src/sync/collection/tests.rs b/rslib/src/sync/collection/tests.rs index a2e416f09..4e30f390f 100644 --- a/rslib/src/sync/collection/tests.rs +++ b/rslib/src/sync/collection/tests.rs @@ -34,11 +34,9 @@ use crate::revlog::RevlogEntry; use crate::search::SortMode; use crate::sync::collection::graves::ApplyGravesRequest; use crate::sync::collection::meta::MetaRequest; -use crate::sync::collection::normal::NormalSyncProgress; use crate::sync::collection::normal::NormalSyncer; use crate::sync::collection::normal::SyncActionRequired; use crate::sync::collection::normal::SyncOutput; -use crate::sync::collection::progress::FullSyncProgress; use crate::sync::collection::protocol::EmptyInput; use crate::sync::collection::protocol::SyncProtocol; use crate::sync::collection::start::StartRequest; @@ -111,10 +109,6 @@ fn unwrap_sync_err_kind(err: AnkiError) -> SyncErrorKind { kind } -fn norm_progress(_: NormalSyncProgress, _: bool) {} - -fn full_progress(_: FullSyncProgress, _: bool) {} - #[tokio::test] async fn host_key() -> Result<()> { with_active_server(|mut client| async move { @@ -209,7 +203,7 @@ async fn aborting_is_idempotent() -> Result<()> { #[tokio::test] async fn new_syncs_cancel_old_ones() -> Result<()> { with_active_server(|mut client| async move { - let ctx = SyncTestContext::new(client.partial_clone()); + let ctx = SyncTestContext::new(client.clone()); // start a sync let req = StartRequest { @@ -296,7 +290,7 @@ async fn sanity_check_should_roll_back_and_force_full_sync() -> Result<()> { .execute("update decks set usn=0 where id=?", [deck.id])?; // the sync should fail - let err = NormalSyncer::new(&mut col1, ctx.cloned_client(), norm_progress) + let err = NormalSyncer::new(&mut col1, ctx.cloned_client()) .sync() .await .unwrap_err(); @@ -349,7 +343,7 @@ async fn sync_errors_should_prompt_db_check() -> Result<()> { col1.storage.db.execute("update notetypes set usn=0", [])?; // the sync should fail - let err = NormalSyncer::new(&mut col1, ctx.cloned_client(), norm_progress) + let err = NormalSyncer::new(&mut col1, ctx.cloned_client()) .sync() .await .unwrap_err(); @@ -362,7 +356,7 @@ async fn sync_errors_should_prompt_db_check() -> Result<()> { assert_eq!(out.required, SyncActionRequired::NoChanges); // and the client should be able to sync again without a forced one-way sync - let err = NormalSyncer::new(&mut col1, ctx.cloned_client(), norm_progress) + let err = NormalSyncer::new(&mut col1, ctx.cloned_client()) .sync() .await .unwrap_err(); @@ -417,9 +411,7 @@ async fn string_grave_ids_are_handled() -> Result<()> { #[tokio::test] async fn invalid_uploads_should_be_handled() -> Result<()> { with_active_server(|client| async move { - let mut ctx = SyncTestContext::new(client); - ctx.client - .set_full_sync_progress_fn(Some(Box::new(full_progress))); + let ctx = SyncTestContext::new(client); let res = ctx .client .upload(b"fake data".to_vec().try_into_sync_request()?) @@ -494,7 +486,7 @@ impl SyncTestContext { } async fn normal_sync(&self, col: &mut Collection) -> SyncOutput { - NormalSyncer::new(col, self.cloned_client(), norm_progress) + NormalSyncer::new(col, self.cloned_client()) .sync() .await .unwrap() @@ -513,9 +505,7 @@ impl SyncTestContext { } fn cloned_client(&self) -> HttpSyncClient { - let mut client = self.client.partial_clone(); - client.set_full_sync_progress_fn(Some(Box::new(full_progress))); - client + self.client.clone() } } diff --git a/rslib/src/sync/collection/upload.rs b/rslib/src/sync/collection/upload.rs index 16c509a95..15dcafa02 100644 --- a/rslib/src/sync/collection/upload.rs +++ b/rslib/src/sync/collection/upload.rs @@ -18,8 +18,6 @@ use crate::collection::CollectionBuilder; use crate::error::SyncErrorKind; use crate::prelude::*; use crate::storage::SchemaVersion; -use crate::sync::collection::progress::FullSyncProgressFn; -use crate::sync::collection::protocol::SyncProtocol; use crate::sync::error::HttpResult; use crate::sync::error::OrHttpErr; use crate::sync::http_client::HttpSyncClient; @@ -34,15 +32,16 @@ pub const CORRUPT_MESSAGE: &str = impl Collection { /// Upload collection to AnkiWeb. Caller must re-open afterwards. - pub async fn full_upload(self, auth: SyncAuth, progress_fn: FullSyncProgressFn) -> Result<()> { - let mut server = HttpSyncClient::new(auth); - server.set_full_sync_progress_fn(Some(progress_fn)); - self.full_upload_with_server(server).await + pub async fn full_upload(self, auth: SyncAuth) -> Result<()> { + self.full_upload_with_server(HttpSyncClient::new(auth)) + .await } - pub(crate) async fn full_upload_with_server(mut self, server: HttpSyncClient) -> Result<()> { + // pub for tests + pub(super) async fn full_upload_with_server(mut self, server: HttpSyncClient) -> Result<()> { self.before_upload()?; let col_path = self.col_path.clone(); + let progress = self.new_progress_handler(); self.close(Some(SchemaVersion::V18))?; let col_data = fs::read(&col_path)?; @@ -55,7 +54,7 @@ impl Collection { } match server - .upload(col_data.try_into_sync_request()?) + .upload_with_progress(col_data.try_into_sync_request()?, progress) .await? .upload_response() { diff --git a/rslib/src/sync/http_client/full_sync.rs b/rslib/src/sync/http_client/full_sync.rs index 11c61661b..d1e7d56e9 100644 --- a/rslib/src/sync/http_client/full_sync.rs +++ b/rslib/src/sync/http_client/full_sync.rs @@ -7,8 +7,8 @@ use std::time::Duration; use tokio::select; use tokio::time::interval; +use crate::progress::ThrottlingProgressHandler; use crate::sync::collection::progress::FullSyncProgress; -use crate::sync::collection::progress::FullSyncProgressFn; use crate::sync::collection::protocol::EmptyInput; use crate::sync::collection::protocol::SyncMethod; use crate::sync::collection::upload::UploadResponse; @@ -19,49 +19,47 @@ use crate::sync::request::SyncRequest; use crate::sync::response::SyncResponse; impl HttpSyncClient { - pub fn set_full_sync_progress_fn(&mut self, func: Option) { - *self.full_sync_progress_fn.lock().unwrap() = func; - } - - fn full_sync_progress_monitor(&self, sending: bool) -> (IoMonitor, impl Future) { - let mut progress = FullSyncProgress { - transferred_bytes: 0, - total_bytes: 0, - }; - let mut progress_fn = self - .full_sync_progress_fn - .lock() - .unwrap() - .take() - .expect("progress func was not set"); + fn full_sync_progress_monitor( + &self, + sending: bool, + mut progress: ThrottlingProgressHandler, + ) -> (IoMonitor, impl Future) { let io_monitor = IoMonitor::new(); let io_monitor2 = io_monitor.clone(); let update_progress = async move { let mut interval = interval(Duration::from_millis(100)); loop { interval.tick().await; - let guard = io_monitor2.0.lock().unwrap(); - progress.total_bytes = if sending { - guard.total_bytes_to_send - } else { - guard.total_bytes_to_receive - } as usize; - progress.transferred_bytes = if sending { - guard.bytes_sent - } else { - guard.bytes_received - } as usize; - progress_fn(progress, true) + let (total_bytes, transferred_bytes) = { + let guard = io_monitor2.0.lock().unwrap(); + ( + if sending { + guard.total_bytes_to_send + } else { + guard.total_bytes_to_receive + }, + if sending { + guard.bytes_sent + } else { + guard.bytes_received + }, + ) + }; + _ = progress.update(false, |p| { + p.total_bytes = total_bytes as usize; + p.transferred_bytes = transferred_bytes as usize; + }) } }; (io_monitor, update_progress) } - pub(super) async fn download_inner( + pub(in super::super) async fn download_with_progress( &self, req: SyncRequest, + progress: ThrottlingProgressHandler, ) -> HttpResult>> { - let (io_monitor, progress_fut) = self.full_sync_progress_monitor(false); + let (io_monitor, progress_fut) = self.full_sync_progress_monitor(false, progress); let output = self.request_ext(SyncMethod::Download, req, io_monitor); select! { _ = progress_fut => unreachable!(), @@ -69,11 +67,12 @@ impl HttpSyncClient { } } - pub(super) async fn upload_inner( + pub(in super::super) async fn upload_with_progress( &self, req: SyncRequest>, + progress: ThrottlingProgressHandler, ) -> HttpResult> { - let (io_monitor, progress_fut) = self.full_sync_progress_monitor(true); + let (io_monitor, progress_fut) = self.full_sync_progress_monitor(true, progress); let output = self.request_ext(SyncMethod::Upload, req, io_monitor); select! { _ = progress_fut => unreachable!(), diff --git a/rslib/src/sync/http_client/mod.rs b/rslib/src/sync/http_client/mod.rs index 6d9eb9e5b..0131fd158 100644 --- a/rslib/src/sync/http_client/mod.rs +++ b/rslib/src/sync/http_client/mod.rs @@ -5,7 +5,6 @@ pub(crate) mod full_sync; pub(crate) mod io_monitor; mod protocol; -use std::sync::Mutex; use std::time::Duration; use reqwest::Client; @@ -14,7 +13,6 @@ use reqwest::StatusCode; use reqwest::Url; use crate::notes; -use crate::sync::collection::progress::FullSyncProgressFn; use crate::sync::collection::protocol::AsSyncEndpoint; use crate::sync::error::HttpError; use crate::sync::error::HttpResult; @@ -25,6 +23,7 @@ use crate::sync::request::header_and_stream::SYNC_HEADER_NAME; use crate::sync::request::SyncRequest; use crate::sync::response::SyncResponse; +#[derive(Clone)] pub struct HttpSyncClient { /// Set to the empty string for initial login pub sync_key: String, @@ -32,7 +31,6 @@ pub struct HttpSyncClient { client: Client, pub endpoint: Url, pub io_timeout: Duration, - full_sync_progress_fn: Mutex>, } impl HttpSyncClient { @@ -46,19 +44,6 @@ impl HttpSyncClient { .endpoint .unwrap_or_else(|| Url::try_from("https://sync.ankiweb.net/").unwrap()), io_timeout, - full_sync_progress_fn: Mutex::new(None), - } - } - - #[cfg(test)] - pub fn partial_clone(&self) -> Self { - Self { - sync_key: self.sync_key.clone(), - session_key: self.session_key.clone(), - client: self.client.clone(), - endpoint: self.endpoint.clone(), - full_sync_progress_fn: Mutex::new(None), - io_timeout: self.io_timeout, } } diff --git a/rslib/src/sync/http_client/protocol.rs b/rslib/src/sync/http_client/protocol.rs index 0c84b1b84..47c675f66 100644 --- a/rslib/src/sync/http_client/protocol.rs +++ b/rslib/src/sync/http_client/protocol.rs @@ -4,6 +4,7 @@ use async_trait::async_trait; use crate::prelude::TimestampMillis; +use crate::progress::ThrottlingProgressHandler; use crate::sync::collection::changes::ApplyChangesRequest; use crate::sync::collection::changes::UnchunkedChanges; use crate::sync::collection::chunks::ApplyChunkRequest; @@ -97,11 +98,13 @@ impl SyncProtocol for HttpSyncClient { } async fn upload(&self, req: SyncRequest>) -> HttpResult> { - self.upload_inner(req).await + self.upload_with_progress(req, ThrottlingProgressHandler::default()) + .await } async fn download(&self, req: SyncRequest) -> HttpResult>> { - self.download_inner(req).await + self.download_with_progress(req, ThrottlingProgressHandler::default()) + .await } } diff --git a/rslib/src/sync/media/progress.rs b/rslib/src/sync/media/progress.rs index ea295c0eb..c1ff77c85 100644 --- a/rslib/src/sync/media/progress.rs +++ b/rslib/src/sync/media/progress.rs @@ -9,3 +9,9 @@ pub struct MediaSyncProgress { pub uploaded_files: usize, pub uploaded_deletions: usize, } + +#[derive(Debug, Default, Clone, Copy)] +#[repr(transparent)] +pub struct MediaCheckProgress { + pub checked: usize, +} diff --git a/rslib/src/sync/media/syncer.rs b/rslib/src/sync/media/syncer.rs index 24d849132..b37c9e7b5 100644 --- a/rslib/src/sync/media/syncer.rs +++ b/rslib/src/sync/media/syncer.rs @@ -10,6 +10,7 @@ use crate::error::SyncErrorKind; use crate::media::files::mtime_as_i64; use crate::media::MediaManager; use crate::prelude::*; +use crate::progress::ThrottlingProgressHandler; use crate::sync::http_client::HttpSyncClient; use crate::sync::media::begin::SyncBeginRequest; use crate::sync::media::begin::SyncBeginResponse; @@ -30,41 +31,25 @@ use crate::sync::media::MAX_MEDIA_FILES_IN_ZIP; use crate::sync::request::IntoSyncRequest; use crate::version; -pub struct MediaSyncer

-where - P: FnMut(MediaSyncProgress) -> bool, -{ +pub struct MediaSyncer { mgr: MediaManager, client: HttpSyncClient, - progress_cb: P, - progress: MediaSyncProgress, + progress: ThrottlingProgressHandler, } -impl

MediaSyncer

-where - P: FnMut(MediaSyncProgress) -> bool, -{ +impl MediaSyncer { pub fn new( mgr: MediaManager, - progress_cb: P, + progress: ThrottlingProgressHandler, client: HttpSyncClient, - ) -> Result> { + ) -> Result { Ok(MediaSyncer { mgr, client, - progress_cb, - progress: Default::default(), + progress, }) } - fn fire_progress_cb(&mut self) -> Result<()> { - if (self.progress_cb)(self.progress) { - Ok(()) - } else { - Err(AnkiError::Interrupted) - } - } - pub async fn sync(&mut self) -> Result<()> { self.sync_inner().await.map_err(|e| { debug!("sync error: {:?}", e); @@ -100,8 +85,6 @@ where self.finalize_sync().await?; } - self.fire_progress_cb()?; - debug!("media sync complete"); Ok(()) @@ -129,16 +112,9 @@ where /// Make sure media DB is up to date. fn register_changes(&mut self) -> Result<()> { - // make borrow checker happy - let progress = &mut self.progress; - let progress_cb = &mut self.progress_cb; - - let progress = |checked| { - progress.checked = checked; - (progress_cb)(*progress) - }; - - ChangeTracker::new(self.mgr.media_folder.as_path(), progress).register_changes(&self.mgr.db) + let progress_cb = |checked| self.progress.update(true, |p| p.checked = checked).is_ok(); + ChangeTracker::new(self.mgr.media_folder.as_path(), progress_cb) + .register_changes(&self.mgr.db) } async fn fetch_changes(&mut self, mut meta: MediaDatabaseMetadata) -> Result<()> { @@ -157,16 +133,15 @@ where } last_usn = batch.last().unwrap().usn; - self.progress.checked += batch.len(); - self.fire_progress_cb()?; + self.progress.update(false, |p| p.checked += batch.len())?; let (to_download, to_delete, to_remove_pending) = changes::determine_required_changes(&self.mgr.db, batch)?; // file removal self.mgr.remove_files(to_delete.as_slice())?; - self.progress.downloaded_deletions += to_delete.len(); - self.fire_progress_cb()?; + self.progress + .update(false, |p| p.downloaded_deletions += to_delete.len())?; // file download let mut downloaded = vec![]; @@ -189,8 +164,7 @@ where dl_fnames = &dl_fnames[len..]; downloaded.extend(download_batch); - self.progress.downloaded_files += len; - self.fire_progress_cb()?; + self.progress.update(false, |p| p.downloaded_files += len)?; } // then update the DB @@ -227,8 +201,8 @@ where None => { // discard zip info and retry batch - not particularly efficient, // but this is a corner case - self.progress.checked += pending.len(); - self.fire_progress_cb()?; + self.progress + .update(false, |p| p.checked += pending.len())?; continue; } Some(data) => zip_files_for_upload(data)?, @@ -245,9 +219,10 @@ where .take(reply.processed) .partition(|e| e.sha1.is_some()); - self.progress.uploaded_files += processed_files.len(); - self.progress.uploaded_deletions += processed_deletions.len(); - self.fire_progress_cb()?; + self.progress.update(false, |p| { + p.uploaded_files += processed_files.len(); + p.uploaded_deletions += processed_deletions.len(); + })?; let fnames: Vec<_> = processed_files .into_iter() diff --git a/rslib/src/sync/media/tests.rs b/rslib/src/sync/media/tests.rs index 06a9bff65..c5a9f8acd 100644 --- a/rslib/src/sync/media/tests.rs +++ b/rslib/src/sync/media/tests.rs @@ -15,6 +15,7 @@ use reqwest::Client; use crate::error::Result; use crate::media::MediaManager; use crate::prelude::AnkiError; +use crate::progress::ThrottlingProgressHandler; use crate::sync::collection::protocol::AsSyncEndpoint; use crate::sync::collection::tests::with_active_server; use crate::sync::collection::tests::SyncTestContext; @@ -104,7 +105,7 @@ async fn legacy_session_key_works() -> Result<()> { #[tokio::test] async fn sanity_check() -> Result<()> { with_active_server(|client| async move { - let ctx = SyncTestContext::new(client.partial_clone()); + let ctx = SyncTestContext::new(client.clone()); let media1 = ctx.media1(); ctx.sync_media1().await?; // may be non-zero when testing on external endpoint @@ -134,8 +135,8 @@ async fn sanity_check() -> Result<()> { .await } -fn ignore_progress(_progress: MediaSyncProgress) -> bool { - true +fn ignore_progress() -> ThrottlingProgressHandler { + ThrottlingProgressHandler::new(Default::default()) } impl SyncTestContext { @@ -149,13 +150,13 @@ impl SyncTestContext { async fn sync_media1(&self) -> Result<()> { let mut syncer = - MediaSyncer::new(self.media1(), ignore_progress, self.client.partial_clone()).unwrap(); + MediaSyncer::new(self.media1(), ignore_progress(), self.client.clone()).unwrap(); syncer.sync().await } async fn sync_media2(&self) -> Result<()> { let mut syncer = - MediaSyncer::new(self.media2(), ignore_progress, self.client.partial_clone()).unwrap(); + MediaSyncer::new(self.media2(), ignore_progress(), self.client.clone()).unwrap(); syncer.sync().await } @@ -171,7 +172,7 @@ impl SyncTestContext { #[tokio::test] async fn media_roundtrip() -> Result<()> { with_active_server(|client| async move { - let ctx = SyncTestContext::new(client.partial_clone()); + let ctx = SyncTestContext::new(client.clone()); let media1 = ctx.media1(); let media2 = ctx.media2(); ctx.sync_media1().await?; @@ -219,7 +220,7 @@ async fn media_roundtrip() -> Result<()> { #[tokio::test] async fn parallel_requests() -> Result<()> { with_active_server(|client| async move { - let ctx = SyncTestContext::new(client.partial_clone()); + let ctx = SyncTestContext::new(client.clone()); let media1 = ctx.media1(); let media2 = ctx.media2(); ctx.sleep();