diff --git a/rslib/Cargo.toml b/rslib/Cargo.toml index 94707c0d3..0bb56f85e 100644 --- a/rslib/Cargo.toml +++ b/rslib/Cargo.toml @@ -23,7 +23,7 @@ unicode-normalization = "0.1.12" tempfile = "3.1.0" serde = "1.0.104" serde_json = "1.0.45" -tokio = "0.2.11" +tokio = { version = "0.2.11", features = ["fs"] } serde_derive = "1.0.104" zip = "0.5.4" serde_tuple = "0.4.0" @@ -47,6 +47,8 @@ rand = "0.7.3" num-integer = "0.1.42" itertools = "0.9.0" flate2 = "1.0.14" +pin-project = "0.4.17" +async-compression = { version = "0.3.4", features = ["stream", "gzip"] } [target.'cfg(target_vendor="apple")'.dependencies.rusqlite] version = "0.23.1" @@ -57,12 +59,15 @@ version = "0.23.1" features = ["trace", "functions", "collation", "bundled"] [target.'cfg(linux)'.dependencies] -reqwest = { version = "0.10.1", features = ["json", "socks", "native-tls-vendored"] } +reqwest = { version = "0.10.1", features = ["json", "socks", "stream", "native-tls-vendored"] } [target.'cfg(not(linux))'.dependencies] -reqwest = { version = "0.10.1", features = ["json", "socks"] } +reqwest = { version = "0.10.1", features = ["json", "socks", "stream" ] } [build-dependencies] prost-build = "0.6.1" fluent-syntax = "0.9.2" +[dev-dependencies] +env_logger = "0.7.1" + diff --git a/rslib/src/sync/http_client.rs b/rslib/src/sync/http_client.rs index 6071eb7fb..1e855db8f 100644 --- a/rslib/src/sync/http_client.rs +++ b/rslib/src/sync/http_client.rs @@ -2,13 +2,16 @@ // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html use super::*; +use bytes::Bytes; +use futures::Stream; +use reqwest::Body; static SYNC_VERSION: u8 = 10; -pub struct HTTPSyncClient<'a> { +pub struct HTTPSyncClient { hkey: Option, skey: String, client: Client, - endpoint: &'a str, + endpoint: String, } #[derive(Serialize)] @@ -68,14 +71,15 @@ struct SanityCheckIn { #[derive(Serialize)] struct Empty {} -impl HTTPSyncClient<'_> { - pub fn new<'a>(endpoint: &'a str) -> HTTPSyncClient<'a> { +impl HTTPSyncClient { + pub fn new<'a>(endpoint_suffix: &str) -> HTTPSyncClient { let client = Client::builder() .connect_timeout(Duration::from_secs(30)) .timeout(Duration::from_secs(60)) .build() .unwrap(); let skey = guid(); + let endpoint = endpoint(&endpoint_suffix); HTTPSyncClient { hkey: None, skey, @@ -84,7 +88,7 @@ impl HTTPSyncClient<'_> { } } - async fn json_request(&self, method: &str, json: &T) -> Result + async fn json_request(&self, method: &str, json: &T, timeout_long: bool) -> Result where T: serde::Serialize, { @@ -94,7 +98,7 @@ impl HTTPSyncClient<'_> { gz.write_all(&req_json)?; let part = multipart::Part::bytes(gz.finish()?); - self.request(method, part).await + self.request(method, part, timeout_long).await } async fn json_request_deserialized(&self, method: &str, json: &T) -> Result @@ -102,14 +106,19 @@ impl HTTPSyncClient<'_> { T: Serialize, T2: DeserializeOwned, { - self.json_request(method, json) + self.json_request(method, json, false) .await? .json() .await .map_err(Into::into) } - async fn request(&self, method: &str, data_part: multipart::Part) -> Result { + async fn request( + &self, + method: &str, + data_part: multipart::Part, + timeout_long: bool, + ) -> Result { let data_part = data_part.file_name("data"); let mut form = multipart::Form::new() @@ -120,12 +129,16 @@ impl HTTPSyncClient<'_> { } let url = format!("{}{}", self.endpoint, method); - let req = self.client.post(&url).multipart(form); + let mut req = self.client.post(&url).multipart(form); + + if timeout_long { + req = req.timeout(Duration::from_secs(60 * 60)); + } req.send().await?.error_for_status().map_err(Into::into) } - async fn login(&mut self, username: &str, password: &str) -> Result<()> { + pub(crate) async fn login(&mut self, username: &str, password: &str) -> Result<()> { let resp: HostKeyOut = self .json_request_deserialized("hostKey", &HostKeyIn { username, password }) .await?; @@ -138,7 +151,7 @@ impl HTTPSyncClient<'_> { self.hkey.as_ref().unwrap() } - async fn meta(&mut self) -> Result { + pub(crate) async fn meta(&self) -> Result { let meta_in = MetaIn { sync_version: SYNC_VERSION, client_version: sync_client_version(), @@ -146,49 +159,168 @@ impl HTTPSyncClient<'_> { self.json_request_deserialized("meta", &meta_in).await } - async fn start(&mut self, input: &StartIn) -> Result { - self.json_request_deserialized("start", input).await + pub(crate) async fn start( + &self, + minimum_usn: Usn, + minutes_west: i32, + client_is_newer: bool, + ) -> Result { + let input = StartIn { + minimum_usn, + minutes_west, + client_is_newer, + client_graves: None, + }; + self.json_request_deserialized("start", &input).await } - async fn apply_graves(&mut self, chunk: Graves) -> Result<()> { + pub(crate) async fn apply_graves(&self, chunk: Graves) -> Result<()> { let input = ApplyGravesIn { chunk }; - let resp = self.json_request("applyGraves", &input).await?; + let resp = self.json_request("applyGraves", &input, false).await?; resp.error_for_status()?; Ok(()) } - async fn apply_changes(&mut self, changes: Changes) -> Result { + pub(crate) async fn apply_changes(&self, changes: Changes) -> Result { let input = ApplyChangesIn { changes }; self.json_request_deserialized("applyChanges", &input).await } - async fn chunk(&mut self) -> Result { + pub(crate) async fn chunk(&self) -> Result { self.json_request_deserialized("chunk", &Empty {}).await } - async fn apply_chunk(&mut self, chunk: Chunk) -> Result<()> { + pub(crate) async fn apply_chunk(&self, chunk: Chunk) -> Result<()> { let input = ApplyChunkIn { chunk }; - let resp = self.json_request("applyChunk", &input).await?; + let resp = self.json_request("applyChunk", &input, false).await?; resp.error_for_status()?; Ok(()) } - async fn sanity_check(&mut self, client: SanityCheckCounts) -> Result { + pub(crate) async fn sanity_check(&self, client: SanityCheckCounts) -> Result { let input = SanityCheckIn { client }; self.json_request_deserialized("sanityCheck2", &input).await } - async fn finish(&mut self) -> Result<()> { - let resp = self.json_request("finish", &Empty {}).await?; + pub(crate) async fn finish(&self) -> Result<()> { + let resp = self.json_request("finish", &Empty {}, false).await?; resp.error_for_status()?; Ok(()) } - async fn abort(&mut self) -> Result<()> { - let resp = self.json_request("abort", &Empty {}).await?; + pub(crate) async fn abort(&self) -> Result<()> { + let resp = self.json_request("abort", &Empty {}, false).await?; resp.error_for_status()?; Ok(()) } + + async fn download_inner( + &self, + ) -> Result<( + usize, + impl Stream>, + )> { + let resp: reqwest::Response = self.json_request("download", &Empty {}, true).await?; + let len = resp.content_length().unwrap_or_default(); + Ok((len as usize, resp.bytes_stream())) + } + + /// Download collection into a temporary file, returning it. + /// Caller should persist the file in the correct path after checking it. + pub(crate) async fn download

