Implement HttpError directly so that sources can be extracted properly

When disabling the default snafu source handling, <dyn Error>.source()
doesn't work.
This commit is contained in:
Damien Elmes 2023-02-18 23:42:20 +10:00
parent d5772ac43a
commit 7ebf8dd84a
5 changed files with 50 additions and 58 deletions

View File

@ -14,7 +14,6 @@ use crate::sync::collection::normal::SyncActionRequired;
use crate::sync::collection::protocol::SyncProtocol; use crate::sync::collection::protocol::SyncProtocol;
use crate::sync::error::HttpError; use crate::sync::error::HttpError;
use crate::sync::error::HttpResult; use crate::sync::error::HttpResult;
use crate::sync::error::HttpSnafu;
use crate::sync::error::OrHttpErr; use crate::sync::error::OrHttpErr;
use crate::sync::http_client::HttpSyncClient; use crate::sync::http_client::HttpSyncClient;
use crate::sync::request::IntoSyncRequest; use crate::sync::request::IntoSyncRequest;
@ -139,13 +138,12 @@ impl Collection {
pub fn server_meta(req: MetaRequest, col: &mut Collection) -> HttpResult<SyncMeta> { pub fn server_meta(req: MetaRequest, col: &mut Collection) -> HttpResult<SyncMeta> {
if !matches!(req.sync_version, SYNC_VERSION_MIN..=SYNC_VERSION_MAX) { if !matches!(req.sync_version, SYNC_VERSION_MIN..=SYNC_VERSION_MAX) {
return HttpSnafu { return Err(HttpError {
// old clients expected this code // old clients expected this code
code: StatusCode::NOT_IMPLEMENTED, code: StatusCode::NOT_IMPLEMENTED,
context: "unsupported version", context: "unsupported version".into(),
source: None, source: None,
} });
.fail();
} }
let mut meta = col.sync_meta().or_internal_err("sync meta")?; let mut meta = col.sync_meta().or_internal_err("sync meta")?;
if meta.v2_scheduler_or_later && req.sync_version < SYNC_VERSION_09_V2_SCHEDULER { if meta.v2_scheduler_or_later && req.sync_version < SYNC_VERSION_09_V2_SCHEDULER {

View File

@ -1,28 +1,37 @@
// Copyright: Ankitects Pty Ltd and contributors // Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html // License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::error::Error;
use std::fmt::Display;
use std::fmt::Formatter;
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::response::Redirect; use axum::response::Redirect;
use axum::response::Response; use axum::response::Response;
use snafu::OptionExt;
use snafu::Snafu;
pub type HttpResult<T, E = HttpError> = std::result::Result<T, E>; pub type HttpResult<T, E = HttpError> = Result<T, E>;
#[derive(Debug, Snafu)] #[derive(Debug)]
#[snafu(visibility(pub))]
pub struct HttpError { pub struct HttpError {
pub code: StatusCode, pub code: StatusCode,
pub context: String, pub context: String,
// snafu's automatic error conversion only supports Option if pub source: Option<Box<dyn Error + Send + Sync>>,
// the whatever trait is derived, and deriving whatever means we }
// can't have extra fields like `code`. Even without Option, the
// error conversion requires us to manually box the error, so we end impl Display for HttpError {
// up having to disable the default behaviour and add the error to the fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
// snafu ourselves write!(f, "{} (code={})", self.context, self.code.as_u16())
#[snafu(source(false))] }
pub source: Option<Box<dyn std::error::Error + Send + Sync>>, }
impl Error for HttpError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match &self.source {
None => None,
Some(err) => Some(err.as_ref()),
}
}
} }
impl HttpError { impl HttpError {
@ -114,7 +123,7 @@ pub trait OrHttpErr {
impl<T, E> OrHttpErr for Result<T, E> impl<T, E> OrHttpErr for Result<T, E>
where where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>, E: Into<Box<dyn Error + Send + Sync + 'static>>,
{ {
type Value = T; type Value = T;
@ -123,13 +132,10 @@ where
code: StatusCode, code: StatusCode,
context: impl Into<String>, context: impl Into<String>,
) -> Result<Self::Value, HttpError> { ) -> Result<Self::Value, HttpError> {
self.map_err(|err| { self.map_err(|err| HttpError {
HttpSnafu { code,
code, context: context.into(),
context: context.into(), source: Some(err.into()),
source: err.into(),
}
.build()
}) })
} }
} }
@ -142,9 +148,9 @@ impl<T> OrHttpErr for Option<T> {
code: StatusCode, code: StatusCode,
context: impl Into<String>, context: impl Into<String>,
) -> Result<Self::Value, HttpError> { ) -> Result<Self::Value, HttpError> {
self.context(HttpSnafu { self.ok_or_else(|| HttpError {
code, code,
context, context: context.into(),
source: None, source: None,
}) })
} }

View File

@ -27,7 +27,6 @@ use tokio_util::io::StreamReader;
use crate::error::Result; use crate::error::Result;
use crate::sync::error::HttpError; use crate::sync::error::HttpError;
use crate::sync::error::HttpResult; use crate::sync::error::HttpResult;
use crate::sync::error::HttpSnafu;
use crate::sync::error::OrHttpErr; use crate::sync::error::OrHttpErr;
use crate::sync::request::header_and_stream::decode_zstd_body_stream_for_client; use crate::sync::request::header_and_stream::decode_zstd_body_stream_for_client;
use crate::sync::request::header_and_stream::encode_zstd_body_stream; use crate::sync::request::header_and_stream::encode_zstd_body_stream;
@ -149,11 +148,11 @@ impl IoMonitor {
data = response_body_stream => Ok(data?), data = response_body_stream => Ok(data?),
// timeout // timeout
_ = self.timeout(stall_duration) => { _ = self.timeout(stall_duration) => {
HttpSnafu { Err(HttpError {
code: StatusCode::REQUEST_TIMEOUT, code: StatusCode::REQUEST_TIMEOUT,
context: "timeout monitor", context: "timeout monitor".into(),
source: None, source: None,
}.fail() })
} }
} }
} }

