Skip to content

Commit 40eb0eb

Browse files
committed
used shared helper for setting up Edge connections
1 parent 0641a89 commit 40eb0eb

6 files changed

Lines changed: 116 additions & 101 deletions

File tree

Cargo.lock

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/defguard_core/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ defguard_web_ui = { workspace = true }
1616
defguard_version = { workspace = true }
1717
model_derive = { workspace = true }
1818
defguard_certs = { workspace = true }
19+
defguard_grpc_tls = { workspace = true }
1920
defguard_static_ip = { workspace = true }
2021

2122
# external dependencies
@@ -30,6 +31,7 @@ bytes = { workspace = true }
3031
chrono = { workspace = true }
3132
futures = { workspace = true }
3233
humantime = { workspace = true }
34+
hyper-rustls = { workspace = true }
3335
# match version used by sqlx
3436
ipnetwork = { workspace = true }
3537
jsonwebkey = { workspace = true }

crates/defguard_core/src/handlers/component_setup.rs

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use std::{
2-
collections::VecDeque,
2+
collections::{HashMap, VecDeque},
33
convert::Infallible,
44
sync::{Arc, Mutex, PoisonError},
55
time::Duration,
@@ -29,6 +29,7 @@ use defguard_common::{
2929
types::proxy::ProxyControlMessage,
3030
utils::strip_scheme,
3131
};
32+
use defguard_grpc_tls::certs::proxy_mtls_channel;
3233
use defguard_proto::{
3334
common::{CertBundle, CertificateInfo},
3435
gateway::gateway_setup_client::GatewaySetupClient,
@@ -43,7 +44,10 @@ use reqwest::Url;
4344
use serde::{Deserialize, Serialize};
4445
use sqlx::PgPool;
4546
use tokio::{
46-
sync::mpsc::{Sender, UnboundedReceiver, UnboundedSender, unbounded_channel},
47+
sync::{
48+
mpsc::{Sender, UnboundedReceiver, UnboundedSender, unbounded_channel},
49+
oneshot, watch,
50+
},
4751
time::{Instant, sleep_until, timeout},
4852
};
4953
use tokio_stream::StreamExt;
@@ -1175,8 +1179,7 @@ fn public_proxy_hostname() -> Result<String, String> {
11751179
/// collected during the ACME run (sent by the proxy via an [`AcmeLogs`] event).
11761180
async fn call_proxy_trigger_acme(
11771181
pool: &PgPool,
1178-
proxy_host: &str,
1179-
proxy_port: u16,
1182+
proxy: &Proxy<Id>,
11801183
domain: String,
11811184
account_credentials_json: String,
11821185
progress_tx: UnboundedSender<AcmeStep>,
@@ -1191,32 +1194,29 @@ async fn call_proxy_trigger_acme(
11911194
)
11921195
})?;
11931196

1194-
let cert_pem = der_to_pem(&ca_cert_der, defguard_certs::PemLabel::Certificate)
1195-
.map_err(|e| (format!("Failed to convert CA cert to PEM: {e}"), Vec::new()))?;
1196-
1197-
let endpoint_str = format!("https://{proxy_host}:{proxy_port}");
1198-
let endpoint = Endpoint::from_shared(endpoint_str)
1199-
.map_err(|e| (format!("Failed to build Edge endpoint: {e}"), Vec::new()))?
1200-
.http2_keep_alive_interval(Duration::from_secs(5))
1201-
.tcp_keepalive(Some(Duration::from_secs(5)))
1202-
.keep_alive_while_idle(true);
1203-
1204-
let tls = ClientTlsConfig::new().ca_certificate(Certificate::from_pem(cert_pem));
1205-
let endpoint = endpoint.tls_config(tls).map_err(|e| {
1197+
let cert_serial = proxy.certificate_serial.as_deref().ok_or_else(|| {
12061198
(
1207-
format!("Failed to configure TLS for Edge endpoint: {e}"),
1199+
"Edge certificate serial not provisioned".to_string(),
12081200
Vec::new(),
12091201
)
12101202
})?;
12111203

1204+
// Seed a one-shot serial map so the rustls verifier validates the server cert serial.
1205+
let (_, certs_rx) = watch::channel(Arc::new(HashMap::from([(
1206+
proxy.id,
1207+
cert_serial.to_string(),
1208+
)])));
1209+
1210+
let channel = proxy_mtls_channel(proxy, &ca_cert_der, certs_rx)
1211+
.map_err(|e| (format!("Failed to build mTLS channel: {e}"), Vec::new()))?;
1212+
12121213
let version = Version::parse(VERSION)
12131214
.map_err(|e| (format!("Failed to parse core version: {e}"), Vec::new()))?;
12141215
let version_interceptor = ClientVersionInterceptor::new(version);
12151216

1216-
let mut client =
1217-
ProxyClient::with_interceptor(endpoint.connect_lazy(), move |req: Request<()>| {
1218-
version_interceptor.clone().call(req)
1219-
});
1217+
let mut client = ProxyClient::with_interceptor(channel, move |req: Request<()>| {
1218+
version_interceptor.clone().call(req)
1219+
});
12201220

12211221
let mut stream = client
12221222
.trigger_acme(AcmeChallenge {
@@ -1300,7 +1300,7 @@ pub async fn stream_proxy_acme(
13001300

13011301
let account_credentials_json = certs.acme_account_credentials.clone().unwrap_or_default();
13021302

1303-
let proxies = match Proxy::list(&pool).await {
1303+
let proxies = match Proxy::all_enabled(&pool).await {
13041304
Ok(list) => list,
13051305
Err(e) => {
13061306
yield Ok(acme_error_event(
@@ -1323,8 +1323,8 @@ pub async fn stream_proxy_acme(
13231323
return;
13241324
};
13251325

1326-
let proxy_host = proxy.address.clone();
1327-
let proxy_port = proxy.port as u16;
1326+
let proxy_host = &proxy.address;
1327+
let proxy_port = proxy.port;
13281328
info!(
13291329
"Triggering ACME HTTP-01 via Edge gRPC TriggerAcme for domain: {domain} \
13301330
Edge={proxy_host}:{proxy_port}"
@@ -1333,16 +1333,15 @@ pub async fn stream_proxy_acme(
13331333
let (progress_tx, mut progress_rx) =
13341334
unbounded_channel::<AcmeStep>();
13351335
let (result_tx, result_rx) =
1336-
tokio::sync::oneshot::channel::<Result<(String, String, String), (String, Vec<String>)>>();
1336+
oneshot::channel::<Result<(String, String, String), (String, Vec<String>)>>();
13371337

13381338
let pool_clone = pool.clone();
13391339
let domain_clone = domain.clone();
13401340
let acct_creds_clone = account_credentials_json.clone();
13411341
tokio::spawn(async move {
13421342
let result = call_proxy_trigger_acme(
13431343
&pool_clone,
1344-
&proxy_host,
1345-
proxy_port,
1344+
&proxy,
13461345
domain_clone,
13471346
acct_creds_clone,
13481347
progress_tx,

crates/defguard_grpc_tls/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ rust-version.workspace = true
1010
[dependencies]
1111
defguard_common.workspace = true
1212
http = "1.1"
13+
hyper-rustls.workspace = true
1314
rustls = { version = "0.23", features = ["ring"] }
1415
thiserror.workspace = true
1516
tokio.workspace = true

crates/defguard_grpc_tls/src/certs.rs

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
//! - A lightweight in-memory cache (refreshed periodically) avoids database access
99
//! during the handshake and keeps verification synchronous.
1010
11-
use std::{collections::HashMap, sync::Arc};
11+
use std::{collections::HashMap, sync::Arc, time::Duration};
1212

13-
use defguard_common::db::Id;
13+
use defguard_common::db::{Id, models::proxy::Proxy};
14+
use hyper_rustls::HttpsConnectorBuilder;
1415
use rustls::{
1516
CertificateError, DistinguishedName, Error as RustlsError, RootCertStore, SignatureScheme,
1617
client::{
@@ -22,10 +23,14 @@ use rustls::{
2223
};
2324
use thiserror::Error;
2425
use tokio::sync::watch;
25-
use tonic::transport::{Certificate, Identity, ServerTlsConfig};
26+
use tonic::transport::{Certificate, Channel, Endpoint, Identity, ServerTlsConfig};
2627
use tracing::error;
2728
use x509_parser::parse_x509_certificate;
2829

30+
use crate::connector::HttpsSchemeConnector;
31+
32+
const TEN_SECS: Duration = Duration::from_secs(10);
33+
2934
/// Errors that can occur while building a TLS config with a pinned verifier.
3035
#[derive(Debug, Error)]
3136
pub enum CertConfigError {
@@ -208,3 +213,53 @@ pub fn client_config(
208213
)));
209214
Ok(config)
210215
}
216+
217+
/// Build an mTLS [`Channel`] to a proxy using its stored per-component client certificate.
218+
///
219+
/// * `proxy` — the full `Proxy<Id>` row from the database; `core_client_cert_der`,
220+
/// `core_client_cert_key_der`, and `certificate_serial` must all be `Some`.
221+
/// * `ca_cert_der` — the core CA certificate in DER form, used as the only trusted root.
222+
/// * `certs_rx` — watch channel carrying the current `{ proxy_id → cert_serial }` map.
223+
/// Pass a long-lived receiver for persistent connections (serial revocation is picked up
224+
/// dynamically) or a one-shot channel seeded with the proxy's current serial for
225+
/// short-lived calls.
226+
///
227+
/// The returned channel uses an `http://` endpoint scheme; TLS is applied by the
228+
/// internal [`HttpsSchemeConnector`](crate::connector::HttpsSchemeConnector).
229+
pub fn proxy_mtls_channel(
230+
proxy: &Proxy<Id>,
231+
ca_cert_der: &[u8],
232+
certs_rx: watch::Receiver<Arc<HashMap<Id, String>>>,
233+
) -> Result<Channel, CertConfigError> {
234+
let cert_der = proxy.core_client_cert_der.as_deref().ok_or_else(|| {
235+
CertConfigError::TlsConfig(format!(
236+
"core client certificate not provisioned for proxy id={}",
237+
proxy.id
238+
))
239+
})?;
240+
let key_der = proxy.core_client_cert_key_der.as_deref().ok_or_else(|| {
241+
CertConfigError::TlsConfig(format!(
242+
"core client certificate key not provisioned for proxy id={}",
243+
proxy.id
244+
))
245+
})?;
246+
247+
let tls_config = client_config(ca_cert_der, certs_rx, proxy.id, cert_der, key_der)?;
248+
249+
let connector = HttpsConnectorBuilder::new()
250+
.with_tls_config(tls_config)
251+
.https_only()
252+
.enable_http2()
253+
.build();
254+
let connector = HttpsSchemeConnector::new(connector);
255+
256+
// Use http:// scheme — the HttpsSchemeConnector rewrites it to https:// internally.
257+
let endpoint_str = format!("http://{}:{}", proxy.address, proxy.port);
258+
let endpoint = Endpoint::from_shared(endpoint_str)
259+
.map_err(|e| CertConfigError::TlsConfig(format!("invalid proxy endpoint URL: {e}")))?
260+
.http2_keep_alive_interval(TEN_SECS)
261+
.tcp_keepalive(Some(TEN_SECS))
262+
.keep_alive_while_idle(true);
263+
264+
Ok(endpoint.connect_with_connector_lazy(connector))
265+
}

0 commit comments

Comments
 (0)