full upload/download HTTP code

This commit is contained in:
Damien Elmes 2020-05-27 19:32:05 +10:00
parent 4fcb10bfa9
commit 529e89f48e
3 changed files with 206 additions and 54 deletions

View File

@ -23,7 +23,7 @@ unicode-normalization = "0.1.12"
tempfile = "3.1.0" tempfile = "3.1.0"
serde = "1.0.104" serde = "1.0.104"
serde_json = "1.0.45" serde_json = "1.0.45"
tokio = "0.2.11" tokio = { version = "0.2.11", features = ["fs"] }
serde_derive = "1.0.104" serde_derive = "1.0.104"
zip = "0.5.4" zip = "0.5.4"
serde_tuple = "0.4.0" serde_tuple = "0.4.0"
@ -47,6 +47,8 @@ rand = "0.7.3"
num-integer = "0.1.42" num-integer = "0.1.42"
itertools = "0.9.0" itertools = "0.9.0"
flate2 = "1.0.14" flate2 = "1.0.14"
pin-project = "0.4.17"
async-compression = { version = "0.3.4", features = ["stream", "gzip"] }
[target.'cfg(target_vendor="apple")'.dependencies.rusqlite] [target.'cfg(target_vendor="apple")'.dependencies.rusqlite]
version = "0.23.1" version = "0.23.1"
@ -57,12 +59,15 @@ version = "0.23.1"
features = ["trace", "functions", "collation", "bundled"] features = ["trace", "functions", "collation", "bundled"]
[target.'cfg(linux)'.dependencies] [target.'cfg(linux)'.dependencies]
reqwest = { version = "0.10.1", features = ["json", "socks", "native-tls-vendored"] } reqwest = { version = "0.10.1", features = ["json", "socks", "stream", "native-tls-vendored"] }
[target.'cfg(not(linux))'.dependencies] [target.'cfg(not(linux))'.dependencies]
reqwest = { version = "0.10.1", features = ["json", "socks"] } reqwest = { version = "0.10.1", features = ["json", "socks", "stream" ] }
[build-dependencies] [build-dependencies]
prost-build = "0.6.1" prost-build = "0.6.1"
fluent-syntax = "0.9.2" fluent-syntax = "0.9.2"
[dev-dependencies]
env_logger = "0.7.1"

View File