View File

@ -18,7 +18,6 @@ use crate::sync::collection::progress::FullSyncProgressFn;
use crate::sync::collection::protocol::AsSyncEndpoint; use crate::sync::collection::protocol::AsSyncEndpoint;
use crate::sync::error::HttpError; use crate::sync::error::HttpError;
use crate::sync::error::HttpResult; use crate::sync::error::HttpResult;
use crate::sync::error::HttpSnafu;
use crate::sync::http_client::io_monitor::IoMonitor; use crate::sync::http_client::io_monitor::IoMonitor;
use crate::sync::login::SyncAuth; use crate::sync::login::SyncAuth;
use crate::sync::request::header_and_stream::SyncHeader; use crate::sync::request::header_and_stream::SyncHeader;
@ -113,13 +112,12 @@ impl HttpSyncClient {
impl From<Error> for HttpError { impl From<Error> for HttpError {
fn from(err: Error) -> Self { fn from(err: Error) -> Self {
HttpSnafu { HttpError {
// we should perhaps make this Optional instead // we should perhaps make this Optional instead
code: err.status().unwrap_or(StatusCode::SEE_OTHER), code: err.status().unwrap_or(StatusCode::SEE_OTHER),
context: "from reqwest", context: "from reqwest".into(),
source: Some(Box::new(err) as _), source: Some(Box::new(err) as _),
} }
.build()
} }
} }

View File

@ -21,8 +21,8 @@ use serde_derive::Serialize;
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
use tokio_util::io::ReaderStream; use tokio_util::io::ReaderStream;
use crate::sync::error::HttpError;
use crate::sync::error::HttpResult; use crate::sync::error::HttpResult;
use crate::sync::error::HttpSnafu;
use crate::sync::error::OrHttpErr; use crate::sync::error::OrHttpErr;
use crate::sync::request::SyncRequest; use crate::sync::request::SyncRequest;
use crate::sync::request::MAXIMUM_SYNC_PAYLOAD_BYTES_UNCOMPRESSED; use crate::sync::request::MAXIMUM_SYNC_PAYLOAD_BYTES_UNCOMPRESSED;
@ -81,25 +81,19 @@ where
data.map_err(|e| std::io::Error::new(ErrorKind::ConnectionAborted, format!("{e}"))), data.map_err(|e| std::io::Error::new(ErrorKind::ConnectionAborted, format!("{e}"))),
); );
let reader = async_compression::tokio::bufread::ZstdDecoder::new(reader); let reader = async_compression::tokio::bufread::ZstdDecoder::new(reader);
ReaderStream::new(reader).map_err(|err| { ReaderStream::new(reader).map_err(|err| HttpError {
HttpSnafu { code: StatusCode::BAD_REQUEST,
code: StatusCode::BAD_REQUEST, context: "decode zstd body".into(),
context: "decode zstd body", source: Some(Box::new(err) as _),
source: Some(Box::new(err) as _),
}
.build()
}) })
} }
pub fn encode_zstd_body(data: Vec<u8>) -> impl Stream<Item = HttpResult<Bytes>> + Unpin { pub fn encode_zstd_body(data: Vec<u8>) -> impl Stream<Item = HttpResult<Bytes>> + Unpin {
let enc = async_compression::tokio::bufread::ZstdEncoder::new(Cursor::new(data)); let enc = async_compression::tokio::bufread::ZstdEncoder::new(Cursor::new(data));
ReaderStream::new(enc).map_err(|err| { ReaderStream::new(enc).map_err(|err| HttpError {
HttpSnafu { code: StatusCode::INTERNAL_SERVER_ERROR,
code: StatusCode::INTERNAL_SERVER_ERROR, context: "encode zstd body".into(),
context: "encode zstd body", source: Some(Box::new(err) as _),
source: Some(Box::new(err) as _),
}
.build()
}) })
} }
@ -112,13 +106,10 @@ where
data.map_err(|e| std::io::Error::new(ErrorKind::ConnectionAborted, format!("{e}"))), data.map_err(|e| std::io::Error::new(ErrorKind::ConnectionAborted, format!("{e}"))),
); );
let reader = async_compression::tokio::bufread::ZstdEncoder::new(reader); let reader = async_compression::tokio::bufread::ZstdEncoder::new(reader);
ReaderStream::new(reader).map_err(|err| { ReaderStream::new(reader).map_err(|err| HttpError {
HttpSnafu { code: StatusCode::BAD_REQUEST,
code: StatusCode::BAD_REQUEST, context: "encode zstd body".into(),
context: "encode zstd body", source: Some(Box::new(err) as _),
source: Some(Box::new(err) as _),
}
.build()
}) })
} }