full upload/download HTTP code
This commit is contained in:
parent
4fcb10bfa9
commit
529e89f48e
@ -23,7 +23,7 @@ unicode-normalization = "0.1.12"
|
||||
tempfile = "3.1.0"
|
||||
serde = "1.0.104"
|
||||
serde_json = "1.0.45"
|
||||
tokio = "0.2.11"
|
||||
tokio = { version = "0.2.11", features = ["fs"] }
|
||||
serde_derive = "1.0.104"
|
||||
zip = "0.5.4"
|
||||
serde_tuple = "0.4.0"
|
||||
@ -47,6 +47,8 @@ rand = "0.7.3"
|
||||
num-integer = "0.1.42"
|
||||
itertools = "0.9.0"
|
||||
flate2 = "1.0.14"
|
||||
pin-project = "0.4.17"
|
||||
async-compression = { version = "0.3.4", features = ["stream", "gzip"] }
|
||||
|
||||
[target.'cfg(target_vendor="apple")'.dependencies.rusqlite]
|
||||
version = "0.23.1"
|
||||
@ -57,12 +59,15 @@ version = "0.23.1"
|
||||
features = ["trace", "functions", "collation", "bundled"]
|
||||
|
||||
[target.'cfg(linux)'.dependencies]
|
||||
reqwest = { version = "0.10.1", features = ["json", "socks", "native-tls-vendored"] }
|
||||
reqwest = { version = "0.10.1", features = ["json", "socks", "stream", "native-tls-vendored"] }
|
||||
|
||||
[target.'cfg(not(linux))'.dependencies]
|
||||
reqwest = { version = "0.10.1", features = ["json", "socks"] }
|
||||
reqwest = { version = "0.10.1", features = ["json", "socks", "stream" ] }
|
||||
|
||||
[build-dependencies]
|
||||
prost-build = "0.6.1"
|
||||
fluent-syntax = "0.9.2"
|
||||
|
||||
[dev-dependencies]
|
||||
env_logger = "0.7.1"
|
||||
|
||||
|
@ -2,13 +2,16 @@
|
||||
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
|
||||
|
||||
use super::*;
|
||||
use bytes::Bytes;
|
||||
use futures::Stream;
|
||||
use reqwest::Body;
|
||||
|
||||
static SYNC_VERSION: u8 = 10;
|
||||
pub struct HTTPSyncClient<'a> {
|
||||
pub struct HTTPSyncClient {
|
||||
hkey: Option<String>,
|
||||
skey: String,
|
||||
client: Client,
|
||||
endpoint: &'a str,
|
||||
endpoint: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@ -68,14 +71,15 @@ struct SanityCheckIn {
|
||||
#[derive(Serialize)]
|
||||
struct Empty {}
|
||||
|
||||
impl HTTPSyncClient<'_> {
|
||||
pub fn new<'a>(endpoint: &'a str) -> HTTPSyncClient<'a> {
|
||||
impl HTTPSyncClient {
|
||||
pub fn new<'a>(endpoint_suffix: &str) -> HTTPSyncClient {
|
||||
let client = Client::builder()
|
||||
.connect_timeout(Duration::from_secs(30))
|
||||
.timeout(Duration::from_secs(60))
|
||||
.build()
|
||||
.unwrap();
|
||||
let skey = guid();
|
||||
let endpoint = endpoint(&endpoint_suffix);
|
||||
HTTPSyncClient {
|
||||
hkey: None,
|
||||
skey,
|
||||
@ -84,7 +88,7 @@ impl HTTPSyncClient<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
async fn json_request<T>(&self, method: &str, json: &T) -> Result<Response>
|
||||
async fn json_request<T>(&self, method: &str, json: &T, timeout_long: bool) -> Result<Response>
|
||||
where
|
||||
T: serde::Serialize,
|
||||
{
|
||||
@ -94,7 +98,7 @@ impl HTTPSyncClient<'_> {
|
||||
gz.write_all(&req_json)?;
|
||||
let part = multipart::Part::bytes(gz.finish()?);
|
||||
|
||||
self.request(method, part).await
|
||||
self.request(method, part, timeout_long).await
|
||||
}
|
||||
|
||||
async fn json_request_deserialized<T, T2>(&self, method: &str, json: &T) -> Result<T2>
|
||||
@ -102,14 +106,19 @@ impl HTTPSyncClient<'_> {
|
||||
T: Serialize,
|
||||
T2: DeserializeOwned,
|
||||
{
|
||||
self.json_request(method, json)
|
||||
self.json_request(method, json, false)
|
||||
.await?
|
||||
.json()
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn request(&self, method: &str, data_part: multipart::Part) -> Result<Response> {
|
||||
async fn request(
|
||||
&self,
|
||||
method: &str,
|
||||
data_part: multipart::Part,
|
||||
timeout_long: bool,
|
||||
) -> Result<Response> {
|
||||
let data_part = data_part.file_name("data");
|
||||
|
||||
let mut form = multipart::Form::new()
|
||||
@ -120,12 +129,16 @@ impl HTTPSyncClient<'_> {
|
||||
}
|
||||
|
||||
let url = format!("{}{}", self.endpoint, method);
|
||||
let req = self.client.post(&url).multipart(form);
|
||||
let mut req = self.client.post(&url).multipart(form);
|
||||
|
||||
if timeout_long {
|
||||
req = req.timeout(Duration::from_secs(60 * 60));
|
||||
}
|
||||
|
||||
req.send().await?.error_for_status().map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn login(&mut self, username: &str, password: &str) -> Result<()> {
|
||||
pub(crate) async fn login(&mut self, username: &str, password: &str) -> Result<()> {
|
||||
let resp: HostKeyOut = self
|
||||
.json_request_deserialized("hostKey", &HostKeyIn { username, password })
|
||||
.await?;
|
||||
@ -138,7 +151,7 @@ impl HTTPSyncClient<'_> {
|
||||
self.hkey.as_ref().unwrap()
|
||||
}
|
||||
|
||||
async fn meta(&mut self) -> Result<ServerMeta> {
|
||||
pub(crate) async fn meta(&self) -> Result<ServerMeta> {
|
||||
let meta_in = MetaIn {
|
||||
sync_version: SYNC_VERSION,
|
||||
client_version: sync_client_version(),
|
||||
@ -146,49 +159,168 @@ impl HTTPSyncClient<'_> {
|
||||
self.json_request_deserialized("meta", &meta_in).await
|
||||
}
|
||||
|
||||
async fn start(&mut self, input: &StartIn) -> Result<Graves> {
|
||||
self.json_request_deserialized("start", input).await
|
||||
pub(crate) async fn start(
|
||||
&self,
|
||||
minimum_usn: Usn,
|
||||
minutes_west: i32,
|
||||
client_is_newer: bool,
|
||||
) -> Result<Graves> {
|
||||
let input = StartIn {
|
||||
minimum_usn,
|
||||
minutes_west,
|
||||
client_is_newer,
|
||||
client_graves: None,
|
||||
};
|
||||
self.json_request_deserialized("start", &input).await
|
||||
}
|
||||
|
||||
async fn apply_graves(&mut self, chunk: Graves) -> Result<()> {
|
||||
pub(crate) async fn apply_graves(&self, chunk: Graves) -> Result<()> {
|
||||
let input = ApplyGravesIn { chunk };
|
||||
let resp = self.json_request("applyGraves", &input).await?;
|
||||
let resp = self.json_request("applyGraves", &input, false).await?;
|
||||
resp.error_for_status()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn apply_changes(&mut self, changes: Changes) -> Result<Changes> {
|
||||
pub(crate) async fn apply_changes(&self, changes: Changes) -> Result<Changes> {
|
||||
let input = ApplyChangesIn { changes };
|
||||
self.json_request_deserialized("applyChanges", &input).await
|
||||
}
|
||||
|
||||
async fn chunk(&mut self) -> Result<Chunk> {
|
||||
pub(crate) async fn chunk(&self) -> Result<Chunk> {
|
||||
self.json_request_deserialized("chunk", &Empty {}).await
|
||||
}
|
||||
|
||||
async fn apply_chunk(&mut self, chunk: Chunk) -> Result<()> {
|
||||
pub(crate) async fn apply_chunk(&self, chunk: Chunk) -> Result<()> {
|
||||
let input = ApplyChunkIn { chunk };
|
||||
let resp = self.json_request("applyChunk", &input).await?;
|
||||
let resp = self.json_request("applyChunk", &input, false).await?;
|
||||
resp.error_for_status()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn sanity_check(&mut self, client: SanityCheckCounts) -> Result<SanityCheckOut> {
|
||||
pub(crate) async fn sanity_check(&self, client: SanityCheckCounts) -> Result<SanityCheckOut> {
|
||||
let input = SanityCheckIn { client };
|
||||
self.json_request_deserialized("sanityCheck2", &input).await
|
||||
}
|
||||
|
||||
async fn finish(&mut self) -> Result<()> {
|
||||
let resp = self.json_request("finish", &Empty {}).await?;
|
||||
pub(crate) async fn finish(&self) -> Result<()> {
|
||||
let resp = self.json_request("finish", &Empty {}, false).await?;
|
||||
resp.error_for_status()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn abort(&mut self) -> Result<()> {
|
||||
let resp = self.json_request("abort", &Empty {}).await?;
|
||||
pub(crate) async fn abort(&self) -> Result<()> {
|
||||
let resp = self.json_request("abort", &Empty {}, false).await?;
|
||||
resp.error_for_status()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn download_inner(
|
||||
&self,
|
||||
) -> Result<(
|
||||
usize,
|
||||
impl Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
|
||||
)> {
|
||||
let resp: reqwest::Response = self.json_request("download", &Empty {}, true).await?;
|
||||
let len = resp.content_length().unwrap_or_default();
|
||||
Ok((len as usize, resp.bytes_stream()))
|
||||
}
|
||||
|
||||
/// Download collection into a temporary file, returning it.
|
||||
/// Caller should persist the file in the correct path after checking it.
|
||||
pub(crate) async fn download<P>(&self, folder: &Path, progress_fn: P) -> Result<NamedTempFile>
|
||||
where
|
||||
P: Fn(&FullSyncProgress),
|
||||
{
|
||||
let mut temp_file = NamedTempFile::new_in(folder)?;
|
||||
let (size, mut stream) = self.download_inner().await?;
|
||||
let mut progress = FullSyncProgress {
|
||||
transferred_bytes: 0,
|
||||
total_bytes: size,
|
||||
};
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk?;
|
||||
temp_file.write_all(&chunk)?;
|
||||
progress.transferred_bytes += chunk.len();
|
||||
progress_fn(&progress);
|
||||
}
|
||||
|
||||
Ok(temp_file)
|
||||
}
|
||||
|
||||
async fn upload_inner(&self, body: Body) -> Result<()> {
|
||||
let data_part = multipart::Part::stream(body);
|
||||
let resp = self.request("upload", data_part, true).await?;
|
||||
resp.error_for_status()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn upload<P>(&mut self, col_path: &Path, progress_fn: P) -> Result<()>
|
||||
where
|
||||
P: Fn(&FullSyncProgress) + Send + Sync + 'static,
|
||||
{
|
||||
let file = tokio::fs::File::open(col_path).await?;
|
||||
let total_bytes = file.metadata().await?.len() as usize;
|
||||
let wrap1 = ProgressWrapper {
|
||||
reader: file,
|
||||
progress_fn,
|
||||
progress: FullSyncProgress {
|
||||
transferred_bytes: 0,
|
||||
total_bytes,
|
||||
},
|
||||
};
|
||||
let wrap2 = async_compression::stream::GzipEncoder::new(wrap1);
|
||||
let body = Body::wrap_stream(wrap2);
|
||||
self.upload_inner(body).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
use futures::{
|
||||
ready,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use pin_project::pin_project;
|
||||
use std::pin::Pin;
|
||||
use tokio::io::AsyncRead;
|
||||
|
||||
#[pin_project]
|
||||
struct ProgressWrapper<S, P> {
|
||||
#[pin]
|
||||
reader: S,
|
||||
progress_fn: P,
|
||||
progress: FullSyncProgress,
|
||||
}
|
||||
|
||||
impl<S, P> Stream for ProgressWrapper<S, P>
|
||||
where
|
||||
S: AsyncRead,
|
||||
P: Fn(&FullSyncProgress),
|
||||
{
|
||||
type Item = std::result::Result<Bytes, std::io::Error>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let mut buf = vec![0; 16 * 1024];
|
||||
let this = self.project();
|
||||
match ready!(this.reader.poll_read(cx, &mut buf)) {
|
||||
Ok(0) => Poll::Ready(None),
|
||||
Ok(size) => {
|
||||
buf.resize(size, 0);
|
||||
this.progress.transferred_bytes += size;
|
||||
(this.progress_fn)(&this.progress);
|
||||
Poll::Ready(Some(Ok(Bytes::from(buf))))
|
||||
}
|
||||
Err(e) => Poll::Ready(Some(Err(e))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn endpoint(suffix: &str) -> String {
|
||||
if let Ok(endpoint) = std::env::var("SYNC_ENDPOINT") {
|
||||
endpoint
|
||||
} else {
|
||||
format!("https://sync{}.ankiweb.net/sync/", suffix)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -197,10 +329,8 @@ mod test {
|
||||
use crate::err::SyncErrorKind;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
static ENDPOINT: &'static str = "https://sync.ankiweb.net/sync/";
|
||||
|
||||
async fn http_client_inner(username: String, password: String) -> Result<()> {
|
||||
let mut syncer = HTTPSyncClient::new(ENDPOINT);
|
||||
let mut syncer = HTTPSyncClient::new("");
|
||||
|
||||
assert!(matches!(
|
||||
syncer.login("nosuchuser", "nosuchpass").await,
|
||||
@ -223,20 +353,13 @@ mod test {
|
||||
})
|
||||
));
|
||||
|
||||
let input = StartIn {
|
||||
minimum_usn: Usn(0),
|
||||
minutes_west: 0,
|
||||
client_is_newer: true,
|
||||
client_graves: None,
|
||||
};
|
||||
|
||||
let _graves = syncer.start(&input).await?;
|
||||
let _graves = syncer.start(Usn(1), 0, true).await?;
|
||||
|
||||
// aborting should now work
|
||||
syncer.abort().await?;
|
||||
|
||||
// start again, and continue
|
||||
let _graves = syncer.start(&input).await?;
|
||||
let _graves = syncer.start(Usn(0), 0, true).await?;
|
||||
|
||||
syncer.apply_graves(Graves::default()).await?;
|
||||
|
||||
@ -268,6 +391,21 @@ mod test {
|
||||
|
||||
syncer.finish().await?;
|
||||
|
||||
use tempfile::tempdir;
|
||||
|
||||
let dir = tempdir()?;
|
||||
let out_path = syncer
|
||||
.download(&dir.path(), |progress| {
|
||||
println!("progress: {:?}", progress);
|
||||
})
|
||||
.await?;
|
||||
|
||||
syncer
|
||||
.upload(&out_path.path(), |progress| {
|
||||
println!("progress {:?}", progress);
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -280,6 +418,7 @@ mod test {
|
||||
}
|
||||
};
|
||||
let pass = std::env::var("TEST_SYNC_PASS").unwrap();
|
||||
env_logger::init();
|
||||
|
||||
let mut rt = Runtime::new().unwrap();
|
||||
rt.block_on(http_client_inner(user, pass))
|
||||
|
@ -14,18 +14,20 @@ use crate::{
|
||||
};
|
||||
use flate2::write::GzEncoder;
|
||||
use flate2::Compression;
|
||||
use futures::StreamExt;
|
||||
use reqwest::{multipart, Client, Response};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use serde_tuple::Serialize_tuple;
|
||||
use std::io::prelude::*;
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
use std::{collections::HashMap, path::Path, time::Duration};
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct SyncProgress {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct ServerMeta {
|
||||
pub struct ServerMeta {
|
||||
#[serde(rename = "mod")]
|
||||
modified: TimestampMillis,
|
||||
#[serde(rename = "scm")]
|
||||
@ -42,20 +44,20 @@ struct ServerMeta {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Default)]
|
||||
struct Graves {
|
||||
pub struct Graves {
|
||||
cards: Vec<CardID>,
|
||||
decks: Vec<DeckID>,
|
||||
notes: Vec<NoteID>,
|
||||
}
|
||||
|
||||
#[derive(Serialize_tuple, Deserialize, Debug, Default)]
|
||||
struct DecksAndConfig {
|
||||
pub struct DecksAndConfig {
|
||||
decks: Vec<DeckSchema11>,
|
||||
config: Vec<DeckConfSchema11>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Default)]
|
||||
struct Changes {
|
||||
pub struct Changes {
|
||||
#[serde(rename = "models")]
|
||||
notetypes: Vec<NoteTypeSchema11>,
|
||||
#[serde(rename = "decks")]
|
||||
@ -70,18 +72,18 @@ struct Changes {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Default)]
|
||||
struct Chunk {
|
||||
pub struct Chunk {
|
||||
done: bool,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
#[serde(skip_serializing_if = "Vec::is_empty", default)]
|
||||
revlog: Vec<ReviewLogEntry>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
#[serde(skip_serializing_if = "Vec::is_empty", default)]
|
||||
cards: Vec<CardEntry>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
#[serde(skip_serializing_if = "Vec::is_empty", default)]
|
||||
notes: Vec<NoteEntry>,
|
||||
}
|
||||
|
||||
#[derive(Serialize_tuple, Deserialize, Debug)]
|
||||
struct ReviewLogEntry {
|
||||
pub struct ReviewLogEntry {
|
||||
id: TimestampMillis,
|
||||
cid: CardID,
|
||||
usn: Usn,
|
||||
@ -97,7 +99,7 @@ struct ReviewLogEntry {
|
||||
}
|
||||
|
||||
#[derive(Serialize_tuple, Deserialize, Debug)]
|
||||
struct NoteEntry {
|
||||
pub struct NoteEntry {
|
||||
id: NoteID,
|
||||
guid: String,
|
||||
#[serde(rename = "mid")]
|
||||
@ -114,7 +116,7 @@ struct NoteEntry {
|
||||
}
|
||||
|
||||
#[derive(Serialize_tuple, Deserialize, Debug)]
|
||||
struct CardEntry {
|
||||
pub struct CardEntry {
|
||||
id: CardID,
|
||||
nid: NoteID,
|
||||
did: DeckID,
|
||||
@ -136,7 +138,7 @@ struct CardEntry {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct SanityCheckOut {
|
||||
pub struct SanityCheckOut {
|
||||
status: SanityCheckStatus,
|
||||
#[serde(rename = "c")]
|
||||
client: Option<SanityCheckCounts>,
|
||||
@ -152,7 +154,7 @@ enum SanityCheckStatus {
|
||||
}
|
||||
|
||||
#[derive(Serialize_tuple, Deserialize, Debug)]
|
||||
struct SanityCheckCounts {
|
||||
pub struct SanityCheckCounts {
|
||||
counts: SanityCheckDueCounts,
|
||||
cards: u32,
|
||||
notes: u32,
|
||||
@ -165,8 +167,14 @@ struct SanityCheckCounts {
|
||||
}
|
||||
|
||||
#[derive(Serialize_tuple, Deserialize, Debug)]
|
||||
struct SanityCheckDueCounts {
|
||||
pub struct SanityCheckDueCounts {
|
||||
new: u32,
|
||||
learn: u32,
|
||||
review: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct FullSyncProgress {
|
||||
transferred_bytes: usize,
|
||||
total_bytes: usize,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user