Add support for custom certificates (#3203)

* Add support for custom certificates

* Update lints

* Update licenses

* Changes after feedback

* More changes
This commit is contained in:
Voczi 2024-05-24 11:57:54 +02:00 committed by GitHub
parent 1957566c39
commit 9e3a34f17f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 101 additions and 12 deletions

View File

@ -174,6 +174,7 @@ Escape0707 <tothesong@gmail.com>
Loudwig <https://github.com/Loudwig>
Wu Yi-Wei <https://github.com/Ianwu0812>
********************
The text of the 3 clause BSD license follows:

23
Cargo.lock generated
View File

@ -144,6 +144,7 @@ dependencies = [
"regex",
"reqwest",
"rusqlite",
"rustls-pemfile 2.1.2",
"scopeguard",
"serde",
"serde-aux",
@ -511,6 +512,12 @@ version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]]
name = "base64"
version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "base64ct"
version = "1.6.0"
@ -4605,7 +4612,7 @@ dependencies = [
"pin-project-lite",
"rustls 0.21.11",
"rustls-native-certs",
"rustls-pemfile",
"rustls-pemfile 1.0.4",
"serde",
"serde_json",
"serde_urlencoded",
@ -4786,7 +4793,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00"
dependencies = [
"openssl-probe",
"rustls-pemfile",
"rustls-pemfile 1.0.4",
"schannel",
"security-framework",
]
@ -4800,6 +4807,16 @@ dependencies = [
"base64 0.21.7",
]
[[package]]
name = "rustls-pemfile"
version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d"
dependencies = [
"base64 0.22.1",
"rustls-pki-types",
]
[[package]]
name = "rustls-pki-types"
version = "1.3.1"
@ -6265,7 +6282,7 @@ dependencies = [
"mime_guess",
"percent-encoding",
"pin-project",
"rustls-pemfile",
"rustls-pemfile 1.0.4",
"scoped-tls",
"serde",
"serde_json",

View File

@ -111,6 +111,7 @@ rand = "0.8.5"
regex = "1.10.3"
reqwest = { version = "0.11.24", default-features = false, features = ["json", "socks", "stream", "multipart"] }
rusqlite = { version = "0.30.0", features = ["trace", "functions", "collation", "bundled"] }
rustls-pemfile = "2.1.2"
scopeguard = "1.2.0"
serde = { version = "1.0.197", features = ["derive"] }
serde-aux = "4.5.0"

View File

@ -287,6 +287,15 @@
"license_file": null,
"description": "encodes and decodes base64 as bytes or utf8"
},
{
"name": "base64",
"version": "0.22.1",
"authors": "Marshall Pierce <marshall@mpierce.org>",
"repository": "https://github.com/marshallpierce/rust-base64",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "encodes and decodes base64 as bytes or utf8"
},
{
"name": "base64ct",
"version": "1.6.0",
@ -3257,6 +3266,15 @@
"license_file": null,
"description": "Basic .pem file parser for keys and certificates"
},
{
"name": "rustls-pemfile",
"version": "2.1.2",
"authors": null,
"repository": "https://github.com/rustls/pemfile",
"license": "Apache-2.0 OR ISC OR MIT",
"license_file": null,
"description": "Basic .pem file parser for keys and certificates"
},
{
"name": "rustls-pki-types",
"version": "1.3.1",

View File

@ -45,6 +45,7 @@ message BackendError {
// Originated from and usually specific to the OS.
OS_ERROR = 20;
SCHEDULER_UPGRADE_REQUIRED = 21;
INVALID_CERTIFICATE_FORMAT = 22;
}
// error description, usually localized, suitable for displaying to the user

View File

@ -23,6 +23,7 @@ service BackendSyncService {
rpc SyncCollection(SyncCollectionRequest) returns (SyncCollectionResponse);
rpc FullUploadOrDownload(FullUploadOrDownloadRequest) returns (generic.Empty);
rpc AbortSync(generic.Empty) returns (generic.Empty);
rpc SetCustomCertificate(generic.String) returns (generic.Bool);
}
message SyncAuth {

View File

@ -83,6 +83,7 @@ rand.workspace = true
regex.workspace = true
reqwest.workspace = true
rusqlite.workspace = true
rustls-pemfile.workspace = true
scopeguard.workspace = true
serde.workspace = true
serde-aux.workspace = true

View File

@ -49,6 +49,7 @@ impl AnkiError {
AnkiError::WindowsError { .. } => Kind::OsError,
AnkiError::SchedulerUpgradeRequired => Kind::SchedulerUpgradeRequired,
AnkiError::FsrsInsufficientReviews { .. } => Kind::InvalidInput,
AnkiError::InvalidCertificateFormat => Kind::InvalidCertificateFormat,
};
anki_proto::backend::BackendError {

View File

@ -56,7 +56,7 @@ pub struct BackendInner {
state: Mutex<BackendState>,
backup_task: Mutex<Option<JoinHandle<Result<()>>>>,
media_sync_task: Mutex<Option<JoinHandle<Result<()>>>>,
web_client: OnceCell<Client>,
web_client: Mutex<Option<Client>>,
}
#[derive(Default)]
@ -91,7 +91,7 @@ impl Backend {
state: Mutex::new(BackendState::default()),
backup_task: Mutex::new(None),
media_sync_task: Mutex::new(None),
web_client: OnceCell::new(),
web_client: Mutex::new(None),
}))
}
@ -137,10 +137,46 @@ impl Backend {
.clone()
}
fn web_client(&self) -> &Client {
#[cfg(feature = "rustls")]
fn set_custom_certificate_inner(&self, cert_str: String) -> Result<()> {
use std::io::Cursor;
use std::io::Read;
use reqwest::Certificate;
let mut web_client = self.web_client.lock().unwrap();
if cert_str.is_empty() {
let _ = web_client.insert(Client::builder().http1_only().build().unwrap());
return Ok(());
}
if rustls_pemfile::read_all(Cursor::new(cert_str.as_bytes()).by_ref()).count() != 1 {
return Err(AnkiError::InvalidCertificateFormat);
}
if let Ok(certificate) = Certificate::from_pem(cert_str.as_bytes()) {
if let Ok(new_client) = Client::builder()
.use_rustls_tls()
.add_root_certificate(certificate)
.http1_only()
.build()
{
let _ = web_client.insert(new_client);
return Ok(());
}
}
Err(AnkiError::InvalidCertificateFormat)
}
fn web_client(&self) -> Client {
// currently limited to http1, as nginx doesn't support http2 proxies
self.web_client
.get_or_init(|| Client::builder().http1_only().build().unwrap())
let mut web_client = self.web_client.lock().unwrap();
return web_client
.get_or_insert_with(|| Client::builder().http1_only().build().unwrap())
.clone();
}
fn db_command(&self, input: &[u8]) -> Result<Vec<u8>> {

View File

@ -154,6 +154,16 @@ impl crate::services::BackendSyncService for Backend {
)?;
Ok(())
}
fn set_custom_certificate(
&self,
_input: anki_proto::generic::String,
) -> Result<anki_proto::generic::Bool> {
#[cfg(feature = "rustls")]
return Ok(self.set_custom_certificate_inner(_input.val).is_ok().into());
#[cfg(not(feature = "rustls"))]
return Ok(false.into());
}
}
impl Backend {
@ -284,7 +294,7 @@ impl Backend {
input.username,
input.password,
input.endpoint.clone(),
self.web_client().clone(),
self.web_client(),
);
let abortable_sync = Abortable::new(sync_fut, abort_reg);
let ret = match rt.block_on(abortable_sync) {
@ -323,7 +333,7 @@ impl Backend {
let rt = self.runtime_handle();
let time_at_check_begin = TimestampSecs::now();
let local = self.with_col(|col| col.sync_meta())?;
let mut client = HttpSyncClient::new(auth, self.web_client().clone());
let mut client = HttpSyncClient::new(auth, self.web_client());
let state = rt.block_on(online_sync_status_check(local, &mut client))?;
{
let mut guard = self.state.lock().unwrap();
@ -348,7 +358,7 @@ impl Backend {
let (_guard, abort_reg) = self.sync_abort_handle()?;
let rt = self.runtime_handle();
let client = self.web_client().clone();
let client = self.web_client();
let auth2 = auth.clone();
let ret = self.with_col(|col| {

View File

@ -122,6 +122,7 @@ pub enum AnkiError {
},
FsrsUnableToDetermineDesiredRetention,
SchedulerUpgradeRequired,
InvalidCertificateFormat,
}
// error helpers
@ -169,7 +170,8 @@ impl AnkiError {
| AnkiError::Existing
| AnkiError::InvalidServiceIndex
| AnkiError::InvalidMethodIndex
| AnkiError::UndoEmpty => format!("{:?}", self),
| AnkiError::UndoEmpty
| AnkiError::InvalidCertificateFormat => format!("{:?}", self),
AnkiError::FileIoError { source } => source.message(),
AnkiError::InvalidInput { source } => source.message(),
AnkiError::NotFound { source } => source.message(tr),