@ -2,13 +2,16 @@
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use super::*; use super::*;
use bytes::Bytes;
use futures::Stream;
use reqwest::Body;
static SYNC_VERSION: u8 = 10; static SYNC_VERSION: u8 = 10;
pub struct HTTPSyncClient<'a> { pub struct HTTPSyncClient {
hkey: Option<String>, hkey: Option<String>,
skey: String, skey: String,
client: Client, client: Client,
endpoint: &'a str, endpoint: String,
} }
#[derive(Serialize)] #[derive(Serialize)]
@ -68,14 +71,15 @@ struct SanityCheckIn {
#[derive(Serialize)] #[derive(Serialize)]
struct Empty {} struct Empty {}
impl HTTPSyncClient<'_> { impl HTTPSyncClient {
pub fn new<'a>(endpoint: &'a str) -> HTTPSyncClient<'a> { pub fn new<'a>(endpoint_suffix: &str) -> HTTPSyncClient {
let client = Client::builder() let client = Client::builder()
.connect_timeout(Duration::from_secs(30)) .connect_timeout(Duration::from_secs(30))
.timeout(Duration::from_secs(60)) .timeout(Duration::from_secs(60))
.build() .build()
.unwrap(); .unwrap();
let skey = guid(); let skey = guid();
let endpoint = endpoint(&endpoint_suffix);
HTTPSyncClient { HTTPSyncClient {
hkey: None, hkey: None,
skey, skey,
@ -84,7 +88,7 @@ impl HTTPSyncClient<'_> {
} }
} }
async fn json_request<T>(&self, method: &str, json: &T) -> Result<Response> async fn json_request<T>(&self, method: &str, json: &T, timeout_long: bool) -> Result<Response>
where where
T: serde::Serialize, T: serde::Serialize,
{ {
@ -94,7 +98,7 @@ impl HTTPSyncClient<'_> {
gz.write_all(&req_json)?; gz.write_all(&req_json)?;
let part = multipart::Part::bytes(gz.finish()?); let part = multipart::Part::bytes(gz.finish()?);
self.request(method, part).await self.request(method, part, timeout_long).await
} }
async fn json_request_deserialized<T, T2>(&self, method: &str, json: &T) -> Result<T2> async fn json_request_deserialized<T, T2>(&self, method: &str, json: &T) -> Result<T2>
@ -102,14 +106,19 @@ impl HTTPSyncClient<'_> {
T: Serialize, T: Serialize,
T2: DeserializeOwned, T2: DeserializeOwned,
{ {
self.json_request(method, json) self.json_request(method, json, false)
.await? .await?
.json() .json()
.await .await
.map_err(Into::into) .map_err(Into::into)
} }
async fn request(&self, method: &str, data_part: multipart::Part) -> Result<Response> { async fn request(
&self,
method: &str,
data_part: multipart::Part,
timeout_long: bool,
) -> Result<Response> {
let data_part = data_part.file_name("data"); let data_part = data_part.file_name("data");
let mut form = multipart::Form::new() let mut form = multipart::Form::new()
@ -120,12 +129,16 @@ impl HTTPSyncClient<'_> {
} }
let url = format!("{}{}", self.endpoint, method); let url = format!("{}{}", self.endpoint, method);
let req = self.client.post(&url).multipart(form); let mut req = self.client.post(&url).multipart(form);
if timeout_long {
req = req.timeout(Duration::from_secs(60 * 60));
}
req.send().await?.error_for_status().map_err(Into::into) req.send().await?.error_for_status().map_err(Into::into)
} }
async fn login(&mut self, username: &str, password: &str) -> Result<()> { pub(crate) async fn login(&mut self, username: &str, password: &str) -> Result<()> {
let resp: HostKeyOut = self let resp: HostKeyOut = self
.json_request_deserialized("hostKey", &HostKeyIn { username, password }) .json_request_deserialized("hostKey", &HostKeyIn { username, password })
.await?; .await?;
@ -138,7 +151,7 @@ impl HTTPSyncClient<'_> {
self.hkey.as_ref().unwrap() self.hkey.as_ref().unwrap()
} }
async fn meta(&mut self) -> Result<ServerMeta> { pub(crate) async fn meta(&self) -> Result<ServerMeta> {
let meta_in = MetaIn { let meta_in = MetaIn {
sync_version: SYNC_VERSION, sync_version: SYNC_VERSION,
client_version: sync_client_version(), client_version: sync_client_version(),
@ -146,49 +159,168 @@ impl HTTPSyncClient<'_> {
self.json_request_deserialized("meta", &meta_in).await self.json_request_deserialized("meta", &meta_in).await
} }
async fn start(&mut self, input: &StartIn) -> Result<Graves> { pub(crate) async fn start(
self.json_request_deserialized("start", input).await &self,
minimum_usn: Usn,
minutes_west: i32,
client_is_newer: bool,
) -> Result<Graves> {
let input = StartIn {
minimum_usn,
minutes_west,
client_is_newer,
client_graves: None,
};
self.json_request_deserialized("start", &input).await
} }
async fn apply_graves(&mut self, chunk: Graves) -> Result<()> { pub(crate) async fn apply_graves(&self, chunk: Graves) -> Result<()> {
let input = ApplyGravesIn { chunk }; let input = ApplyGravesIn { chunk };
let resp = self.json_request("applyGraves", &input).await?; let resp = self.json_request("applyGraves", &input, false).await?;
resp.error_for_status()?; resp.error_for_status()?;
Ok(()) Ok(())
} }
async fn apply_changes(&mut self, changes: Changes) -> Result<Changes> { pub(crate) async fn apply_changes(&self, changes: Changes) -> Result<Changes> {
let input = ApplyChangesIn { changes }; let input = ApplyChangesIn { changes };
self.json_request_deserialized("applyChanges", &input).await self.json_request_deserialized("applyChanges", &input).await
} }
async fn chunk(&mut self) -> Result<Chunk> { pub(crate) async fn chunk(&self) -> Result<Chunk> {
self.json_request_deserialized("chunk", &Empty {}).await self.json_request_deserialized("chunk", &Empty {}).await
} }
async fn apply_chunk(&mut self, chunk: Chunk) -> Result<()> { pub(crate) async fn apply_chunk(&self, chunk: Chunk) -> Result<()> {
let input = ApplyChunkIn { chunk }; let input = ApplyChunkIn { chunk };
let resp = self.json_request("applyChunk", &input).await?; let resp = self.json_request("applyChunk", &input, false).await?;
resp.error_for_status()?; resp.error_for_status()?;
Ok(()) Ok(())
} }
async fn sanity_check(&mut self, client: SanityCheckCounts) -> Result<SanityCheckOut> { pub(crate) async fn sanity_check(&self, client: SanityCheckCounts) -> Result<SanityCheckOut> {
let input = SanityCheckIn { client }; let input = SanityCheckIn { client };
self.json_request_deserialized("sanityCheck2", &input).await self.json_request_deserialized("sanityCheck2", &input).await
} }
async fn finish(&mut self) -> Result<()> { pub(crate) async fn finish(&self) -> Result<()> {
let resp = self.json_request("finish", &Empty {}).await?; let resp = self.json_request("finish", &Empty {}, false).await?;
resp.error_for_status()?; resp.error_for_status()?;
Ok(()) Ok(())
} }
async fn abort(&mut self) -> Result<()> { pub(crate) async fn abort(&self) -> Result<()> {
let resp = self.json_request("abort", &Empty {}).await?; let resp = self.json_request("abort", &Empty {}, false).await?;
resp.error_for_status()?; resp.error_for_status()?;
Ok(()) Ok(())
} }
async fn download_inner(
&self,
) -> Result<(
usize,
impl Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
)> {
let resp: reqwest::Response = self.json_request("download", &Empty {}, true).await?;
let len = resp.content_length().unwrap_or_default();
Ok((len as usize, resp.bytes_stream()))
}
/// Download collection into a temporary file, returning it.
/// Caller should persist the file in the correct path after checking it.
pub(crate) async fn download<P>(&self, folder: &Path, progress_fn: P) -> Result<NamedTempFile>
where
P: Fn(&FullSyncProgress),
{
let mut temp_file = NamedTempFile::new_in(folder)?;
let (size, mut stream) = self.download_inner().await?;
let mut progress = FullSyncProgress {
transferred_bytes: 0,
total_bytes: size,
};
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
temp_file.write_all(&chunk)?;
progress.transferred_bytes += chunk.len();
progress_fn(&progress);
}
Ok(temp_file)
}
async fn upload_inner(&self, body: Body) -> Result<()> {
let data_part = multipart::Part::stream(body);
let resp = self.request("upload", data_part, true).await?;
resp.error_for_status()?;
Ok(())
}
pub(crate) async fn upload<P>(&mut self, col_path: &Path, progress_fn: P) -> Result<()>
where
P: Fn(&FullSyncProgress) + Send + Sync + 'static,
{
let file = tokio::fs::File::open(col_path).await?;
let total_bytes = file.metadata().await?.len() as usize;
let wrap1 = ProgressWrapper {
reader: file,
progress_fn,
progress: FullSyncProgress {
transferred_bytes: 0,
total_bytes,
},
};
let wrap2 = async_compression::stream::GzipEncoder::new(wrap1);
let body = Body::wrap_stream(wrap2);
self.upload_inner(body).await?;
Ok(())
}
}
use futures::{
ready,
task::{Context, Poll},
};
use pin_project::pin_project;
use std::pin::Pin;
use tokio::io::AsyncRead;
#[pin_project]
struct ProgressWrapper<S, P> {
#[pin]
reader: S,
progress_fn: P,
progress: FullSyncProgress,
}
impl<S, P> Stream for ProgressWrapper<S, P>
where
S: AsyncRead,
P: Fn(&FullSyncProgress),
{
type Item = std::result::Result<Bytes, std::io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut buf = vec![0; 16 * 1024];
let this = self.project();
match ready!(this.reader.poll_read(cx, &mut buf)) {
Ok(0) => Poll::Ready(None),
Ok(size) => {
buf.resize(size, 0);
this.progress.transferred_bytes += size;
(this.progress_fn)(&this.progress);
Poll::Ready(Some(Ok(Bytes::from(buf))))
}
Err(e) => Poll::Ready(Some(Err(e))),
}
}
}
fn endpoint(suffix: &str) -> String {
if let Ok(endpoint) = std::env::var("SYNC_ENDPOINT") {
endpoint
} else {
format!("https://sync{}.ankiweb.net/sync/", suffix)
}
} }
#[cfg(test)] #[cfg(test)]
@ -197,10 +329,8 @@ mod test {
use crate::err::SyncErrorKind; use crate::err::SyncErrorKind;
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
static ENDPOINT: &'static str = "https://sync.ankiweb.net/sync/";
async fn http_client_inner(username: String, password: String) -> Result<()> { async fn http_client_inner(username: String, password: String) -> Result<()> {
let mut syncer = HTTPSyncClient::new(ENDPOINT); let mut syncer = HTTPSyncClient::new("");
assert!(matches!( assert!(matches!(
syncer.login("nosuchuser", "nosuchpass").await, syncer.login("nosuchuser", "nosuchpass").await,
@ -223,20 +353,13 @@ mod test {
}) })
)); ));
let input = StartIn { let _graves = syncer.start(Usn(1), 0, true).await?;
minimum_usn: Usn(0),
minutes_west: 0,
client_is_newer: true,
client_graves: None,
};
let _graves = syncer.start(&input).await?;
// aborting should now work // aborting should now work
syncer.abort().await?; syncer.abort().await?;
// start again, and continue // start again, and continue
let _graves = syncer.start(&input).await?; let _graves = syncer.start(Usn(0), 0, true).await?;
syncer.apply_graves(Graves::default()).await?; syncer.apply_graves(Graves::default()).await?;
@ -268,6 +391,21 @@ mod test {
syncer.finish().await?; syncer.finish().await?;
use tempfile::tempdir;
let dir = tempdir()?;
let out_path = syncer
.download(&dir.path(), |progress| {
println!("progress: {:?}", progress);
})
.await?;
syncer
.upload(&out_path.path(), |progress| {
println!("progress {:?}", progress);
})
.await?;
Ok(()) Ok(())
} }
@ -280,6 +418,7 @@ mod test {
} }
}; };
let pass = std::env::var("TEST_SYNC_PASS").unwrap(); let pass = std::env::var("TEST_SYNC_PASS").unwrap();
env_logger::init();
let mut rt = Runtime::new().unwrap(); let mut rt = Runtime::new().unwrap();
rt.block_on(http_client_inner(user, pass)) rt.block_on(http_client_inner(user, pass))

