Implement HttpError directly so that sources can be extracted properly

When disabling the default snafu source handling, <dyn Error>.source()
doesn't work.
This commit is contained in:
Damien Elmes 2023-02-18 23:42:20 +10:00
parent d5772ac43a
commit 7ebf8dd84a
5 changed files with 50 additions and 58 deletions

View File

@ -14,7 +14,6 @@ use crate::sync::collection::normal::SyncActionRequired;
use crate::sync::collection::protocol::SyncProtocol;
use crate::sync::error::HttpError;
use crate::sync::error::HttpResult;
use crate::sync::error::HttpSnafu;
use crate::sync::error::OrHttpErr;
use crate::sync::http_client::HttpSyncClient;
use crate::sync::request::IntoSyncRequest;
@ -139,13 +138,12 @@ impl Collection {
pub fn server_meta(req: MetaRequest, col: &mut Collection) -> HttpResult<SyncMeta> {
if !matches!(req.sync_version, SYNC_VERSION_MIN..=SYNC_VERSION_MAX) {
return HttpSnafu {
return Err(HttpError {
// old clients expected this code
code: StatusCode::NOT_IMPLEMENTED,
context: "unsupported version",
context: "unsupported version".into(),
source: None,
}
.fail();
});
}
let mut meta = col.sync_meta().or_internal_err("sync meta")?;
if meta.v2_scheduler_or_later && req.sync_version < SYNC_VERSION_09_V2_SCHEDULER {

View File

@ -1,28 +1,37 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::error::Error;
use std::fmt::Display;
use std::fmt::Formatter;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::response::Redirect;
use axum::response::Response;
use snafu::OptionExt;
use snafu::Snafu;
pub type HttpResult<T, E = HttpError> = std::result::Result<T, E>;
pub type HttpResult<T, E = HttpError> = Result<T, E>;
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
#[derive(Debug)]
pub struct HttpError {
pub code: StatusCode,
pub context: String,
// snafu's automatic error conversion only supports Option if
// the whatever trait is derived, and deriving whatever means we
// can't have extra fields like `code`. Even without Option, the
// error conversion requires us to manually box the error, so we end
// up having to disable the default behaviour and add the error to the
// snafu ourselves
#[snafu(source(false))]
pub source: Option<Box<dyn std::error::Error + Send + Sync>>,
pub source: Option<Box<dyn Error + Send + Sync>>,
}
impl Display for HttpError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{} (code={})", self.context, self.code.as_u16())
}
}
impl Error for HttpError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match &self.source {
None => None,
Some(err) => Some(err.as_ref()),
}
}
}
impl HttpError {
@ -114,7 +123,7 @@ pub trait OrHttpErr {
impl<T, E> OrHttpErr for Result<T, E>
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
type Value = T;
@ -123,13 +132,10 @@ where
code: StatusCode,
context: impl Into<String>,
) -> Result<Self::Value, HttpError> {
self.map_err(|err| {
HttpSnafu {
code,
context: context.into(),
source: err.into(),
}
.build()
self.map_err(|err| HttpError {
code,
context: context.into(),
source: Some(err.into()),
})
}
}
@ -142,9 +148,9 @@ impl<T> OrHttpErr for Option<T> {
code: StatusCode,
context: impl Into<String>,
) -> Result<Self::Value, HttpError> {
self.context(HttpSnafu {
self.ok_or_else(|| HttpError {
code,
context,
context: context.into(),
source: None,
})
}

View File

@ -27,7 +27,6 @@ use tokio_util::io::StreamReader;
use crate::error::Result;
use crate::sync::error::HttpError;
use crate::sync::error::HttpResult;
use crate::sync::error::HttpSnafu;
use crate::sync::error::OrHttpErr;
use crate::sync::request::header_and_stream::decode_zstd_body_stream_for_client;
use crate::sync::request::header_and_stream::encode_zstd_body_stream;
@ -149,11 +148,11 @@ impl IoMonitor {
data = response_body_stream => Ok(data?),
// timeout
_ = self.timeout(stall_duration) => {
HttpSnafu {
Err(HttpError {
code: StatusCode::REQUEST_TIMEOUT,
context: "timeout monitor",
context: "timeout monitor".into(),
source: None,
}.fail()
})
}
}
}

View File

@ -18,7 +18,6 @@ use crate::sync::collection::progress::FullSyncProgressFn;
use crate::sync::collection::protocol::AsSyncEndpoint;
use crate::sync::error::HttpError;
use crate::sync::error::HttpResult;
use crate::sync::error::HttpSnafu;
use crate::sync::http_client::io_monitor::IoMonitor;
use crate::sync::login::SyncAuth;
use crate::sync::request::header_and_stream::SyncHeader;
@ -113,13 +112,12 @@ impl HttpSyncClient {
impl From<Error> for HttpError {
fn from(err: Error) -> Self {
HttpSnafu {
HttpError {
// we should perhaps make this Optional instead
code: err.status().unwrap_or(StatusCode::SEE_OTHER),
context: "from reqwest",
context: "from reqwest".into(),
source: Some(Box::new(err) as _),
}
.build()
}
}

View File

@ -21,8 +21,8 @@ use serde_derive::Serialize;
use tokio::io::AsyncReadExt;
use tokio_util::io::ReaderStream;
use crate::sync::error::HttpError;
use crate::sync::error::HttpResult;
use crate::sync::error::HttpSnafu;
use crate::sync::error::OrHttpErr;
use crate::sync::request::SyncRequest;
use crate::sync::request::MAXIMUM_SYNC_PAYLOAD_BYTES_UNCOMPRESSED;
@ -81,25 +81,19 @@ where
data.map_err(|e| std::io::Error::new(ErrorKind::ConnectionAborted, format!("{e}"))),
);
let reader = async_compression::tokio::bufread::ZstdDecoder::new(reader);
ReaderStream::new(reader).map_err(|err| {
HttpSnafu {
code: StatusCode::BAD_REQUEST,
context: "decode zstd body",
source: Some(Box::new(err) as _),
}
.build()
ReaderStream::new(reader).map_err(|err| HttpError {
code: StatusCode::BAD_REQUEST,
context: "decode zstd body".into(),
source: Some(Box::new(err) as _),
})
}
pub fn encode_zstd_body(data: Vec<u8>) -> impl Stream<Item = HttpResult<Bytes>> + Unpin {
let enc = async_compression::tokio::bufread::ZstdEncoder::new(Cursor::new(data));
ReaderStream::new(enc).map_err(|err| {
HttpSnafu {
code: StatusCode::INTERNAL_SERVER_ERROR,
context: "encode zstd body",
source: Some(Box::new(err) as _),
}
.build()
ReaderStream::new(enc).map_err(|err| HttpError {
code: StatusCode::INTERNAL_SERVER_ERROR,
context: "encode zstd body".into(),
source: Some(Box::new(err) as _),
})
}
@ -112,13 +106,10 @@ where
data.map_err(|e| std::io::Error::new(ErrorKind::ConnectionAborted, format!("{e}"))),
);
let reader = async_compression::tokio::bufread::ZstdEncoder::new(reader);
ReaderStream::new(reader).map_err(|err| {
HttpSnafu {
code: StatusCode::BAD_REQUEST,
context: "encode zstd body",
source: Some(Box::new(err) as _),
}
.build()
ReaderStream::new(reader).map_err(|err| HttpError {
code: StatusCode::BAD_REQUEST,
context: "encode zstd body".into(),
source: Some(Box::new(err) as _),
})
}