Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/core/src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ mod tests {
identity: &IDENTITY,
operation: &OPERATION,
bucket_config: Some(Cow::Borrowed(&*BUCKET_CONFIG)),
headers: &*HEADERS,
headers: &HEADERS,
source_ip: None,
request_id: "test-request-id",
list_rewrite: None,
Expand Down
2 changes: 2 additions & 0 deletions crates/oidc-provider/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ rsa.workspace = true
sha2.workspace = true
tracing.workspace = true
uuid.workspace = true
# Per-key async lock for single-flight credential refresh (see `cache.rs`).
futures.workspace = true

[dev-dependencies]
tokio = { workspace = true, features = ["rt", "macros"] }
Expand Down
222 changes: 176 additions & 46 deletions crates/oidc-provider/src/cache.rs
Original file line number Diff line number Diff line change
@@ -1,97 +1,227 @@
//! TTL credential cache.
//! Credential cache with single-flight refresh.
//!
//! Caches [`BackendCredentials`] by key, evicting entries that are within a
//! safety margin of expiration. This avoids redundant STS calls when the
//! same backend is accessed repeatedly within a short window.
//! Caches [`BackendCredentials`] by key so the proxy doesn't re-mint and
//! re-exchange on every request. Beyond a plain TTL cache it:
//!
//! - **serves while fresh** — returns a cached value directly while it is
//! comfortably valid,
//! - **proactively refreshes** — once a value is within [`REFRESH_LEAD_SECS`]
//! of expiry, the next access re-mints it, so a credential is never handed
//! out about to expire mid-request, and
//! - **single-flights** — while one caller is minting for a key, concurrent
//! callers for that *same* key await the in-flight result instead of each
//! launching their own exchange. A cold-cache burst collapses to one STS call.
//!
//! The fetch happens through a caller-supplied closure ([`get_or_fetch`]), so
//! the cache never needs to know how credentials are minted, and a runtime can
//! layer an additional cache tier (e.g. the Cloudflare Cache API) inside the
//! closure. See `docs/architecture/caching.md`.

use std::collections::HashMap;
use std::future::Future;
use std::sync::{Arc, Mutex};

use chrono::{Duration, Utc};
use futures::lock::Mutex as AsyncMutex;

use crate::BackendCredentials;

/// Safety margin before expiration — credentials are considered expired
/// this many seconds before their actual `expires_at`.
const EXPIRY_MARGIN_SECS: i64 = 60;
/// Refresh a cached credential once it is within this many seconds of expiry,
/// so it is never handed out about to expire mid-request.
const REFRESH_LEAD_SECS: i64 = 60;

/// One async-locked slot per key. The per-key [`AsyncMutex`] is what serializes
/// (single-flights) refreshes; the value is shared via `Arc`.
type Slot = Arc<AsyncMutex<Option<Arc<BackendCredentials>>>>;

/// Thread-safe TTL cache for cloud credentials.
/// Thread-safe credential cache with proactive refresh and single-flight.
///
/// `Clone` shares the same underlying store (the entries map is behind an
/// `Arc`), so a cloned [`OidcCredentialProvider`](crate::OidcCredentialProvider)
/// keeps hitting the same cache — letting a runtime hold the provider in a
/// `Clone` shares the same underlying store (the slot map is behind an `Arc`),
/// so a cloned [`OidcCredentialProvider`](crate::OidcCredentialProvider) keeps
/// hitting the same cache — letting a runtime hold the provider in a
/// shared/`static` slot and reuse it across requests instead of re-minting and
/// re-exchanging every time.
#[derive(Clone, Default)]
pub struct CredentialCache {
entries: Arc<Mutex<HashMap<String, Arc<BackendCredentials>>>>,
/// One slot per key. The outer `Mutex` only guards insertion into the map
/// and is never held across an `.await`; the per-key [`AsyncMutex`] inside
/// each [`Slot`] is what single-flights refreshes.
slots: Arc<Mutex<HashMap<String, Slot>>>,
}

impl CredentialCache {
/// Create an empty credential cache.
pub fn new() -> Self {
Self {
entries: Arc::new(Mutex::new(HashMap::new())),
slots: Arc::new(Mutex::new(HashMap::new())),
}
}

/// Retrieve cached credentials if they are still valid.
pub fn get(&self, key: &str) -> Option<Arc<BackendCredentials>> {
let entries = self.entries.lock().unwrap();
if let Some(creds) = entries.get(key) {
let margin = Duration::seconds(EXPIRY_MARGIN_SECS);
if creds.expiration > Utc::now() + margin {
return Some(creds.clone());
/// Return cached credentials for `key` if still fresh, otherwise run `fetch`
/// (single-flighted) to obtain and cache new ones.
///
/// A cached value is fresh while `now < expiration - REFRESH_LEAD_SECS`.
///
/// Single-flight: while one caller is running `fetch` for a key, concurrent
/// callers for that same key block on the per-key lock; when it releases
/// they observe the freshly-cached value and return it without calling their
/// own `fetch`.
pub async fn get_or_fetch<F, Fut, E>(
&self,
key: &str,
fetch: F,
) -> Result<Arc<BackendCredentials>, E>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<Arc<BackendCredentials>, E>>,
{
let slot = self.slot(key);
let mut guard = slot.lock().await;

if let Some(creds) = guard.as_ref() {
if is_fresh(creds) {
return Ok(creds.clone());
}
}
None

let fresh = fetch().await?;
*guard = Some(fresh.clone());
Ok(fresh)
}

/// Store credentials in the cache.
pub fn put(&self, key: String, creds: Arc<BackendCredentials>) {
let mut entries = self.entries.lock().unwrap();
entries.insert(key, creds);
fn slot(&self, key: &str) -> Slot {
self.slots
.lock()
.expect("credential cache mutex poisoned")
.entry(key.to_string())
.or_insert_with(|| Arc::new(AsyncMutex::new(None)))
.clone()
}
}

/// A credential is fresh while it is more than [`REFRESH_LEAD_SECS`] from expiry.
fn is_fresh(creds: &BackendCredentials) -> bool {
creds.expiration > Utc::now() + Duration::seconds(REFRESH_LEAD_SECS)
}

#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};

fn make_creds(expires_in_secs: i64) -> BackendCredentials {
BackendCredentials {
fn creds(expires_in_secs: i64) -> Arc<BackendCredentials> {
Arc::new(BackendCredentials {
access_key_id: "AKID".into(),
secret_access_key: "secret".into(),
session_token: "token".into(),
expiration: Utc::now() + Duration::seconds(expires_in_secs),
}
})
}

#[test]
fn cache_returns_valid_entry() {
#[tokio::test]
async fn fetches_on_miss() {
let cache = CredentialCache::new();
let creds = Arc::new(make_creds(600));
cache.put("role-a".into(), creds.clone());

let got = cache.get("role-a");
assert!(got.is_some());
assert_eq!(got.unwrap().access_key_id, "AKID");
let got = cache
.get_or_fetch("k", || async { Ok::<_, ()>(creds(600)) })
.await
.unwrap();
assert_eq!(got.access_key_id, "AKID");
}

#[test]
fn cache_evicts_expired_entry() {
#[tokio::test]
async fn reuses_while_fresh() {
let cache = CredentialCache::new();
// Expires in 30 seconds — within the 60-second margin
let creds = Arc::new(make_creds(30));
cache.put("role-b".into(), creds);
cache
.get_or_fetch("k", || async { Ok::<_, ()>(creds(600)) })
.await
.unwrap();
// Well outside the 60s refresh lead → must not re-fetch.
let got = cache
.get_or_fetch::<_, _, ()>("k", || async {
panic!("must not fetch while cached creds are fresh")
})
.await
.unwrap();
assert_eq!(got.access_key_id, "AKID");
}

let got = cache.get("role-b");
assert!(got.is_none());
#[tokio::test]
async fn refreshes_within_lead_window() {
let cache = CredentialCache::new();
// Expires in 30s — inside the 60s refresh lead → due for refresh.
cache
.get_or_fetch("k", || async { Ok::<_, ()>(creds(30)) })
.await
.unwrap();
let got = cache
.get_or_fetch("k", || async {
Ok::<_, ()>(Arc::new(BackendCredentials {
access_key_id: "REFRESHED".into(),
secret_access_key: "secret".into(),
session_token: "token".into(),
expiration: Utc::now() + Duration::hours(1),
}))
})
.await
.unwrap();
assert_eq!(got.access_key_id, "REFRESHED");
}

#[test]
fn cache_miss_for_unknown_key() {
#[tokio::test]
async fn keys_are_isolated() {
let cache = CredentialCache::new();
assert!(cache.get("unknown").is_none());
cache
.get_or_fetch("a", || async { Ok::<_, ()>(creds(600)) })
.await
.unwrap();
// A different key is a miss → fetches.
let mut fetched = false;
cache
.get_or_fetch("b", || async {
fetched = true;
Ok::<_, ()>(creds(600))
})
.await
.unwrap();
assert!(fetched);
}

#[tokio::test]
async fn single_flights_concurrent_fetches() {
let cache = Arc::new(CredentialCache::new());
let calls = Arc::new(AtomicUsize::new(0));

let one = {
let cache = cache.clone();
let calls = calls.clone();
async move {
cache
.get_or_fetch("k", || async {
calls.fetch_add(1, Ordering::SeqCst);
// Yield while holding the per-key lock so the sibling
// future contends for it — exercising single-flight.
tokio::task::yield_now().await;
Ok::<_, ()>(creds(600))
})
.await
}
};
let two = {
let cache = cache.clone();
let calls = calls.clone();
async move {
cache
.get_or_fetch("k", || async {
calls.fetch_add(1, Ordering::SeqCst);
Ok::<_, ()>(creds(600))
})
.await
}
};

let (a, b) = tokio::join!(one, two);
a.unwrap();
b.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 1, "fetch should run once");
}
}
2 changes: 1 addition & 1 deletion crates/oidc-provider/src/jwks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ mod tests {
assert_eq!(key["use"], "sig");
assert_eq!(key["kid"], "my-kid");
assert!(key["n"].as_str().unwrap().len() > 10);
assert!(key["e"].as_str().unwrap().len() > 0);
assert!(!key["e"].as_str().unwrap().is_empty());
}

#[test]
Expand Down
36 changes: 16 additions & 20 deletions crates/oidc-provider/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,33 +89,29 @@ impl<H: HttpExchange> OidcCredentialProvider<H> {
/// Get credentials for a backend, using cached values when available.
///
/// `exchange` describes how to trade the self-signed JWT for cloud
/// credentials (AWS, Azure, GCP). `cache_key` identifies the backend
/// for caching purposes (e.g. the role ARN).
/// credentials (AWS, Azure, GCP). `cache_key` identifies the backend for
/// caching purposes (e.g. the role ARN).
///
/// Concurrent calls for the same `cache_key` are single-flighted: only one
/// JWT mint + exchange runs, and the rest await its result. A cached value
/// is reused until it nears expiry, then proactively re-minted.
pub async fn get_credentials<E: CredentialExchange<H>>(
&self,
cache_key: &str,
exchange: &E,
subject: &str,
extra_claims: &[(&str, &str)],
) -> Result<Arc<BackendCredentials>, OidcProviderError> {
// Check cache first
if let Some(creds) = self.cache.get(cache_key) {
return Ok(creds);
}

// Mint a JWT
let token = self
.signer
.sign(subject, &self.issuer, &self.audience, extra_claims)?;

// Exchange it for cloud credentials
let creds: BackendCredentials = exchange.exchange(&self.http, &token).await?;
let creds = Arc::new(creds);

// Cache
self.cache.put(cache_key.to_string(), creds.clone());

Ok(creds)
self.cache
.get_or_fetch(cache_key, || async {
// Cache miss (or due for refresh): mint a JWT and exchange it.
let token =
self.signer
.sign(subject, &self.issuer, &self.audience, extra_claims)?;
let creds: BackendCredentials = exchange.exchange(&self.http, &token).await?;
Ok(Arc::new(creds))
})
.await
}

/// Access the underlying signer (e.g. for JWKS generation).
Expand Down
4 changes: 4 additions & 0 deletions docs/.vitepress/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ const adminSidebar = [
text: "Multi-Runtime Design",
link: "/architecture/multi-runtime",
},
{
text: "Caching",
link: "/architecture/caching",
},
],
},
{
Expand Down
Loading
Loading