View File

@ -14,18 +14,20 @@ use crate::{
}; };
use flate2::write::GzEncoder; use flate2::write::GzEncoder;
use flate2::Compression; use flate2::Compression;
use futures::StreamExt;
use reqwest::{multipart, Client, Response}; use reqwest::{multipart, Client, Response};
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use serde_tuple::Serialize_tuple; use serde_tuple::Serialize_tuple;
use std::io::prelude::*; use std::io::prelude::*;
use std::{collections::HashMap, time::Duration}; use std::{collections::HashMap, path::Path, time::Duration};
use tempfile::NamedTempFile;
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct SyncProgress {} pub struct SyncProgress {}
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
struct ServerMeta { pub struct ServerMeta {
#[serde(rename = "mod")] #[serde(rename = "mod")]
modified: TimestampMillis, modified: TimestampMillis,
#[serde(rename = "scm")] #[serde(rename = "scm")]
@ -42,20 +44,20 @@ struct ServerMeta {
} }
#[derive(Serialize, Deserialize, Debug, Default)] #[derive(Serialize, Deserialize, Debug, Default)]
struct Graves { pub struct Graves {
cards: Vec<CardID>, cards: Vec<CardID>,
decks: Vec<DeckID>, decks: Vec<DeckID>,
notes: Vec<NoteID>, notes: Vec<NoteID>,
} }
#[derive(Serialize_tuple, Deserialize, Debug, Default)] #[derive(Serialize_tuple, Deserialize, Debug, Default)]
struct DecksAndConfig { pub struct DecksAndConfig {
decks: Vec<DeckSchema11>, decks: Vec<DeckSchema11>,
config: Vec<DeckConfSchema11>, config: Vec<DeckConfSchema11>,
} }
#[derive(Serialize, Deserialize, Debug, Default)] #[derive(Serialize, Deserialize, Debug, Default)]
struct Changes { pub struct Changes {
#[serde(rename = "models")] #[serde(rename = "models")]
notetypes: Vec<NoteTypeSchema11>, notetypes: Vec<NoteTypeSchema11>,
#[serde(rename = "decks")] #[serde(rename = "decks")]
@ -70,18 +72,18 @@ struct Changes {
} }
#[derive(Serialize, Deserialize, Debug, Default)] #[derive(Serialize, Deserialize, Debug, Default)]
struct Chunk { pub struct Chunk {
done: bool, done: bool,
#[serde(skip_serializing_if = "Vec::is_empty")] #[serde(skip_serializing_if = "Vec::is_empty", default)]
revlog: Vec<ReviewLogEntry>, revlog: Vec<ReviewLogEntry>,
#[serde(skip_serializing_if = "Vec::is_empty")] #[serde(skip_serializing_if = "Vec::is_empty", default)]
cards: Vec<CardEntry>, cards: Vec<CardEntry>,
#[serde(skip_serializing_if = "Vec::is_empty")] #[serde(skip_serializing_if = "Vec::is_empty", default)]
notes: Vec<NoteEntry>, notes: Vec<NoteEntry>,
} }
#[derive(Serialize_tuple, Deserialize, Debug)] #[derive(Serialize_tuple, Deserialize, Debug)]
struct ReviewLogEntry { pub struct ReviewLogEntry {
id: TimestampMillis, id: TimestampMillis,
cid: CardID, cid: CardID,
usn: Usn, usn: Usn,
@ -97,7 +99,7 @@ struct ReviewLogEntry {
} }
#[derive(Serialize_tuple, Deserialize, Debug)] #[derive(Serialize_tuple, Deserialize, Debug)]
struct NoteEntry { pub struct NoteEntry {
id: NoteID, id: NoteID,
guid: String, guid: String,
#[serde(rename = "mid")] #[serde(rename = "mid")]
@ -114,7 +116,7 @@ struct NoteEntry {
} }
#[derive(Serialize_tuple, Deserialize, Debug)] #[derive(Serialize_tuple, Deserialize, Debug)]
struct CardEntry { pub struct CardEntry {
id: CardID, id: CardID,
nid: NoteID, nid: NoteID,
did: DeckID, did: DeckID,
@ -136,7 +138,7 @@ struct CardEntry {
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
struct SanityCheckOut { pub struct SanityCheckOut {
status: SanityCheckStatus, status: SanityCheckStatus,
#[serde(rename = "c")] #[serde(rename = "c")]
client: Option<SanityCheckCounts>, client: Option<SanityCheckCounts>,
@ -152,7 +154,7 @@ enum SanityCheckStatus {
} }
#[derive(Serialize_tuple, Deserialize, Debug)] #[derive(Serialize_tuple, Deserialize, Debug)]
struct SanityCheckCounts { pub struct SanityCheckCounts {
counts: SanityCheckDueCounts, counts: SanityCheckDueCounts,
cards: u32, cards: u32,
notes: u32, notes: u32,
@ -165,8 +167,14 @@ struct SanityCheckCounts {
} }
#[derive(Serialize_tuple, Deserialize, Debug)] #[derive(Serialize_tuple, Deserialize, Debug)]
struct SanityCheckDueCounts { pub struct SanityCheckDueCounts {
new: u32, new: u32,
learn: u32, learn: u32,
review: u32, review: u32,
} }
#[derive(Debug, Default)]
pub struct FullSyncProgress {
transferred_bytes: usize,
total_bytes: usize,
}