(&self, folder: &Path, progress_fn: P) -> Result + where + P: Fn(&FullSyncProgress), + { + let mut temp_file = NamedTempFile::new_in(folder)?; + let (size, mut stream) = self.download_inner().await?; + let mut progress = FullSyncProgress { + transferred_bytes: 0, + total_bytes: size, + }; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + temp_file.write_all(&chunk)?; + progress.transferred_bytes += chunk.len(); + progress_fn(&progress); + } + + Ok(temp_file) + } + + async fn upload_inner(&self, body: Body) -> Result<()> { + let data_part = multipart::Part::stream(body); + let resp = self.request("upload", data_part, true).await?; + resp.error_for_status()?; + Ok(()) + } + + pub(crate) async fn upload

(&mut self, col_path: &Path, progress_fn: P) -> Result<()> + where + P: Fn(&FullSyncProgress) + Send + Sync + 'static, + { + let file = tokio::fs::File::open(col_path).await?; + let total_bytes = file.metadata().await?.len() as usize; + let wrap1 = ProgressWrapper { + reader: file, + progress_fn, + progress: FullSyncProgress { + transferred_bytes: 0, + total_bytes, + }, + }; + let wrap2 = async_compression::stream::GzipEncoder::new(wrap1); + let body = Body::wrap_stream(wrap2); + self.upload_inner(body).await?; + + Ok(()) + } +} + +use futures::{ + ready, + task::{Context, Poll}, +}; +use pin_project::pin_project; +use std::pin::Pin; +use tokio::io::AsyncRead; + +#[pin_project] +struct ProgressWrapper { + #[pin] + reader: S, + progress_fn: P, + progress: FullSyncProgress, +} + +impl Stream for ProgressWrapper +where + S: AsyncRead, + P: Fn(&FullSyncProgress), +{ + type Item = std::result::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut buf = vec![0; 16 * 1024]; + let this = self.project(); + match ready!(this.reader.poll_read(cx, &mut buf)) { + Ok(0) => Poll::Ready(None), + Ok(size) => { + buf.resize(size, 0); + this.progress.transferred_bytes += size; + (this.progress_fn)(&this.progress); + Poll::Ready(Some(Ok(Bytes::from(buf)))) + } + Err(e) => Poll::Ready(Some(Err(e))), + } + } +} + +fn endpoint(suffix: &str) -> String { + if let Ok(endpoint) = std::env::var("SYNC_ENDPOINT") { + endpoint + } else { + format!("https://sync{}.ankiweb.net/sync/", suffix) + } } #[cfg(test)] @@ -197,10 +329,8 @@ mod test { use crate::err::SyncErrorKind; use tokio::runtime::Runtime; - static ENDPOINT: &'static str = "https://sync.ankiweb.net/sync/"; - async fn http_client_inner(username: String, password: String) -> Result<()> { - let mut syncer = HTTPSyncClient::new(ENDPOINT); + let mut syncer = HTTPSyncClient::new(""); assert!(matches!( syncer.login("nosuchuser", "nosuchpass").await, @@ -223,20 +353,13 @@ mod test { }) )); - let input = StartIn { - minimum_usn: Usn(0), - minutes_west: 0, - client_is_newer: true, - client_graves: None, - }; - - let _graves = syncer.start(&input).await?; + let _graves = syncer.start(Usn(1), 0, true).await?; // aborting should now work syncer.abort().await?; // start again, and continue - let _graves = syncer.start(&input).await?; + let _graves = syncer.start(Usn(0), 0, true).await?; syncer.apply_graves(Graves::default()).await?; @@ -268,6 +391,21 @@ mod test { syncer.finish().await?; + use tempfile::tempdir; + + let dir = tempdir()?; + let out_path = syncer + .download(&dir.path(), |progress| { + println!("progress: {:?}", progress); + }) + .await?; + + syncer + .upload(&out_path.path(), |progress| { + println!("progress {:?}", progress); + }) + .await?; + Ok(()) } @@ -280,6 +418,7 @@ mod test { } }; let pass = std::env::var("TEST_SYNC_PASS").unwrap(); + env_logger::init(); let mut rt = Runtime::new().unwrap(); rt.block_on(http_client_inner(user, pass)) diff --git a/rslib/src/sync/mod.rs b/rslib/src/sync/mod.rs index 907031e5f..bc843ff3a 100644 --- a/rslib/src/sync/mod.rs +++ b/rslib/src/sync/mod.rs @@ -14,18 +14,20 @@ use crate::{ }; use flate2::write::GzEncoder; use flate2::Compression; +use futures::StreamExt; use reqwest::{multipart, Client, Response}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::Value; use serde_tuple::Serialize_tuple; use std::io::prelude::*; -use std::{collections::HashMap, time::Duration}; +use std::{collections::HashMap, path::Path, time::Duration}; +use tempfile::NamedTempFile; #[derive(Default, Debug)] pub struct SyncProgress {} #[derive(Serialize, Deserialize, Debug)] -struct ServerMeta { +pub struct ServerMeta { #[serde(rename = "mod")] modified: TimestampMillis, #[serde(rename = "scm")] @@ -42,20 +44,20 @@ struct ServerMeta { } #[derive(Serialize, Deserialize, Debug, Default)] -struct Graves { +pub struct Graves { cards: Vec, decks: Vec, notes: Vec, } #[derive(Serialize_tuple, Deserialize, Debug, Default)] -struct DecksAndConfig { +pub struct DecksAndConfig { decks: Vec, config: Vec, } #[derive(Serialize, Deserialize, Debug, Default)] -struct Changes { +pub struct Changes { #[serde(rename = "models")] notetypes: Vec, #[serde(rename = "decks")] @@ -70,18 +72,18 @@ struct Changes { } #[derive(Serialize, Deserialize, Debug, Default)] -struct Chunk { +pub struct Chunk { done: bool, - #[serde(skip_serializing_if = "Vec::is_empty")] + #[serde(skip_serializing_if = "Vec::is_empty", default)] revlog: Vec, - #[serde(skip_serializing_if = "Vec::is_empty")] + #[serde(skip_serializing_if = "Vec::is_empty", default)] cards: Vec, - #[serde(skip_serializing_if = "Vec::is_empty")] + #[serde(skip_serializing_if = "Vec::is_empty", default)] notes: Vec, } #[derive(Serialize_tuple, Deserialize, Debug)] -struct ReviewLogEntry { +pub struct ReviewLogEntry { id: TimestampMillis, cid: CardID, usn: Usn, @@ -97,7 +99,7 @@ struct ReviewLogEntry { } #[derive(Serialize_tuple, Deserialize, Debug)] -struct NoteEntry { +pub struct NoteEntry { id: NoteID, guid: String, #[serde(rename = "mid")] @@ -114,7 +116,7 @@ struct NoteEntry { } #[derive(Serialize_tuple, Deserialize, Debug)] -struct CardEntry { +pub struct CardEntry { id: CardID, nid: NoteID, did: DeckID, @@ -136,7 +138,7 @@ struct CardEntry { } #[derive(Serialize, Deserialize, Debug)] -struct SanityCheckOut { +pub struct SanityCheckOut { status: SanityCheckStatus, #[serde(rename = "c")] client: Option, @@ -152,7 +154,7 @@ enum SanityCheckStatus { } #[derive(Serialize_tuple, Deserialize, Debug)] -struct SanityCheckCounts { +pub struct SanityCheckCounts { counts: SanityCheckDueCounts, cards: u32, notes: u32, @@ -165,8 +167,14 @@ struct SanityCheckCounts { } #[derive(Serialize_tuple, Deserialize, Debug)] -struct SanityCheckDueCounts { +pub struct SanityCheckDueCounts { new: u32, learn: u32, review: u32, } + +#[derive(Debug, Default)] +pub struct FullSyncProgress { + transferred_bytes: usize, + total_bytes: usize, +}