diff --git a/Cargo.lock b/Cargo.lock index c64419dd..51078d23 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2043,6 +2043,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" +[[package]] +name = "owo-colors" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d211803b9b6b570f68772237e415a029d5a50c65d382910b879fb19d3271f94d" + [[package]] name = "p256" version = "0.13.2" @@ -2272,6 +2278,21 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "progenitor-client" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e8a874cf25a33cac7a01b9c1de87bcfbc8aea93f3156d09dcc3bee516a78926" +dependencies = [ + "bytes", + "futures-core", + "percent-encoding", + "reqwest 0.13.2", + "serde", + "serde_json", + "serde_urlencoded", +] + [[package]] name = "quinn" version = "0.11.9" @@ -2558,6 +2579,7 @@ dependencies = [ "rustls-platform-verifier", "serde", "serde_json", + "serde_urlencoded", "sync_wrapper", "tokio", "tokio-rustls 0.26.4", @@ -2732,7 +2754,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -3331,6 +3353,15 @@ dependencies = [ "libc", ] +[[package]] +name = "tabwriter" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fce91f2f0ec87dff7e6bcbbeb267439aa1188703003c6055193c821487400432" +dependencies = [ + "unicode-width", +] + [[package]] name = "take_mut" version = "0.2.2" @@ -3721,6 +3752,12 @@ version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + [[package]] name = "unicode-xid" version = "0.2.6" @@ -3858,6 +3895,30 @@ dependencies = [ "v-model", ] +[[package]] +name = "v-cli-sdk" +version = "0.2.0" +dependencies = [ + "anyhow", + "chrono", + "clap", + "http", + "http-body-util", + "hyper", + "hyper-util", + "oauth2", + "oauth2-reqwest", + "owo-colors", + "progenitor-client", + "reqwest 0.13.2", + "schemars 0.8.22", + "serde", + "serde_json", + "tabwriter", + "tokio", + "uuid", +] + [[package]] name = "v-model" version = "0.2.0" @@ -3878,6 +3939,7 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tracing", + "url", "uuid", "v-api-installer", ] @@ -4118,7 +4180,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 7d58a2ae..2b476f4d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "v-api-installer", "v-api-param", "v-api-permission-derive", + "v-cli-sdk", "v-model", "xtask" ] @@ -29,14 +30,17 @@ hex = "0.4.3" http = "1" http-body-util = "0.1.3" hyper = "1.9.0" +hyper-util = "0.1" jsonwebtoken = { version = "10.2", features = ["rust_crypto"] } mockall = "0.14.0" newtype-uuid = { version = "1.3.2", features = ["schemars08", "serde", "v4"] } oauth2 = { version = "5.0.0", default-features = false, features = ["rustls-tls"] } oauth2-reqwest = "0.1.0-alpha.3" +owo-colors = "4.2.3" partial-struct = { git = "https://github.com/oxidecomputer/partial-struct" } percent-encoding = "2.3.2" proc-macro2 = "1" +progenitor-client = "0.14.0" quote = "1" rand = "0.10.1" rand_core = "0.10.1" @@ -53,6 +57,7 @@ sha2 = "0.11.0" slog = "2.8.2" steno = { git = "https://github.com/oxidecomputer/steno" } syn = "2" +tabwriter = "1.4.1" tap = "1.0.1" tempfile = "3" thiserror = "2" diff --git a/v-api-permission-derive/src/lib.rs b/v-api-permission-derive/src/lib.rs index 62dfe6ca..0dc29db4 100644 --- a/v-api-permission-derive/src/lib.rs +++ b/v-api-permission-derive/src/lib.rs @@ -439,6 +439,7 @@ fn from_system_permission_tokens( VPermission::ManageMagicLinkClientsAll => Self::ManageMagicLinkClientsAll, VPermission::CreateAccessToken => Self::CreateAccessToken, + VPermission::RetrieveRemoteAccessToken => Self::RetrieveRemoteAccessToken, VPermission::GetSagasAll => Self::GetSagasAll, VPermission::ManageSagasAll => Self::ManageSagasAll, @@ -722,6 +723,7 @@ fn system_permission_tokens() -> TokenStream { ManageMagicLinkClientsAll, CreateAccessToken, + RetrieveRemoteAccessToken, #[v_api(scope(to = "saga:r", from = "saga:r"))] GetSagasAll, diff --git a/v-api/src/config.rs b/v-api/src/config.rs index 22789099..a12a688c 100644 --- a/v-api/src/config.rs +++ b/v-api/src/config.rs @@ -8,20 +8,23 @@ use jsonwebtoken::jwk::{ AlgorithmParameters, CommonParameters, Jwk, KeyAlgorithm, PublicKeyUse, RSAKeyParameters, RSAKeyType, }; +use newtype_uuid::TypedUuid; +use partial_struct::partial; use rsa::{ pkcs1v15::{SigningKey, VerifyingKey}, pkcs8::{DecodePrivateKey, DecodePublicKey}, traits::PublicKeyParts, RsaPrivateKey, RsaPublicKey, }; -use secrecy::ExposeSecret; +use secrecy::{ExposeSecret, SecretString}; use serde::{ de::{self, Visitor}, Deserialize, Deserializer, }; use std::path::PathBuf; use thiserror::Error; -use v_api_param::StringParam; +use v_api_param::{ParamResolutionError, StringParam}; +use v_model::OAuthClientId; use crate::{ authn::{ @@ -151,25 +154,99 @@ pub struct SendGridConfig { pub struct OAuthProviders { pub github: Option, pub google: Option, + pub zendesk: Option, } -#[derive(Debug, Deserialize)] +#[partial(ResolvedOAuthConfig)] +#[derive(Clone, Debug, Deserialize)] pub struct OAuthConfig { - pub device: OAuthDeviceConfig, - pub web: OAuthWebConfig, + #[partial(ResolvedOAuthConfig(retype = Option))] + pub device: Option, + #[partial(ResolvedOAuthConfig(retype = Option))] + pub web: Option, + #[partial(ResolvedOAuthConfig(retype = Option))] + pub proxy_web: Option, } -#[derive(Debug, Deserialize)] +#[partial(ResolvedOAuthDeviceConfig)] +#[derive(Clone, Debug, Deserialize)] pub struct OAuthDeviceConfig { - pub client_id: String, - pub client_secret: StringParam, + pub client_id: TypedUuid, + pub remote_client_id: String, + #[partial(ResolvedOAuthDeviceConfig(retype = SecretString))] + pub remote_client_secret: StringParam, } -#[derive(Debug, Deserialize)] +#[partial(ResolvedOAuthWebConfig)] +#[derive(Clone, Debug, Deserialize)] pub struct OAuthWebConfig { - pub client_id: String, - pub client_secret: StringParam, + pub remote_client_id: String, + #[partial(ResolvedOAuthWebConfig(retype = SecretString))] + pub remote_client_secret: StringParam, +} + +#[partial(ResolvedOAuthWebProxyConfig)] +#[derive(Clone, Debug, Deserialize)] +pub struct OAuthWebProxyConfig { + pub client_id: TypedUuid, pub redirect_uri: String, + pub proxy_port: u16, +} + +impl OAuthConfig { + pub fn resolve( + &self, + base: Option, + ) -> Result { + let device = self + .device + .as_ref() + .and_then(|d| d.resolve(base.clone()).ok()); + let web = self.web.as_ref().and_then(|w| w.resolve(base.clone()).ok()); + let proxy_web = self.proxy_web.as_ref().and_then(|p| p.resolve(base).ok()); + Ok(ResolvedOAuthConfig { + device, + web, + proxy_web, + }) + } +} +impl OAuthDeviceConfig { + pub fn resolve( + &self, + base: Option, + ) -> Result { + let remote_client_secret = self.remote_client_secret.resolve(base)?; + Ok(ResolvedOAuthDeviceConfig { + client_id: self.client_id, + remote_client_id: self.remote_client_id.clone(), + remote_client_secret, + }) + } +} +impl OAuthWebConfig { + pub fn resolve( + &self, + base: Option, + ) -> Result { + let remote_client_secret = self.remote_client_secret.resolve(base)?; + Ok(ResolvedOAuthWebConfig { + remote_client_id: self.remote_client_id.clone(), + remote_client_secret, + }) + } +} +impl OAuthWebProxyConfig { + pub fn resolve( + &self, + _base: Option, + ) -> Result { + Ok(ResolvedOAuthWebProxyConfig { + client_id: self.client_id, + redirect_uri: self.redirect_uri.clone(), + proxy_port: self.proxy_port, + }) + } } impl AsymmetricKey { diff --git a/v-api/src/context/auth.rs b/v-api/src/context/auth.rs index a818820b..dd8a5719 100644 --- a/v-api/src/context/auth.rs +++ b/v-api/src/context/auth.rs @@ -40,9 +40,22 @@ where jwks: JwkSet, signers: Vec, verifiers: Vec, + mut additional_permissions: Vec, ) -> Result { let signers = signers.into_iter().map(Arc::new).collect::>(); let verifiers = verifiers.into_iter().map(Arc::new).collect::>(); + additional_permissions.extend_from_slice(&[ + VPermission::CreateApiUser.into(), + VPermission::GetApiUsersAll.into(), + VPermission::ManageApiUsersAll.into(), + VPermission::GetApiKeysAll.into(), + VPermission::CreateGroup.into(), + VPermission::GetGroupsAll.into(), + VPermission::CreateMapper.into(), + VPermission::GetMappersAll.into(), + VPermission::GetOAuthClientsAll.into(), + VPermission::CreateAccessToken.into(), + ]); Ok(Self { unauthenticated_caller: Caller { id: "00000000-0000-4000-8000-000000000000".parse().unwrap(), @@ -51,19 +64,7 @@ where }, registration_caller: Caller { id: "00000000-0000-4000-8000-000000000001".parse().unwrap(), - permissions: vec![ - VPermission::CreateApiUser, - VPermission::GetApiUsersAll, - VPermission::ManageApiUsersAll, - VPermission::GetApiKeysAll, - VPermission::CreateGroup, - VPermission::GetGroupsAll, - VPermission::CreateMapper, - VPermission::GetMappersAll, - VPermission::GetOAuthClientsAll, - VPermission::CreateAccessToken, - ] - .into(), + permissions: additional_permissions.into(), extensions: HashMap::default(), }, jwt: JwtContext { @@ -215,6 +216,7 @@ mod tests { wrong_verifier.resolve_verifier(None).await.unwrap(), verifier.resolve_verifier(None).await.unwrap(), ], + vec![], ) .unwrap(); diff --git a/v-api/src/context/login.rs b/v-api/src/context/login.rs index c3ddcb49..75d07fa3 100644 --- a/v-api/src/context/login.rs +++ b/v-api/src/context/login.rs @@ -44,14 +44,12 @@ where attempt: LoginAttempt, code: String, ) -> Result { - let mut attempt: NewLoginAttempt = attempt.into(); - attempt.provider_authz_code = Some(code); + let mut update: NewLoginAttempt = attempt.into(); + update.provider_authz_code = Some(code); + update.attempt_state = LoginAttemptState::RemoteAuthenticated; + update.authz_code = Some(CsrfToken::new_random().secret().to_string()); - // TODO: Internal state changes to the struct - attempt.attempt_state = LoginAttemptState::RemoteAuthenticated; - attempt.authz_code = Some(CsrfToken::new_random().secret().to_string()); - - LoginAttemptStore::upsert(&*self.storage, attempt).await + LoginAttemptStore::update_if_state(&*self.storage, update, LoginAttemptState::New).await } pub async fn get_login_attempt( @@ -84,25 +82,38 @@ where Ok(attempts.pop()) } - pub async fn complete_login_attempt( + /// Atomically claim a login attempt by transitioning it from `RemoteAuthenticated` + /// to `Complete`. Returns an error if the attempt has already been claimed by a + /// concurrent request (i.e., the state is no longer `RemoteAuthenticated`). + /// This must be called before exchanging the authorization code with the remote + /// provider to prevent the same code from being used twice (RFC 6749 §4.1.2). + pub async fn claim_login_attempt( &self, attempt: LoginAttempt, ) -> Result { - let mut attempt: NewLoginAttempt = attempt.into(); - attempt.attempt_state = LoginAttemptState::Complete; - LoginAttemptStore::upsert(&*self.storage, attempt).await + let mut update: NewLoginAttempt = attempt.into(); + update.attempt_state = LoginAttemptState::Complete; + + LoginAttemptStore::update_if_state( + &*self.storage, + update, + LoginAttemptState::RemoteAuthenticated, + ) + .await } pub async fn fail_login_attempt( &self, attempt: LoginAttempt, + expected_state: LoginAttemptState, error: Option<&str>, provider_error: Option<&str>, ) -> Result { - let mut attempt: NewLoginAttempt = attempt.into(); - attempt.attempt_state = LoginAttemptState::Failed; - attempt.error = error.map(|s| s.to_string()); - attempt.provider_error = provider_error.map(|s| s.to_string()); - LoginAttemptStore::upsert(&*self.storage, attempt).await + let mut update: NewLoginAttempt = attempt.into(); + update.attempt_state = LoginAttemptState::Failed; + update.error = error.map(|s| s.to_string()); + update.provider_error = provider_error.map(|s| s.to_string()); + + LoginAttemptStore::update_if_state(&*self.storage, update, expected_state).await } } diff --git a/v-api/src/context/mod.rs b/v-api/src/context/mod.rs index b73abdb3..4e5fb41c 100644 --- a/v-api/src/context/mod.rs +++ b/v-api/src/context/mod.rs @@ -764,6 +764,7 @@ pub struct VContextBuilder { keys: Option>, #[cfg(feature = "sagas")] saga: Option<(TypedUuid, Option)>, + additional_builtin_permissions: Vec, } impl Default for VContextBuilder @@ -790,6 +791,7 @@ where keys: None, #[cfg(feature = "sagas")] saga: None, + additional_builtin_permissions: Vec::new(), } } @@ -838,6 +840,11 @@ where self } + pub fn with_additional_builtin_permissions(mut self, permissions: Vec) -> Self { + self.additional_builtin_permissions = permissions; + self + } + pub async fn build(self) -> Result, VContextBuilderError> { if self.storage.is_some() && self.storage_url.is_some() { return Err(VContextBuilderError::ConfigConflict( @@ -900,7 +907,14 @@ where .into_iter() .filter_map(|key| key.ok()) .collect::>(); - let auth_ctx = AuthContext::new(jwt, jwks, signers, verifiers).map_err(|err| { + let auth_ctx = AuthContext::new( + jwt, + jwks, + signers, + verifiers, + self.additional_builtin_permissions, + ) + .map_err(|err| { tracing::error!(?err, "Auth context construction failed"); VContextError::InternalAuthContext })?; @@ -1245,8 +1259,13 @@ pub(crate) mod test_mocks { }; use crate::{ - config::JwtConfig, - endpoints::login::oauth::{google::GoogleOAuthProvider, OAuthProviderName}, + config::{ + JwtConfig, ResolvedOAuthConfig, ResolvedOAuthWebConfig, ResolvedOAuthWebProxyConfig, + }, + endpoints::login::oauth::{ + remote::google::GoogleOAuthProvider, remote::zendesk::ZendeskOAuthProvider, + OAuthProviderName, + }, mapper::DefaultMappingEngine, permissions::VPermission, util::tests::{mock_key, MockKey}, @@ -1259,7 +1278,7 @@ pub(crate) mod test_mocks { pub async fn mock_context(storage: Arc) -> VContext { let MockKey { signer, verifier } = mock_key("test"); let mut ctx = VContextBuilder::::new() - .with_public_url("".to_string()) + .with_public_url("https://test_public_url".to_string()) .with_storage(storage) .with_jwt_expiration(JwtConfig::default().default_expiration) .with_keys(vec![signer, verifier]) @@ -1278,10 +1297,38 @@ pub(crate) mod test_mocks { OAuthProviderName::Google, Box::new(move || { Box::new(GoogleOAuthProvider::new( - "google_device_client_id".to_string(), - "google_device_client_secret".to_string().into(), - "google_web_client_id".to_string(), - "google_web_client_secret".to_string().into(), + ResolvedOAuthConfig { + device: None, + web: Some(ResolvedOAuthWebConfig { + remote_client_id: "google_web_client_id".to_string(), + remote_client_secret: "google_web_client_secret".to_string().into(), + }), + proxy_web: None, + }, + "https://test_public_url".to_string(), + None, + )) + }), + ); + + ctx.auth.insert_oauth_provider( + OAuthProviderName::Zendesk, + Box::new(move || { + Box::new(ZendeskOAuthProvider::new( + ResolvedOAuthConfig { + device: None, + web: Some(ResolvedOAuthWebConfig { + remote_client_id: "zendesk_web_client_id".to_string(), + remote_client_secret: "zendesk_web_client_secret".to_string().into(), + }), + proxy_web: Some(ResolvedOAuthWebProxyConfig { + client_id: TypedUuid::new_v4(), + redirect_uri: "https://test_public_url/pkce-callback".to_string(), + proxy_port: 1234, + }), + }, + "https://test_public_url".to_string(), + "subdomain".to_string(), None, )) }), @@ -1600,6 +1647,18 @@ pub(crate) mod test_mocks { .upsert(attempt) .await } + + async fn update_if_state( + &self, + attempt: NewLoginAttempt, + expected_state: v_model::LoginAttemptState, + ) -> Result { + self.login_attempt_store + .as_ref() + .unwrap() + .update_if_state(attempt, expected_state) + .await + } } #[async_trait] diff --git a/v-api/src/context/oauth.rs b/v-api/src/context/oauth.rs index 41835b60..1a9ee0c4 100644 --- a/v-api/src/context/oauth.rs +++ b/v-api/src/context/oauth.rs @@ -40,15 +40,10 @@ where pub async fn create_oauth_client( &self, caller: &Caller, + id: TypedUuid, ) -> ResourceResult { if caller.can(&VPermission::CreateOAuthClient.into()) { - Ok(OAuthClientStore::upsert( - &*self.storage, - NewOAuthClient { - id: TypedUuid::new_v4(), - }, - ) - .await?) + Ok(OAuthClientStore::upsert(&*self.storage, NewOAuthClient { id }).await?) } else { resource_restricted() } diff --git a/v-api/src/endpoints/handlers.rs b/v-api/src/endpoints/handlers.rs index f2dd46e4..f939ee38 100644 --- a/v-api/src/endpoints/handlers.rs +++ b/v-api/src/endpoints/handlers.rs @@ -69,15 +69,16 @@ mod macros { DeleteOAuthClientSecretPath, GetOAuthClientPath, InitialOAuthClientSecretResponse, }, - code::{ + flow::code::{ authz_code_callback_op, authz_code_exchange_op, authz_code_redirect_op, - OAuthAuthzCodeExchangeBody, OAuthAuthzCodeExchangeResponse, + get_public_pkce_provider_op, + OAuthAuthzCodeExchangeBody, OAuthAuthzCodeExchangeResponse, OAuthAuthzCodeExchangeQuery, OAuthAuthzCodeQuery, OAuthAuthzCodeReturnQuery, }, - device_token::{ + flow::device_token::{ exchange_device_token_op, get_device_provider_op, AccessTokenExchangeRequest, }, - OAuthProviderInfo, OAuthProviderNameParam, + OAuthProviderNameParam, OAuthProviderDeviceInfo, OAuthProviderAuthorizationCodePkceInfo } }, mappers::{ @@ -259,7 +260,7 @@ mod macros { // LOGIN ENDPOINTS - // AUTHZ CODE + // AUTHORIZATION CODE FLOW /// Generate the remote provider login url and redirect the user #[endpoint { @@ -295,15 +296,30 @@ mod macros { }] pub async fn authz_code_exchange( rqctx: RequestContext<$context_type>, + query: Query, path: Path, body: TypedBody, ) -> Result, HttpError> { - authz_code_exchange_op(&rqctx, path, body).await + authz_code_exchange_op(&rqctx, query, path, body).await } - // DEVICE CODE + // AUTHORIZATION CODE PKCE ONLY FLOW - /// Retrieve the metadata about an OAuth provider + /// Retrieve the metadata about an OAuth provider for public PKCE authorization code flow + #[endpoint { + method = GET, + path = "/login/oauth/{provider}/public-pkce" + }] + pub async fn get_web_pkce_provider( + rqctx: RequestContext<$context_type>, + path: Path, + ) -> Result, HttpError> { + get_public_pkce_provider_op(&rqctx, path).await + } + + // DEVICE CODE FLOW + + /// Retrieve the metadata about an OAuth provider for limited input flow #[endpoint { method = GET, path = "/login/oauth/{provider}/device" @@ -311,7 +327,7 @@ mod macros { pub async fn get_device_provider( rqctx: RequestContext<$context_type>, path: Path, - ) -> Result, HttpError> { + ) -> Result, HttpError> { get_device_provider_op(&rqctx, path).await } @@ -756,7 +772,9 @@ mod macros { $api.register(authz_code_exchange) .expect("Failed to register endpoint"); - // OAuth Device Login + // OAuth Login + $api.register(get_web_pkce_provider) + .expect("Failed to register endpoint"); $api.register(get_device_provider) .expect("Failed to register endpoint"); $api.register(exchange_device_token) diff --git a/v-api/src/endpoints/login/magic_link/client.rs b/v-api/src/endpoints/login/magic_link/client.rs index 58a70380..6013eb38 100644 --- a/v-api/src/endpoints/login/magic_link/client.rs +++ b/v-api/src/endpoints/login/magic_link/client.rs @@ -8,6 +8,7 @@ use newtype_uuid::{GenericUuid, TypedUuid}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use tracing::instrument; +use url::Url; use v_model::{ permissions::{Caller, PermissionStorage}, MagicLink, MagicLinkId, MagicLinkRedirectUri, MagicLinkRedirectUriId, MagicLinkSecret, @@ -19,7 +20,7 @@ use crate::{ context::{ApiContext, VContextWithCaller}, permissions::{VAppPermission, VPermission}, secrets::OpenApiSecretString, - util::response::to_internal_error, + util::response::{bad_request, to_internal_error}, VContext, }; @@ -192,6 +193,18 @@ where let (ctx, caller) = rqctx.as_ctx().await?; let path = path.into_inner(); let body = body.into_inner(); + + // Validate that the redirect URI is a well-formed URL before storing it. + // Per RFC 6749 §3.1.2, redirect URIs must be absolute URIs and must not + // include a fragment component. + let parsed = Url::parse(&body.redirect_uri) + .map_err(|_| bad_request("Invalid redirect URI: not a valid URL"))?; + if parsed.fragment().is_some() { + return Err(bad_request( + "Invalid redirect URI: must not contain a fragment", + )); + } + Ok(HttpResponseOk( ctx.magic_link .add_magic_link_redirect_uri(&caller, &path.client_id, &body.redirect_uri) diff --git a/v-api/src/endpoints/login/magic_link/mod.rs b/v-api/src/endpoints/login/magic_link/mod.rs index 2108a5b6..e33d6a9e 100644 --- a/v-api/src/endpoints/login/magic_link/mod.rs +++ b/v-api/src/endpoints/login/magic_link/mod.rs @@ -218,6 +218,7 @@ where external_id: ExternalUserId::MagicLink(body.recipient.clone()), verified_emails: vec![body.recipient], display_name: None, + idp_token: None, }, ) .await?; @@ -312,8 +313,9 @@ impl CheckMagicLinkClient for MagicLink { fn is_redirect_uri_valid(&self, redirect_uri: &str) -> bool { tracing::trace!(?redirect_uri, valid_uris = ?self.redirect_uris, "Checking redirect uri against list of valid uris"); - self.redirect_uris - .iter() - .any(|r| r.redirect_uri == redirect_uri) + super::is_redirect_uri_valid( + redirect_uri, + self.redirect_uris.iter().map(|r| r.redirect_uri.as_str()), + ) } } diff --git a/v-api/src/endpoints/login/mod.rs b/v-api/src/endpoints/login/mod.rs index f9c471e4..6c9ef239 100644 --- a/v-api/src/endpoints/login/mod.rs +++ b/v-api/src/endpoints/login/mod.rs @@ -10,6 +10,7 @@ use serde::{ Deserialize, Deserializer, Serialize, Serializer, }; use thiserror::Error; +use url::Url; use crate::{ permissions::VPermission, @@ -60,6 +61,7 @@ impl From for HttpError { pub enum ExternalUserId { GitHub(String), Google(String), + Zendesk(String), #[cfg(feature = "local-dev")] Local(String), MagicLink(String), @@ -70,6 +72,7 @@ impl ExternalUserId { match self { Self::GitHub(id) => id, Self::Google(id) => id, + Self::Zendesk(id) => id, #[cfg(feature = "local-dev")] Self::Local(id) => id, Self::MagicLink(id) => id, @@ -80,6 +83,7 @@ impl ExternalUserId { match self { Self::GitHub(_) => "github", Self::Google(_) => "google", + Self::Zendesk(_) => "zendesk", #[cfg(feature = "local-dev")] Self::Local(_) => "local", Self::MagicLink(_) => "magic-link", @@ -103,6 +107,7 @@ impl Serialize for ExternalUserId { match self { ExternalUserId::GitHub(id) => serializer.serialize_str(&format!("github-{}", id)), ExternalUserId::Google(id) => serializer.serialize_str(&format!("google-{}", id)), + ExternalUserId::Zendesk(id) => serializer.serialize_str(&format!("zendesk-{}", id)), #[cfg(feature = "local-dev")] ExternalUserId::Local(id) => serializer.serialize_str(&format!("local-{}", id)), ExternalUserId::MagicLink(id) => { @@ -142,6 +147,12 @@ impl<'de> Deserialize<'de> for ExternalUserId { } else { Err(de::Error::custom(ExternalUserIdDeserializeError::Empty)) } + } else if let Some(("", id)) = value.split_once("zendesk-") { + if !id.is_empty() { + Ok(ExternalUserId::Zendesk(id.to_string())) + } else { + Err(de::Error::custom(ExternalUserIdDeserializeError::Empty)) + } } else if let Some(("", id)) = value.split_once("local-") { #[cfg(feature = "local-dev")] { @@ -181,6 +192,7 @@ pub struct UserInfo { pub external_id: ExternalUserId, pub verified_emails: Vec, pub display_name: Option, + pub idp_token: Option, } #[derive(Debug, Error)] @@ -191,11 +203,48 @@ pub enum UserInfoError { Deserialize(#[from] serde_json::Error), #[error("Failed to create user info request {0}")] Http(#[from] http::Error), + #[error("User account is locked")] + Locked, #[error("User information is missing")] MissingUserInfoData(String), + #[error("User info endpoint returned HTTP {status} for {endpoint}")] + UnexpectedStatus { + endpoint: String, + status: http::StatusCode, + }, } #[async_trait] pub trait UserInfoProvider { async fn get_user_info(&self, token: &str) -> Result; } + +/// Structurally compare a candidate redirect URI against a list of registered redirect URIs. +/// Comparison is performed on scheme, host, port, and path. URIs that fail to parse or contain +/// fragments are rejected (per RFC 6749 §3.1.2). +pub fn is_redirect_uri_valid<'a>( + redirect_uri: &str, + registered_uris: impl Iterator, +) -> bool { + let candidate = match Url::parse(redirect_uri) { + Ok(url) => url, + Err(_) => return false, + }; + + // Reject redirect URIs that contain fragments (per RFC 6749 §3.1.2) + if candidate.fragment().is_some() { + return false; + } + + registered_uris + .into_iter() + .any(|registered| match Url::parse(registered) { + Ok(registered) => { + registered.scheme() == candidate.scheme() + && registered.host() == candidate.host() + && registered.port() == candidate.port() + && registered.path() == candidate.path() + } + Err(_) => false, + }) +} diff --git a/v-api/src/endpoints/login/oauth/client.rs b/v-api/src/endpoints/login/oauth/client.rs index c8fd8304..085d1966 100644 --- a/v-api/src/endpoints/login/oauth/client.rs +++ b/v-api/src/endpoints/login/oauth/client.rs @@ -8,6 +8,7 @@ use newtype_uuid::{GenericUuid, TypedUuid}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use tracing::instrument; +use url::Url; use v_model::{ permissions::{Caller, PermissionStorage}, OAuthClient, OAuthClientId, OAuthClientRedirectUri, OAuthClientSecret, OAuthRedirectUriId, @@ -19,7 +20,7 @@ use crate::{ context::{ApiContext, VContextWithCaller}, permissions::{VAppPermission, VPermission}, secrets::OpenApiSecretString, - util::response::to_internal_error, + util::response::{bad_request, to_internal_error}, VContext, }; @@ -54,7 +55,10 @@ where T: VAppPermission + From + PermissionStorage, { // Create the new client - let client = ctx.oauth.create_oauth_client(&caller).await?; + let client = ctx + .oauth + .create_oauth_client(&caller, TypedUuid::new_v4()) + .await?; // Give the caller permission to perform actions on the client ctx.user @@ -188,6 +192,18 @@ where let (ctx, caller) = rqctx.as_ctx().await?; let path = path.into_inner(); let body = body.into_inner(); + + // Validate that the redirect URI is a well-formed URL before storing it. + // Per RFC 6749 §3.1.2, redirect URIs must be absolute URIs and must not + // include a fragment component. + let parsed = Url::parse(&body.redirect_uri) + .map_err(|_| bad_request("Invalid redirect URI: not a valid URL"))?; + if parsed.fragment().is_some() { + return Err(bad_request( + "Invalid redirect URI: must not contain a fragment", + )); + } + Ok(HttpResponseOk( ctx.oauth .add_oauth_redirect_uri(&caller, &path.client_id, &body.redirect_uri) diff --git a/v-api/src/endpoints/login/oauth/device_token.rs b/v-api/src/endpoints/login/oauth/device_token.rs deleted file mode 100644 index 0c9d5c11..00000000 --- a/v-api/src/endpoints/login/oauth/device_token.rs +++ /dev/null @@ -1,279 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -use chrono::{DateTime, Utc}; -use dropshot::{Body, HttpError, HttpResponseOk, Method, Path, RequestContext, TypedBody}; -use http::{header, HeaderValue, Response, StatusCode}; -use oauth2::{basic::BasicTokenType, EmptyExtraTokenFields, StandardTokenResponse, TokenResponse}; -use schemars::JsonSchema; -use secrecy::ExposeSecret; -use serde::{Deserialize, Serialize}; -use std::fmt::Debug; -use tap::TapFallible; -use tracing::instrument; -use v_model::permissions::PermissionStorage; - -use super::{ - ClientType, OAuthProvider, OAuthProviderInfo, OAuthProviderNameParam, UserInfoProvider, -}; -use crate::{ - context::ApiContext, endpoints::login::LoginError, error::ApiError, - permissions::VAppPermission, response::internal_error, util::response::bad_request, -}; - -#[instrument(skip(rqctx), err(Debug))] -pub async fn get_device_provider_op( - rqctx: &RequestContext>, - path: Path, -) -> Result, HttpError> -where - T: VAppPermission + PermissionStorage, -{ - let path = path.into_inner(); - - tracing::trace!("Getting OAuth data for {}", path.provider); - - let provider = rqctx - .v_ctx() - .get_oauth_provider(&path.provider) - .await - .map_err(ApiError::OAuth)?; - - Ok(HttpResponseOk(provider.provider_info( - rqctx.v_ctx().public_url(), - &ClientType::Device, - ))) -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct AccessTokenExchangeRequest { - pub device_code: String, - pub grant_type: String, - pub expires_at: Option>, -} - -#[derive(Serialize)] -pub struct AccessTokenExchange { - provider: ProviderTokenExchange, - expires_at: Option>, -} - -#[derive(Serialize)] -pub struct ProviderTokenExchange { - client_id: String, - device_code: String, - grant_type: String, - client_secret: String, -} - -impl AccessTokenExchange { - pub fn new( - req: AccessTokenExchangeRequest, - provider: &(dyn OAuthProvider + Send + Sync), - ) -> Option { - provider - .client_secret(&ClientType::Device) - .map(|client_secret| Self { - provider: ProviderTokenExchange { - client_id: provider.client_id(&ClientType::Device).to_string(), - device_code: req.device_code, - grant_type: req.grant_type, - client_secret: client_secret.expose_secret().to_string(), - }, - expires_at: req.expires_at, - }) - } -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct ProxyTokenResponse { - pub access_token: String, - pub token_type: String, - pub expires_in: Option, - pub refresh_token: Option, - pub scopes: Option>, -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct ProxyTokenError { - error: String, - error_description: Option, - error_uri: Option, -} - -// Complete a device exchange request against the specified provider. This effectively proxies the -// requests that would go to the provider, captures the returned access tokens, and registers a -// new internal user as needed. The user is then returned an token that is valid for interacting -// with the API -#[instrument(skip(rqctx, body), err(Debug))] -pub async fn exchange_device_token_op( - rqctx: &RequestContext>, - path: Path, - body: TypedBody, -) -> Result, HttpError> -where - T: VAppPermission + PermissionStorage, -{ - let ctx = rqctx.v_ctx(); - let path = path.into_inner(); - let provider = ctx - .get_oauth_provider(&path.provider) - .await - .map_err(ApiError::OAuth)?; - - tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for token exchange"); - - let exchange_request = body.into_inner(); - - if let Some(exchange) = AccessTokenExchange::new(exchange_request, &*provider) { - let token_exchange_endpoint = provider.token_exchange_endpoint(); - let client = reqwest::Client::new(); - - let response = client - .request(Method::POST, token_exchange_endpoint) - .header(header::CONTENT_TYPE, provider.token_exchange_content_type()) - .header(header::ACCEPT, HeaderValue::from_static("application/json")) - .body( - // We know that this is safe to unwrap as we just deserialized it via the body Extractor - serde_urlencoded::to_string(&exchange.provider).unwrap(), - ) - .send() - .await - .tap_err(|err| tracing::error!(?err, "Token exchange request failed")) - .map_err(internal_error)?; - - // Take a part the response as we will need the individual parts later - let status = response.status(); - let headers = response.headers().clone(); - let bytes = response.bytes().await.map_err(internal_error)?; - - // We unfortunately can not trust our providers to follow specs and therefore need to do - // our own inspection of the response to determine what to do - if !status.is_success() { - // If the server returned a non-success status then we are going to trust the server and - // report their error back to the client - tracing::debug!(provider = ?path.provider, ?headers, ?status, "Received error response from OAuth provider"); - - let mut client_response = Response::new(Body::from(bytes)); - *client_response.headers_mut() = headers; - *client_response.status_mut() = status; - - Ok(client_response) - } else { - // The server gave us back a non-error response but it still may not be a success. - // GitHub for instance does not use a status code for indicating the success or failure - // of a call. So instead we try to deserialize the body into an access token, with the - // understanding that it may fail and we will need to try and treat the response as - // an error instead. - - let parsed: Result< - StandardTokenResponse, - serde_json::Error, - > = serde_json::from_slice(&bytes); - - match parsed { - Ok(parsed) => { - let info = provider - .get_user_info(parsed.access_token().secret()) - .await - .map_err(LoginError::UserInfo) - .tap_err(|err| { - tracing::error!(?err, "Failed to look up user information") - })?; - - tracing::debug!("Verified and validated OAuth user"); - - let (api_user_info, api_user_provider) = ctx - .register_api_user(&ctx.builtin_registration_user(), info) - .await?; - - tracing::info!(api_user_id = ?api_user_info.user.id, api_user_provider_id = ?api_user_provider.id, "Retrieved api user to generate device token for"); - - let claims = - ctx.generate_claims(&api_user_info.user.id, &api_user_provider.id, None); - let token = ctx - .user - .register_access_token( - &ctx.builtin_registration_user(), - ctx.jwt_signer(), - &api_user_info.user.id, - &claims, - ) - .await?; - - tracing::info!(provider = ?path.provider, api_user_id = ?api_user_info.user.id, "Generated access token"); - - Ok(Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "application/json") - .body( - serde_json::to_string(&ProxyTokenResponse { - access_token: token.signed_token, - token_type: "Bearer".to_string(), - expires_in: Some(claims.exp - Utc::now().timestamp()), - refresh_token: None, - scopes: None, - }) - .unwrap() - .into(), - )?) - } - Err(_) => { - // Do not log the error here as we want to ensure we do not leak token information - tracing::debug!( - "Failed to parse a success response from the remote token endpoint" - ); - - // Try to deserialize the body again, but this time as an error - let mut error_response = match serde_json::from_slice::(&bytes) - { - Ok(error) => { - // We found an error in the message body. This is not ideal, but we at - // least can understand what the server was trying to tell us - tracing::debug!(?error, provider = ?path.provider, "Parsed error response from OAuth provider"); - - let mut client_response = Response::new(Body::from(bytes)); - *client_response.headers_mut() = headers; - *client_response.status_mut() = status; - - client_response - } - Err(_) => { - // We still do not know what the remote server is doing... and need to - // cancel the request ourselves - tracing::warn!( - "Remote OAuth provide returned a response that we do not undestand" - ); - - Response::new( - serde_json::to_vec(&ProxyTokenError { - error: "access_denied".to_string(), - error_description: Some(format!( - "{} returned a malformed response", - path.provider - )), - error_uri: None, - }) - .unwrap() - .into(), - ) - } - }; - - *error_response.status_mut() = StatusCode::BAD_REQUEST; - error_response.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static("application/json"), - ); - - Ok(error_response) - } - } - } - } else { - tracing::info!(provider = ?path.provider, "Found an OAuth provider, but it is not configured properly"); - - Err(bad_request("Invalid provider")) - } -} diff --git a/v-api/src/endpoints/login/oauth/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs similarity index 70% rename from v-api/src/endpoints/login/oauth/code.rs rename to v-api/src/endpoints/login/oauth/flow/code.rs index 39d3a4a8..d5abcb73 100644 --- a/v-api/src/endpoints/login/oauth/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -12,7 +12,7 @@ use dropshot::{ }; use dropshot_authorization_header::basic::BasicAuth; use http::{header::SET_COOKIE, HeaderValue}; -use newtype_uuid::TypedUuid; +use newtype_uuid::{GenericUuid, TypedUuid}; use oauth2::{ AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, }; @@ -24,22 +24,25 @@ use sha2::{Digest, Sha256}; use std::{fmt::Debug, ops::Add}; use tap::TapFallible; use tracing::instrument; +use uuid::Uuid; use v_model::{ permissions::{AsScope, PermissionStorage}, schema_ext::LoginAttemptState, LoginAttempt, LoginAttemptId, NewLoginAttempt, OAuthClient, OAuthClientId, }; -use super::{OAuthProvider, OAuthProviderNameParam, UserInfoProvider, WebClientConfig}; +use super::super::{OAuthProvider, OAuthProviderNameParam}; +use crate::endpoints::login::UserInfoProvider; use crate::{ authn::key::RawKey, context::{ApiContext, VContext}, endpoints::login::{ - oauth::{CheckOAuthClient, ClientType}, + oauth::{CheckOAuthClient, ClientType, OAuthProviderAuthorizationCodePkceInfo}, LoginError, UserInfo, }, error::ApiError, permissions::{VAppPermission, VPermission}, + response::bad_request, secrets::OpenApiSecretString, util::{ request::RequestCookies, @@ -48,8 +51,26 @@ use crate::{ }; static LOGIN_ATTEMPT_COOKIE: &str = "__v_login"; +static LOGIN_ATTEMPT_COOKIE_PATH: &str = "/login/oauth/"; static DEFAULT_SCOPE: &str = "user:info:r"; +/// Build the login attempt cookie with consistent attributes. +/// The `Path` is scoped to the OAuth login endpoints so the cookie is not +/// sent to unrelated paths on the same domain. +fn build_login_attempt_cookie<'a>( + value: &'a str, + public_url: &str, + max_age_secs: i64, +) -> Cookie<'a> { + let mut cookie = Cookie::new(LOGIN_ATTEMPT_COOKIE, value.to_string()); + cookie.set_path(LOGIN_ATTEMPT_COOKIE_PATH); + cookie.set_http_only(true); + cookie.set_same_site(SameSite::Lax); + cookie.set_secure(public_url.starts_with("https")); + cookie.set_max_age(cookie::time::Duration::seconds(max_age_secs)); + cookie +} + #[derive(Debug, Deserialize, JsonSchema, Serialize, PartialEq, Eq)] struct OAuthError { error: OAuthErrorCode, @@ -62,7 +83,7 @@ struct OAuthError { } #[derive(Debug, Deserialize, JsonSchema, Serialize, PartialEq, Eq)] -#[serde(untagged)] +#[serde(rename_all = "snake_case")] enum OAuthErrorCode { AccessDenied, InvalidClient, @@ -96,6 +117,10 @@ pub struct OAuthAuthzCodeQuery { pub response_type: String, pub state: String, pub scope: Option, + /// PKCE code challenge (RFC 7636). Required for all authorization code flows. + pub code_challenge: String, + /// PKCE code challenge method. Must be "S256". + pub code_challenge_method: String, } #[derive(Debug, Deserialize, JsonSchema, Serialize)] @@ -152,6 +177,32 @@ where } } +#[instrument(skip(rqctx), err(Debug))] +pub async fn get_public_pkce_provider_op( + rqctx: &RequestContext>, + path: Path, +) -> Result, HttpError> +where + T: VAppPermission + PermissionStorage, +{ + let path = path.into_inner(); + + tracing::trace!("Getting OAuth data for {}", path.provider); + + let provider = rqctx + .v_ctx() + .get_oauth_provider(&path.provider) + .await + .map_err(ApiError::OAuth)?; + + Ok(HttpResponseOk( + provider + .authz_code_pkce_flow_info() + .cloned() + .ok_or_else(|| bad_request("Provider does not support web pkce clients"))?, + )) +} + #[instrument(skip(rqctx), err(Debug))] pub async fn authz_code_redirect_op( rqctx: &RequestContext>, @@ -169,6 +220,40 @@ where tracing::debug!(?query.client_id, ?query.redirect_uri, "Verified client id and redirect uri"); + // Validate the client's PKCE challenge method. Only S256 is supported. + if query.code_challenge_method != "S256" { + return Err(OAuthError { + error: OAuthErrorCode::InvalidRequest, + error_description: Some( + "Unsupported code_challenge_method. Only S256 is supported.".to_string(), + ), + error_uri: None, + state: None, + } + .into()); + } + + // Validate the PKCE code challenge. For S256, this must be a base64url-no-pad + // encoding of a SHA256 hash, which is always exactly 43 characters of [A-Za-z0-9_-] + // (RFC 7636 §4.2). + if query.code_challenge.len() != 43 + || !query + .code_challenge + .bytes() + .all(|b| b.is_ascii_alphanumeric() || b == b'-' || b == b'_') + { + return Err(OAuthError { + error: OAuthErrorCode::InvalidRequest, + error_description: Some( + "Invalid code_challenge. Must be a base64url-encoded SHA256 hash (43 characters)." + .to_string(), + ), + error_uri: None, + state: None, + } + .into()); + } + // Find the configured provider for the requested remote backend. We should always have a valid // provider value, so if this fails then a 500 is returned let provider = ctx @@ -180,9 +265,16 @@ where // Check that the passed in scopes are valid. The scopes are not currently restricted by client let scope = query.scope.unwrap_or_else(|| DEFAULT_SCOPE.to_string()); - let scope_error = VPermission::from_scope_arg(&scope) - .err() - .map(|_| "invalid_scope".to_string()); + if let Err(err) = VPermission::from_scope_arg(&scope) { + tracing::warn!(?err, ?scope, "Client submitted an invalid scope"); + return Err(OAuthError { + error: OAuthErrorCode::InvalidScope, + error_description: Some(format!("Invalid scope: {}", scope)), + error_uri: None, + state: None, + } + .into()); + } // Construct a new login attempt with the minimum required values let mut attempt = NewLoginAttempt::new( @@ -200,16 +292,19 @@ where // TODO: Make this configurable attempt.expires_at = Some(Utc::now().add(TimeDelta::try_minutes(5).unwrap())); - // Assign any scope errors that arose - attempt.error = scope_error; - // Add in the user defined state and redirect uri. State is an arbitrary value and may be // malicious. It must be url-encoded before being presented back to the client. Therefore we // process once before storing so all downstream consumers see the encoded value. attempt.state = Some(percent_encode(query.state.as_bytes(), NON_ALPHANUMERIC).to_string()); - // If the remote provider supports pkce, set up a challenge - let pkce_challenge = if provider.supports_pkce() { + // Always store the client's PKCE challenge so we can verify it during the token exchange. + // This is the client-to-v-api PKCE leg and is mandatory for all flows. + attempt.pkce_challenge = Some(query.code_challenge); + attempt.pkce_challenge_method = Some(query.code_challenge_method); + + // If the remote provider supports PKCE, also set up a challenge for the v-api-to-remote leg. + // This is independent of the client-to-v-api PKCE above. + let remote_pkce_challenge = if provider.supports_pkce() { let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); attempt.provider_pkce_verifier = Some(pkce_verifier.secret().to_string()); Some(pkce_challenge) @@ -226,7 +321,12 @@ where tracing::info!(?attempt.id, "Created login attempt"); - oauth_redirect_response(ctx.public_url(), &*provider, &attempt, pkce_challenge) + oauth_redirect_response( + ctx.public_url(), + &*provider, + &attempt, + remote_pkce_challenge, + ) } fn oauth_redirect_response( @@ -239,20 +339,12 @@ fn oauth_redirect_response( // TODO: This behavior should be changed so that clients are precomputed. We do not need to be // constructing a new client on every request. That said, we need to ensure the client does not // maintain state between requests - let client = provider - .as_web_client(&WebClientConfig { - prefix: public_url.to_string(), - }) - .map_err(to_internal_error)?; + let client = provider.as_web_client().map_err(to_internal_error)?; // Create an attempt cookie header for storing the login attempt. This also acts as our csrf // check - let mut cookie = Cookie::new(LOGIN_ATTEMPT_COOKIE, attempt.id.to_string()); - cookie.set_http_only(true); - cookie.set_same_site(SameSite::Lax); - cookie.set_secure(public_url.starts_with("https")); - cookie.set_max_age(cookie::time::Duration::seconds(600)); - + let attempt_id_str = attempt.id.to_string(); + let cookie = build_login_attempt_cookie(&attempt_id_str, public_url, 600); let login_cookie = HeaderValue::from_str(&cookie.to_string()).map_err(to_internal_error)?; // Generate the url to the remote provider that the user will be redirected to @@ -260,8 +352,8 @@ fn oauth_redirect_response( .authorize_url(|| CsrfToken::new(attempt.id.to_string())) .add_scopes( provider - .scopes() - .into_iter() + .default_scopes() + .iter() .map(|s| Scope::new(s.to_string())) .collect::>(), ); @@ -310,7 +402,7 @@ fn verify_csrf( .value() .parse() .map_err(|err| { - tracing::warn!(?err, "Failed to parse state"); + tracing::warn!(?err, "Failed to parse state cookie"); unauthorized() })?; @@ -358,11 +450,7 @@ where let attempt_id = verify_csrf(&rqctx.request, &query)?; // Clear the login attempt cookie - let mut cookie = Cookie::new(LOGIN_ATTEMPT_COOKIE, ""); - cookie.set_http_only(true); - cookie.set_same_site(SameSite::Lax); - cookie.set_secure(ctx.public_url().starts_with("https")); - cookie.set_max_age(cookie::time::Duration::seconds(0)); + let cookie = build_login_attempt_cookie("", ctx.public_url(), 0); let login_cookie = HeaderValue::from_str(&cookie.to_string()).map_err(to_internal_error)?; let mut redirect = http_response_temporary_redirect( @@ -430,14 +518,27 @@ where // TODO: Specialize the returned error ctx.login - .fail_login_attempt(attempt, Some(error_message), error.as_deref()) + .fail_login_attempt( + attempt, + LoginAttemptState::New, + Some(error_message), + error.as_deref(), + ) .await .map_err(to_internal_error)? } }; // Redirect back to the original authenticator - Ok(attempt.callback_url()) + attempt.callback_url().map_err(|err| { + tracing::error!(?err, redirect_uri = ?attempt.redirect_uri, "Login attempt contains an invalid redirect URI"); + to_internal_error(err) + }) +} + +#[derive(Debug, Deserialize, JsonSchema)] +pub struct OAuthAuthzCodeExchangeQuery { + pub request_idp_token: bool, } #[derive(Debug, Deserialize, JsonSchema)] @@ -447,7 +548,8 @@ pub struct OAuthAuthzCodeExchangeBody { pub redirect_uri: String, pub grant_type: String, pub code: String, - pub pkce_verifier: Option, + /// PKCE code verifier (RFC 7636). Required for all authorization code exchanges. + pub pkce_verifier: String, } #[derive(Debug, Deserialize, JsonSchema, Serialize)] @@ -455,11 +557,18 @@ pub struct OAuthAuthzCodeExchangeResponse { pub access_token: String, pub token_type: String, pub expires_in: i64, + pub idp_token: Option, +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct OAuthAuthzCodeIdpToken { + pub token: String, } #[instrument(skip(rqctx), err(Debug))] pub async fn authz_code_exchange_op( rqctx: &RequestContext>, + query: Query, path: Path, body: TypedBody, ) -> Result, HttpError> @@ -467,49 +576,81 @@ where T: VAppPermission + PermissionStorage, { let ctx = rqctx.v_ctx(); + let query = query.into_inner(); let path = path.into_inner(); let body = body.into_inner(); - let (client_id, client_secret) = - if let (Some(client_id), Some(client_secret)) = (body.client_id, body.client_secret) { - Ok::<_, HttpError>((client_id, client_secret)) - } else { - // Attempt to extract basic authorization credentials from the request if they were not - // present in the request body - let auth = ::from_request(rqctx) - .await - .tap_err(|err| { - tracing::warn!(?err, "Failed to extract basic authentication values"); - }); - let (client_id, client_secret) = match auth { - Ok(auth) if auth.username().is_some() && auth.password().is_some() => Ok(( - auth.username().unwrap().to_string(), - auth.password().unwrap().to_string(), - )), - _ => Err(internal_error( - "Missing client id and client secret from authz code exchange", - )), - }?; - - Ok(( - client_id.parse().map_err(to_internal_error)?, - OpenApiSecretString(client_secret.into()), - )) - }?; - let provider = ctx .get_oauth_provider(&path.provider) .await .map_err(ApiError::OAuth)?; + // Extract basic authorization credentials from the request if they were provided. + let auth = ::from_request(rqctx) + .await + .tap_err(|err| { + tracing::warn!(?err, "Failed to extract basic authentication values"); + }); + let basic_credentials = match auth { + Ok(auth) if auth.username().is_some() && auth.password().is_some() => Ok(Some(( + TypedUuid::from_untyped_uuid( + Uuid::parse_str(auth.username().unwrap()) + .map_err(|_| bad_request("Malformed client ID presented to code exchange"))?, + ), + auth.password().unwrap().to_string(), + ))), + Ok(auth) if auth.username().is_none() && auth.password().is_none() => { + tracing::info!("Credentials for code exchange not defined via basic auth"); + Ok(None) + } + Ok(_) => Err(bad_request( + "Malformed credentials presented to code exchange", + )), + Err(err) => { + tracing::info!(?err, "Failed to extract basic authentication credentials"); + Ok(None) + } + }?; + + // Extract credentials from the request body if they were provided. + let body_credentials = (body.client_id, body.client_secret); + + // Now validate if the credentials provided by the client support one of our expected schemes. + // We of course deny underspecifying credentials, but we also want to disallow over specifying + // them. For example, if the client provides both basic auth and a client id/secret in the + // request body, we should reject the request. + tracing::debug!( + ?basic_credentials, + ?body_credentials, + "Extracted credentials from request" + ); + let (client_id, client_secret) = match (basic_credentials, body_credentials) { + (Some(_), (Some(_), _)) => Err(bad_request( + "Cannot provide both basic auth and client credentials", + )), + (Some(_), (_, Some(_))) => Err(bad_request( + "Cannot provide both basic auth and client credentials", + )), + (Some((client_id, client_secret)), (None, None)) => Ok(( + client_id, + Some(OpenApiSecretString(SecretString::from(client_secret))), + )), + (None, (Some(client_id), Some(client_secret))) => Ok((client_id, Some(client_secret))), + (None, (Some(client_id), _)) if provider.authz_code_pkce_flow_info().is_some() => { + Ok((client_id, None)) + } + _ => Err(bad_request("Missing client credentials")), + }?; + tracing::debug!("Attempting code exchange"); // Verify the submitted client credentials authorize_code_exchange( ctx, + &*provider, &body.grant_type, client_id, - &client_secret.0, + client_secret.map(|s| s.0).as_ref(), &body.redirect_uri, ) .await?; @@ -532,37 +673,54 @@ where // Verify that the login attempt is valid and matches the submitted client credentials verify_login_attempt( &attempt, + &provider.name().to_string(), client_id, &body.redirect_uri, - body.pkce_verifier.as_deref(), + &body.pkce_verifier, )?; tracing::debug!("Verified login attempt"); - // Now that the attempt has been confirmed, use it to fetch user information form the remote - // provider - let info = fetch_user_info(ctx.public_url(), &ctx.web_client(), &*provider, &attempt).await?; - - tracing::debug!("Retrieved user information from remote provider"); - - // During fetch_user_info we revoke any downstream codes if possible, therefore At this point we - // consider the login attempt to be consumed and can no longer be used. We state transition to - // complete, even though we may fail further along in the handler. If a failure occurs then the - // user will need to re-authenticate. + // Atomically claim this login attempt before doing any remote work. This transitions + // the attempt from RemoteAuthenticated -> Complete in a single conditional UPDATE, + // ensuring that a concurrent request using the same authorization code will fail. + // Per RFC 6749 §4.1.2, authorization codes MUST be single-use. + let attempt_id = attempt.id; attempt = ctx .login - .complete_login_attempt(attempt) + .claim_login_attempt(attempt) .await .map_err(|err| { - tracing::error!(?err, "Failed to complete login attempt"); + tracing::warn!( + ?err, + ?attempt_id, + "Failed to claim login attempt (may have been consumed by a concurrent request)" + ); OAuthError { - error: OAuthErrorCode::ServerError, - error_description: Some("An unexpected error occurred".to_string()), + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Authorization code has already been used".to_string()), error_uri: None, state: None, } })?; + tracing::debug!("Claimed login attempt"); + + // Now that the attempt has been claimed, use it to fetch user information from the + // remote provider. If this fails, the attempt is already consumed and the user must + // re-authenticate. + let info = fetch_user_info( + ctx.public_url(), + &ctx.web_client(), + &*provider, + &attempt, + !query.request_idp_token, + ) + .await?; + let idp_token = info.idp_token.clone(); + + tracing::debug!("Retrieved user information from remote provider"); + // Register this user as an API user if needed let (api_user_info, api_user_provider) = ctx .register_api_user(&ctx.builtin_registration_user(), info) @@ -589,14 +747,16 @@ where token_type: "Bearer".to_string(), access_token: token.signed_token, expires_in: token.expires_in, + idp_token, })) } async fn authorize_code_exchange( ctx: &VContext, + provider: &dyn OAuthProvider, grant_type: &str, client_id: TypedUuid, - client_secret: &SecretString, + client_secret: Option<&SecretString>, redirect_uri: &str, ) -> Result<(), OAuthError> where @@ -616,40 +776,62 @@ where tracing::debug!(grant_type, "Verified grant type"); - let client_secret = RawKey::try_from(client_secret).map_err(|err| { - tracing::warn!(?err, "Failed to parse OAuth client secret"); + // If we were provided a client secret, then it must be verified. If a client secret was not + // provided, then we can skip this step as long as the provider supports pkce_only + // authentication. + if let Some(client_secret) = client_secret { + let client_secret = RawKey::try_from(client_secret).map_err(|err| { + tracing::warn!(?err, "Failed to parse OAuth client secret"); - OAuthError { - error: OAuthErrorCode::InvalidRequest, - error_description: Some("Malformed client secret".to_string()), - error_uri: None, - state: None, - } - })?; + OAuthError { + error: OAuthErrorCode::InvalidRequest, + error_description: Some("Malformed client secret".to_string()), + error_uri: None, + state: None, + } + })?; - tracing::debug!("Constructed client secret"); + tracing::debug!("Constructed client secret"); + + if !client.is_secret_valid(&client_secret, ctx) { + Err(OAuthError { + error: OAuthErrorCode::InvalidClient, + error_description: Some("Invalid client secret".to_string()), + error_uri: None, + state: None, + }) + } else { + tracing::debug!("Verified client secret validity"); - if !client.is_secret_valid(&client_secret, ctx) { + Ok(()) + } + } else if provider.authz_code_pkce_flow_info().is_some() { + Ok(()) + } else { Err(OAuthError { - error: OAuthErrorCode::InvalidClient, - error_description: Some("Invalid client secret".to_string()), + error: OAuthErrorCode::InvalidRequest, + error_description: Some("Client secret required".to_string()), error_uri: None, state: None, }) - } else { - tracing::debug!("Verified client secret validity"); - - Ok(()) } } fn verify_login_attempt( attempt: &LoginAttempt, + provider: &str, client_id: TypedUuid, redirect_uri: &str, - pkce_verifier: Option<&str>, + pkce_verifier: &str, ) -> Result<(), OAuthError> { - if attempt.client_id != client_id { + if attempt.provider != provider { + Err(OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Provider mismatch".to_string()), + error_uri: None, + state: None, + }) + } else if attempt.client_id != client_id { Err(OAuthError { error: OAuthErrorCode::InvalidGrant, error_description: Some("Invalid client id".to_string()), @@ -678,16 +860,10 @@ fn verify_login_attempt( state: None, }) } else { - match (attempt.pkce_challenge.as_deref(), pkce_verifier) { - (Some(_), None) => Err(OAuthError { - error: OAuthErrorCode::InvalidRequest, - error_description: Some("Missing pkce verifier".to_string()), - error_uri: None, - state: None, - }), - (Some(challenge), Some(verifier)) => { + match attempt.pkce_challenge.as_deref() { + Some(challenge) => { let mut hasher = Sha256::new(); - hasher.update(verifier); + hasher.update(pkce_verifier); let hash = hasher.finalize(); let computed_challenge = BASE64_URL_SAFE_NO_PAD.encode(hash); @@ -702,7 +878,14 @@ fn verify_login_attempt( }) } } - (None, _) => Ok(()), + // PKCE is mandatory for all authorization code flows. A missing challenge + // means the login attempt was not properly initialized. + None => Err(OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Login attempt is missing a PKCE challenge".to_string()), + error_uri: None, + state: None, + }), } } } @@ -713,13 +896,13 @@ async fn fetch_user_info( client_type: &ClientType, provider: &dyn OAuthProvider, attempt: &LoginAttempt, + revoke_idp_token: bool, ) -> Result { + let provider_info = provider + .authz_code_flow_info() + .ok_or_else(|| internal_error("Authorization code flow not supported"))?; // Exchange the stored authorization code with the remote provider for a remote access token - let client = provider - .as_web_client(&WebClientConfig { - prefix: public_url.to_string(), - }) - .map_err(to_internal_error)?; + let client = provider.as_web_client().map_err(to_internal_error)?; let mut request = client.exchange_code(AuthorizationCode::new( attempt @@ -754,7 +937,7 @@ async fn fetch_user_info( // Now that we are done with fetching user information from the remote API, we can revoke it if // the provider supports it - if provider.token_revocation_endpoint().is_some() { + if revoke_idp_token && provider_info.remote.revocation_endpoint.is_some() { client .revoke_token(response.access_token().into()) .map_err(internal_error)? @@ -799,7 +982,7 @@ mod tests { VContext, }, endpoints::login::oauth::{ - code::{ + flow::code::{ authz_code_callback_op_inner, verify_csrf, verify_login_attempt, OAuthAuthzCodeReturnQuery, OAuthError, OAuthErrorCode, LOGIN_ATTEMPT_COOKIE, }, @@ -819,7 +1002,7 @@ mod tests { .unwrap(); let secret_signature = key.signature().to_string(); let client_secret = key.key(); - let redirect_uri = "callback-destination"; + let redirect_uri = "https://example.com/callback"; ( ctx, @@ -889,8 +1072,7 @@ mod tests { #[tokio::test] async fn test_remote_provider_redirect_url() { let storage = MockStorage::new(); - let mut ctx = mock_context(Arc::new(storage)).await; - ctx.with_public_url("https://api.oxeng.dev"); + let ctx = mock_context(Arc::new(storage)).await; let (challenge, _) = PkceCodeChallenge::new_random_sha256(); let attempt = LoginAttempt { @@ -927,7 +1109,7 @@ mod tests { .unwrap(); let headers = response.headers(); - let expected_location = format!("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=google_web_client_id&state={}&code_challenge={}&code_challenge_method=S256&redirect_uri=https%3A%2F%2Fapi.oxeng.dev%2Flogin%2Foauth%2Fgoogle%2Fcode%2Fcallback&scope=openid+email+profile", attempt.id, challenge.as_str()); + let expected_location = format!("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=google_web_client_id&state={}&code_challenge={}&code_challenge_method=S256&redirect_uri=https%3A%2F%2Ftest_public_url%2Flogin%2Foauth%2Fgoogle%2Fcode%2Fcallback&scope=openid+email+profile", attempt.id, challenge.as_str()); assert_eq!( expected_location, @@ -935,7 +1117,7 @@ mod tests { ); assert_eq!( format!( - "{}; HttpOnly; SameSite=Lax; Secure; Max-Age=600", + "{}; HttpOnly; SameSite=Lax; Secure; Path=/login/oauth/; Max-Age=600", attempt.id ) .as_str(), @@ -1100,9 +1282,12 @@ mod tests { .returning(move |_| Ok(Some(original_attempt.clone()))); attempt_store - .expect_upsert() - .withf(|attempt| attempt.attempt_state == LoginAttemptState::Failed) - .returning(move |arg| { + .expect_update_if_state() + .withf(|attempt, expected| { + attempt.attempt_state == LoginAttemptState::Failed + && *expected == LoginAttemptState::New + }) + .returning(move |arg, _| { let mut returned = attempt.clone(); returned.attempt_state = arg.attempt_state; returned.authz_code = arg.authz_code; @@ -1124,7 +1309,7 @@ mod tests { .unwrap(); assert_eq!( - format!("https://test.oxeng.dev/callback?error=server_error&state=ox_state",), + format!("https://test.oxeng.dev/callback?state=ox_state&error=server_error",), location ); } @@ -1160,9 +1345,12 @@ mod tests { .returning(move |_| Ok(Some(original_attempt.clone()))); attempt_store - .expect_upsert() - .withf(|attempt| attempt.attempt_state == LoginAttemptState::Failed) - .returning(move |arg| { + .expect_update_if_state() + .withf(|attempt, expected| { + attempt.attempt_state == LoginAttemptState::Failed + && *expected == LoginAttemptState::New + }) + .returning(move |arg, _| { let mut returned = attempt.clone(); returned.attempt_state = arg.attempt_state; returned.authz_code = arg.authz_code; @@ -1184,7 +1372,7 @@ mod tests { .unwrap(); assert_eq!( - format!("https://test.oxeng.dev/callback?error=access_denied&state=ox_state",), + format!("https://test.oxeng.dev/callback?state=ox_state&error=access_denied",), location ); } @@ -1222,9 +1410,12 @@ mod tests { let extracted_code = Arc::new(Mutex::new(None)); let extractor = extracted_code.clone(); attempt_store - .expect_upsert() - .withf(|attempt| attempt.attempt_state == LoginAttemptState::RemoteAuthenticated) - .returning(move |arg| { + .expect_update_if_state() + .withf(|attempt, expected| { + attempt.attempt_state == LoginAttemptState::RemoteAuthenticated + && *expected == LoginAttemptState::New + }) + .returning(move |arg, _| { let mut returned = attempt.clone(); returned.attempt_state = arg.attempt_state; returned.authz_code = arg.authz_code; @@ -1244,7 +1435,7 @@ mod tests { let lock = extracted_code.lock(); assert_eq!( format!( - "https://test.oxeng.dev/callback?code={}&state=ox_state", + "https://test.oxeng.dev/callback?state=ox_state&code={}", lock.unwrap().as_ref().unwrap() ), location @@ -1275,15 +1466,20 @@ mod tests { storage.oauth_client_store = Some(Arc::new(client_store)); ctx.set_storage(Arc::new(storage)); + let provider = ctx + .get_oauth_provider(&OAuthProviderName::Google) + .await + .unwrap(); // 1. Verify exchange fails when passing an incorrect client id assert_eq!( Some("Unknown client id".to_string()), authorize_code_exchange( &ctx, + &*provider, "authorization_code", wrong_client_id, - &client_secret, + Some(&client_secret), &redirect_uri, ) .await @@ -1296,9 +1492,10 @@ mod tests { Some("Invalid redirect uri".to_string()), authorize_code_exchange( &ctx, + &*provider, "authorization_code", client_id, - &client_secret, + Some(&client_secret), "wrong-callback-destination", ) .await @@ -1306,14 +1503,75 @@ mod tests { .error_description ); - // 3. Verify a successful exchange + // 3. Verify a successful exchange with a client secret assert_eq!( (), authorize_code_exchange( &ctx, + &*provider, "authorization_code", client_id, - &client_secret, + Some(&client_secret), + &redirect_uri, + ) + .await + .unwrap() + ); + } + + #[tokio::test] + async fn test_exchange_requires_secret_except_for_pkce_only() { + let (mut ctx, client, _) = mock_client().await; + let client_id = client.id; + let redirect_uri = client.redirect_uris[0].redirect_uri.clone(); + + let mut client_store = MockOAuthClientStore::new(); + client_store + .expect_get() + .with(eq(client_id), eq(false)) + .returning(move |_, _| Ok(Some(client.clone()))); + + let mut storage = MockStorage::new(); + storage.oauth_client_store = Some(Arc::new(client_store)); + + ctx.set_storage(Arc::new(storage)); + + let provider = ctx + .get_oauth_provider(&OAuthProviderName::Google) + .await + .unwrap(); + let pkce_only_provider = ctx + .get_oauth_provider(&OAuthProviderName::Zendesk) + .await + .unwrap(); + + // 1. Verify exchange fails when not passing a client secret for a client that does not + // support pkce_only + assert_eq!( + Some("Client secret required".to_string()), + authorize_code_exchange( + &ctx, + &*provider, + "authorization_code", + client_id, + None, + &redirect_uri, + ) + .await + .unwrap_err() + .error_description + ); + + // 2. Verify exchange passes when omitting the client secret for a client that does + // support pkce_only + assert_eq!( + (), + authorize_code_exchange( + &ctx, + &*pkce_only_provider, + "authorization_code", + client_id, + None, &redirect_uri, ) .await @@ -1337,14 +1595,19 @@ mod tests { storage.oauth_client_store = Some(Arc::new(client_store)); ctx.set_storage(Arc::new(storage)); + let provider = ctx + .get_oauth_provider(&OAuthProviderName::Google) + .await + .unwrap(); assert_eq!( OAuthErrorCode::UnsupportedGrantType, authorize_code_exchange( &ctx, + &*provider, "not_authorization_code", client_id, - &client_secret, + Some(&client_secret), &redirect_uri ) .await @@ -1356,9 +1619,10 @@ mod tests { (), authorize_code_exchange( &ctx, + &*provider, "authorization_code", client_id, - &client_secret, + Some(&client_secret), &redirect_uri ) .await @@ -1382,6 +1646,10 @@ mod tests { storage.oauth_client_store = Some(Arc::new(client_store)); ctx.set_storage(Arc::new(storage)); + let provider = ctx + .get_oauth_provider(&OAuthProviderName::Google) + .await + .unwrap(); let invalid_secret = RawKey::generate::<8>(&Uuid::new_v4()) .sign(&*ctx.signer()) @@ -1394,9 +1662,10 @@ mod tests { OAuthErrorCode::InvalidRequest, authorize_code_exchange( &ctx, + &*provider, "authorization_code", client_id, - &"too-short".to_string().into(), + Some(&"too-short".to_string().into()), &redirect_uri ) .await @@ -1408,9 +1677,10 @@ mod tests { OAuthErrorCode::InvalidClient, authorize_code_exchange( &ctx, + &*provider, "authorization_code", client_id, - &invalid_secret.into(), + Some(&invalid_secret.into()), &redirect_uri ) .await @@ -1422,9 +1692,10 @@ mod tests { (), authorize_code_exchange( &ctx, + &*provider, "authorization_code", client_id, - &client_secret, + Some(&client_secret), &redirect_uri ) .await @@ -1469,9 +1740,10 @@ mod tests { }, verify_login_attempt( &bad_client_id, + &attempt.provider, attempt.client_id, &attempt.redirect_uri, - Some(verifier.secret().as_str()), + verifier.secret().as_str(), ) .unwrap_err() ); @@ -1490,9 +1762,10 @@ mod tests { }, verify_login_attempt( &bad_redirect_uri, + &attempt.provider, attempt.client_id, &attempt.redirect_uri, - Some(verifier.secret().as_str()), + verifier.secret().as_str(), ) .unwrap_err() ); @@ -1511,9 +1784,10 @@ mod tests { }, verify_login_attempt( &unconfirmed_state, + &attempt.provider, attempt.client_id, &attempt.redirect_uri, - Some(verifier.secret().as_str()), + verifier.secret().as_str(), ) .unwrap_err() ); @@ -1532,9 +1806,10 @@ mod tests { }, verify_login_attempt( &already_used_state, + &attempt.provider, attempt.client_id, &attempt.redirect_uri, - Some(verifier.secret().as_str()), + verifier.secret().as_str(), ) .unwrap_err() ); @@ -1553,9 +1828,10 @@ mod tests { }, verify_login_attempt( &failed_state, + &attempt.provider, attempt.client_id, &attempt.redirect_uri, - Some(verifier.secret().as_str()), + verifier.secret().as_str(), ) .unwrap_err() ); @@ -1574,27 +1850,35 @@ mod tests { }, verify_login_attempt( &expired, + &attempt.provider, attempt.client_id, &attempt.redirect_uri, - Some(verifier.secret().as_str()), + verifier.secret().as_str(), ) .unwrap_err() ); - let missing_pkce = LoginAttempt { ..attempt.clone() }; + // Verify that a login attempt with no stored PKCE challenge is rejected. + // PKCE is mandatory, so a missing challenge means the attempt is invalid. + let missing_challenge = LoginAttempt { + pkce_challenge: None, + pkce_challenge_method: None, + ..attempt.clone() + }; assert_eq!( OAuthError { - error: OAuthErrorCode::InvalidRequest, - error_description: Some("Missing pkce verifier".to_string()), + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Login attempt is missing a PKCE challenge".to_string()), error_uri: None, state: None, }, verify_login_attempt( - &missing_pkce, + &missing_challenge, + &attempt.provider, attempt.client_id, &attempt.redirect_uri, - None, + verifier.secret().as_str(), ) .unwrap_err() ); @@ -1613,22 +1897,131 @@ mod tests { }, verify_login_attempt( &invalid_pkce, + &attempt.provider, + attempt.client_id, + &attempt.redirect_uri, + verifier.secret().as_str(), + ) + .unwrap_err() + ); + + assert_eq!( + (), + verify_login_attempt( + &attempt, + &attempt.provider, attempt.client_id, &attempt.redirect_uri, - Some(verifier.secret().as_str()), + verifier.secret().as_str(), + ) + .unwrap() + ); + } + + #[tokio::test] + async fn test_provider_mismatch_is_rejected() { + let (challenge, verifier) = PkceCodeChallenge::new_random_sha256(); + + // Login attempt was created via Google + let attempt = LoginAttempt { + id: TypedUuid::new_v4(), + attempt_state: LoginAttemptState::RemoteAuthenticated, + client_id: TypedUuid::new_v4(), + redirect_uri: "https://test.oxeng.dev/callback".to_string(), + state: Some("ox_state".to_string()), + pkce_challenge: Some(challenge.as_str().to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: None, + expires_at: Some(Utc::now().add(TimeDelta::try_seconds(60).unwrap())), + error: None, + provider: "google".to_string(), + provider_pkce_verifier: Some("v_verifier".to_string()), + provider_authz_code: None, + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + scope: String::new(), + }; + + // Exchanging against a different provider must fail + assert_eq!( + OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Provider mismatch".to_string()), + error_uri: None, + state: None, + }, + verify_login_attempt( + &attempt, + "github", + attempt.client_id, + &attempt.redirect_uri, + verifier.secret().as_str(), ) .unwrap_err() ); + // Exchanging against the correct provider must succeed assert_eq!( (), verify_login_attempt( &attempt, + "google", attempt.client_id, &attempt.redirect_uri, - Some(verifier.secret().as_str()), + verifier.secret().as_str(), ) .unwrap() ); } + + #[test] + fn test_login_attempt_cookie_has_path() { + let cookie = + super::build_login_attempt_cookie("test-attempt-id", "https://example.com", 600); + + assert_eq!(cookie.path(), Some(super::LOGIN_ATTEMPT_COOKIE_PATH)); + } + + #[test] + fn test_login_attempt_cookie_is_http_only() { + let cookie = + super::build_login_attempt_cookie("test-attempt-id", "https://example.com", 600); + + assert_eq!(cookie.http_only(), Some(true)); + } + + #[test] + fn test_login_attempt_cookie_is_same_site_lax() { + let cookie = + super::build_login_attempt_cookie("test-attempt-id", "https://example.com", 600); + + assert_eq!(cookie.same_site(), Some(cookie::SameSite::Lax)); + } + + #[test] + fn test_login_attempt_cookie_is_secure_for_https() { + let https_cookie = + super::build_login_attempt_cookie("test-attempt-id", "https://example.com", 600); + assert_eq!(https_cookie.secure(), Some(true)); + + let http_cookie = + super::build_login_attempt_cookie("test-attempt-id", "http://localhost", 600); + assert_eq!(http_cookie.secure(), Some(false)); + } + + #[test] + fn test_login_attempt_clear_cookie_has_same_path() { + // The clear cookie must use the same Path as the set cookie, + // otherwise browsers won't clear it. + let set_cookie = + super::build_login_attempt_cookie("test-attempt-id", "https://example.com", 600); + let clear_cookie = super::build_login_attempt_cookie("", "https://example.com", 0); + + assert_eq!(set_cookie.path(), clear_cookie.path()); + assert_eq!( + clear_cookie.max_age(), + Some(cookie::time::Duration::seconds(0)) + ); + } } diff --git a/v-api/src/endpoints/login/oauth/flow/device_token.rs b/v-api/src/endpoints/login/oauth/flow/device_token.rs new file mode 100644 index 00000000..064ff2d5 --- /dev/null +++ b/v-api/src/endpoints/login/oauth/flow/device_token.rs @@ -0,0 +1,482 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use chrono::{DateTime, Utc}; +use dropshot::{Body, HttpError, HttpResponseOk, Method, Path, RequestContext, TypedBody}; +use http::{header, HeaderMap, HeaderValue, Response, StatusCode}; +use hyper::body::Bytes; +use oauth2::{basic::BasicTokenType, EmptyExtraTokenFields, StandardTokenResponse, TokenResponse}; +use schemars::JsonSchema; +use secrecy::ExposeSecret; +use serde::{Deserialize, Serialize}; +use std::fmt::Debug; +use tap::TapFallible; +use tracing::instrument; +use v_model::permissions::PermissionStorage; + +use super::super::OAuthProviderNameParam; +use crate::endpoints::login::UserInfoProvider; +use crate::{ + context::ApiContext, + endpoints::login::{oauth::OAuthProviderDeviceInfo, LoginError}, + error::ApiError, + permissions::VAppPermission, + response::internal_error, + util::response::bad_request, +}; + +#[instrument(skip(rqctx), err(Debug))] +pub async fn get_device_provider_op( + rqctx: &RequestContext>, + path: Path, +) -> Result, HttpError> +where + T: VAppPermission + PermissionStorage, +{ + let path = path.into_inner(); + + tracing::trace!("Getting OAuth data for {}", path.provider); + + let provider = rqctx + .v_ctx() + .get_oauth_provider(&path.provider) + .await + .map_err(ApiError::OAuth)?; + + Ok(HttpResponseOk( + provider + .device_code_flow_info() + .cloned() + .ok_or_else(|| bad_request("Provider does not support device clients"))?, + )) +} + +#[derive(Debug, Deserialize, Serialize, JsonSchema)] +pub struct AccessTokenExchangeRequest { + pub device_code: String, + pub grant_type: String, + pub expires_at: Option>, +} + +#[derive(Serialize)] +pub struct AccessTokenExchange { + provider: ProviderTokenExchange, + expires_at: Option>, +} + +#[derive(Serialize)] +pub struct ProviderTokenExchange { + client_id: String, + device_code: String, + grant_type: String, + client_secret: String, +} + +impl AccessTokenExchange { + pub fn new(req: AccessTokenExchangeRequest, provider: &OAuthProviderDeviceInfo) -> Self { + Self { + provider: ProviderTokenExchange { + client_id: provider.remote_client_id.clone(), + device_code: req.device_code, + grant_type: req.grant_type, + client_secret: provider.remote_client_secret.0.expose_secret().to_string(), + }, + expires_at: req.expires_at, + } + } +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct ProxyTokenResponse { + pub access_token: String, + pub token_type: String, + pub expires_in: Option, + pub refresh_token: Option, + pub scopes: Option>, +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct ProxyTokenError { + error: String, + error_description: Option, + error_uri: Option, +} + +// Complete a device exchange request against the specified provider. This effectively proxies the +// requests that would go to the provider, captures the returned access tokens, and registers a +// new internal user as needed. The user is then returned an token that is valid for interacting +// with the API +#[instrument(skip(rqctx, body), err(Debug))] +pub async fn exchange_device_token_op( + rqctx: &RequestContext>, + path: Path, + body: TypedBody, +) -> Result, HttpError> +where + T: VAppPermission + PermissionStorage, +{ + let ctx = rqctx.v_ctx(); + let path = path.into_inner(); + let provider = ctx + .get_oauth_provider(&path.provider) + .await + .map_err(ApiError::OAuth)?; + let device_info = provider.device_code_flow_info(); + + tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for token exchange"); + + if device_info.is_none() { + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .header(header::CONTENT_TYPE, "application/json") + .body( + serde_json::to_vec(&ProxyTokenError { + error: "unsupported_grant_type".to_string(), + error_description: Some(format!( + "{} does not support device code flow", + path.provider + )), + error_uri: None, + }) + .unwrap() + .into(), + )?); + } + + let device_info = device_info.unwrap(); + let exchange_request = body.into_inner(); + let exchange = AccessTokenExchange::new(exchange_request, device_info); + + let client = reqwest::Client::new(); + + let response = client + .request(Method::POST, &device_info.token_endpoint) + .header( + header::CONTENT_TYPE, + &device_info.token_endpoint_content_type, + ) + .header(header::ACCEPT, HeaderValue::from_static("application/json")) + .body( + // We know that this is safe to unwrap as we just deserialized it via the body Extractor + serde_urlencoded::to_string(&exchange.provider).unwrap(), + ) + .send() + .await + .tap_err(|err| tracing::error!(?err, "Token exchange request failed")) + .map_err(internal_error)?; + + // Take a part the response as we will need the individual parts later + let status = response.status(); + let headers = response.headers().clone(); + let bytes = response.bytes().await.map_err(internal_error)?; + + // We unfortunately can not trust our providers to follow specs and therefore need to do + // our own inspection of the response to determine what to do + if !status.is_success() { + // If the server returned a non-success status then we are going to trust the server and + // report their error back to the client + tracing::debug!(provider = ?path.provider, ?headers, ?status, "Received error response from OAuth provider"); + + Ok(proxy_upstream_response(bytes, headers, status)) + } else { + // The server gave us back a non-error response but it still may not be a success. + // GitHub for instance does not use a status code for indicating the success or failure + // of a call. So instead we try to deserialize the body into an access token, with the + // understanding that it may fail and we will need to try and treat the response as + // an error instead. + + let parsed: Result< + StandardTokenResponse, + serde_json::Error, + > = serde_json::from_slice(&bytes); + + match parsed { + Ok(parsed) => { + let info = provider + .get_user_info(parsed.access_token().secret()) + .await + .map_err(LoginError::UserInfo) + .tap_err(|err| tracing::error!(?err, "Failed to look up user information"))?; + + tracing::debug!("Verified and validated OAuth user"); + + let (api_user_info, api_user_provider) = ctx + .register_api_user(&ctx.builtin_registration_user(), info) + .await?; + + tracing::info!(api_user_id = ?api_user_info.user.id, api_user_provider_id = ?api_user_provider.id, "Retrieved api user to generate device token for"); + + let claims = + ctx.generate_claims(&api_user_info.user.id, &api_user_provider.id, None); + let token = ctx + .user + .register_access_token( + &ctx.builtin_registration_user(), + ctx.jwt_signer(), + &api_user_info.user.id, + &claims, + ) + .await?; + + tracing::info!(provider = ?path.provider, api_user_id = ?api_user_info.user.id, "Generated access token"); + + Ok(Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "application/json") + .body( + serde_json::to_string(&ProxyTokenResponse { + access_token: token.signed_token, + token_type: "Bearer".to_string(), + expires_in: Some(claims.exp - Utc::now().timestamp()), + refresh_token: None, + scopes: None, + }) + .unwrap() + .into(), + )?) + } + Err(_) => { + // Do not log the error here as we want to ensure we do not leak token information + tracing::debug!( + "Failed to parse a success response from the remote token endpoint" + ); + + Ok(handle_token_parse_failure( + &path.provider.to_string(), + bytes, + headers, + status, + )) + } + } + } +} + +/// Headers that are safe to forward from an upstream OAuth provider response. +/// Only `Content-Type` is needed so the client can parse the body. Polling backoff +/// is handled via the JSON body per RFC 8628 (`interval` field / `slow_down` error), +/// not via HTTP headers. +const FORWARDED_HEADERS: &[header::HeaderName] = &[header::CONTENT_TYPE]; + +/// Copy only allowlisted headers from an upstream response to avoid forwarding +/// dangerous headers such as `Set-Cookie`, `Location`, or CORS headers. +fn filter_upstream_headers(upstream: &HeaderMap) -> HeaderMap { + let mut filtered = HeaderMap::new(); + for name in FORWARDED_HEADERS { + if let Some(value) = upstream.get(name) { + filtered.insert(name.clone(), value.clone()); + } + } + filtered +} + +/// Build a response to the client by proxying an upstream provider's response. This is used +/// when the upstream provider returns a non-success status code. +fn proxy_upstream_response(bytes: Bytes, headers: HeaderMap, status: StatusCode) -> Response { + let mut client_response = Response::new(Body::from(bytes)); + *client_response.headers_mut() = filter_upstream_headers(&headers); + *client_response.status_mut() = status; + client_response +} + +/// Handle the case where the upstream provider returned a 200 status but the body could not be +/// parsed as a valid token response. We try to interpret the body as an error response and proxy +/// it back. If the body is not a recognizable error either, we return our own error. +fn handle_token_parse_failure( + provider_name: &str, + bytes: Bytes, + headers: HeaderMap, + status: StatusCode, +) -> Response { + // Try to deserialize the body as an error + let mut error_response = match serde_json::from_slice::(&bytes) { + Ok(error) => { + // We found an error in the message body. This is not ideal, but we at + // least can understand what the server was trying to tell us + tracing::debug!( + ?error, + provider_name, + "Parsed error response from OAuth provider" + ); + + let mut client_response = Response::new(Body::from(bytes)); + *client_response.headers_mut() = filter_upstream_headers(&headers); + *client_response.status_mut() = status; + + client_response + } + Err(_) => { + // We still do not know what the remote server is doing... and need to + // cancel the request ourselves + tracing::warn!("Remote OAuth provider returned a response that we do not understand"); + + Response::new( + serde_json::to_vec(&ProxyTokenError { + error: "access_denied".to_string(), + error_description: Some(format!( + "{} returned a malformed response", + provider_name + )), + error_uri: None, + }) + .unwrap() + .into(), + ) + } + }; + + *error_response.status_mut() = StatusCode::BAD_REQUEST; + error_response.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + + error_response +} + +#[cfg(test)] +mod tests { + use http::{ + header::{self, HeaderName, SET_COOKIE}, + HeaderMap, HeaderValue, StatusCode, + }; + use hyper::body::Bytes; + + use super::{handle_token_parse_failure, proxy_upstream_response}; + + #[test] + fn test_upstream_set_cookie_is_stripped_from_error_response() { + // A malicious or compromised upstream provider includes a Set-Cookie header + // that would set a cookie on our API's domain in the user's browser + let mut upstream_headers = HeaderMap::new(); + upstream_headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + upstream_headers.insert( + SET_COOKIE, + HeaderValue::from_static("session=malicious-value; Path=/; HttpOnly"), + ); + + let body = Bytes::from_static(b"{\"error\": \"authorization_pending\"}"); + let response = proxy_upstream_response(body, upstream_headers, StatusCode::FORBIDDEN); + + // The Set-Cookie header must NOT be forwarded to our client + assert!( + response.headers().get(SET_COOKIE).is_none(), + "Upstream Set-Cookie header must not be forwarded to the client" + ); + // But Content-Type should still be forwarded + assert!( + response.headers().get(header::CONTENT_TYPE).is_some(), + "Content-Type should be forwarded from upstream" + ); + } + + #[test] + fn test_upstream_cors_headers_are_stripped_from_error_response() { + // A malicious upstream provider injects permissive CORS headers that would + // weaken our API's cross-origin protections + let mut upstream_headers = HeaderMap::new(); + upstream_headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + upstream_headers.insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + HeaderValue::from_static("*"), + ); + upstream_headers.insert( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + + let body = Bytes::from_static(b"{\"error\": \"authorization_pending\"}"); + let response = proxy_upstream_response(body, upstream_headers, StatusCode::BAD_REQUEST); + + // CORS headers must NOT be forwarded from upstream + assert!( + response + .headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .is_none(), + "Upstream CORS header must not be forwarded to the client" + ); + assert!( + response + .headers() + .get(header::ACCESS_CONTROL_ALLOW_CREDENTIALS) + .is_none(), + "Upstream CORS credentials header must not be forwarded to the client" + ); + } + + #[test] + fn test_upstream_location_and_framing_headers_are_stripped_from_token_parse_failure() { + // When the upstream returns a 200 status but the body is an error (not a valid + // token), handle_token_parse_failure must not forward dangerous headers. + let mut upstream_headers = HeaderMap::new(); + upstream_headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + upstream_headers.insert( + header::LOCATION, + HeaderValue::from_static("https://evil.example.com/phishing"), + ); + upstream_headers.insert( + HeaderName::from_static("x-frame-options"), + HeaderValue::from_static("ALLOW-FROM https://evil.example.com"), + ); + + // Body that parses as a ProxyTokenError but NOT as a valid token + let body = Bytes::from_static( + b"{\"error\": \"slow_down\", \"error_description\": null, \"error_uri\": null}", + ); + let response = + handle_token_parse_failure("test-provider", body, upstream_headers, StatusCode::OK); + + // Dangerous headers must NOT be forwarded + assert!( + response.headers().get(header::LOCATION).is_none(), + "Upstream Location header must not be forwarded to the client" + ); + assert!( + response.headers().get("x-frame-options").is_none(), + "Upstream X-Frame-Options header must not be forwarded to the client" + ); + // But Content-Type should still be present (set by the function itself) + assert!( + response.headers().get(header::CONTENT_TYPE).is_some(), + "Content-Type should be present on the response" + ); + } + + #[test] + fn test_upstream_set_cookie_is_stripped_from_token_parse_failure() { + // Even when the upstream returns 200 and the body is a parseable error, + // a Set-Cookie header must not be forwarded to our client + let mut upstream_headers = HeaderMap::new(); + upstream_headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + upstream_headers.insert( + SET_COOKIE, + HeaderValue::from_static("tracking=evil-tracker; Domain=.our-api.com; Path=/"), + ); + + let body = Bytes::from_static( + b"{\"error\": \"authorization_pending\", \"error_description\": null, \"error_uri\": null}", + ); + let response = + handle_token_parse_failure("test-provider", body, upstream_headers, StatusCode::OK); + + // The Set-Cookie header must NOT be forwarded + assert!( + response.headers().get(SET_COOKIE).is_none(), + "Upstream Set-Cookie header must not be forwarded via token parse failure path" + ); + } +} diff --git a/v-api/src/endpoints/login/oauth/flow/mod.rs b/v-api/src/endpoints/login/oauth/flow/mod.rs new file mode 100644 index 00000000..304abf26 --- /dev/null +++ b/v-api/src/endpoints/login/oauth/flow/mod.rs @@ -0,0 +1,2 @@ +pub mod code; +pub mod device_token; diff --git a/v-api/src/endpoints/login/oauth/google.rs b/v-api/src/endpoints/login/oauth/google.rs deleted file mode 100644 index a38024d7..00000000 --- a/v-api/src/endpoints/login/oauth/google.rs +++ /dev/null @@ -1,188 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -use hyper::body::Bytes; -use reqwest::Request; -use secrecy::SecretString; -use serde::Deserialize; -use std::fmt; - -use crate::endpoints::login::{ExternalUserId, UserInfo, UserInfoError}; - -use super::{ - ClientType, ExtractUserInfo, OAuthPrivateCredentials, OAuthProvider, OAuthProviderName, - OAuthPublicCredentials, -}; - -pub struct GoogleOAuthProvider { - device_public: OAuthPublicCredentials, - device_private: Option, - web_public: OAuthPublicCredentials, - web_private: Option, - additional_scopes: Vec, - client: reqwest::Client, -} - -impl fmt::Debug for GoogleOAuthProvider { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("GoogleOAuthProvider").finish() - } -} - -impl GoogleOAuthProvider { - pub fn new( - device_client_id: String, - device_client_secret: SecretString, - web_client_id: String, - web_client_secret: SecretString, - additional_scopes: Option>, - ) -> Self { - Self { - device_public: OAuthPublicCredentials { - client_id: device_client_id, - }, - device_private: Some(OAuthPrivateCredentials { - client_secret: device_client_secret, - }), - web_public: OAuthPublicCredentials { - client_id: web_client_id, - }, - web_private: Some(OAuthPrivateCredentials { - client_secret: web_client_secret, - }), - additional_scopes: additional_scopes.unwrap_or_default(), - client: reqwest::ClientBuilder::new() - .redirect(reqwest::redirect::Policy::none()) - .build() - .expect("Static client must build"), - } - } - - pub fn with_client(&mut self, client: reqwest::Client) -> &mut Self { - self.client = client; - self - } -} - -#[derive(Debug, Deserialize)] -struct GoogleUserInfo { - sub: String, - email: String, - email_verified: bool, -} - -#[derive(Debug, Deserialize)] -struct GoogleProfile { - #[serde(default)] - names: Vec, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -struct GoogleProfileName { - display_name: String, - metadata: GoogleProfileNameMeta, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -struct GoogleProfileNameMeta { - #[serde(default)] - primary: bool, -} - -impl ExtractUserInfo for GoogleOAuthProvider { - // There should always be as many entries in the data list as there are endpoints. This should - // be changed in the future to be a static check - fn extract_user_info(&self, data: &[Bytes]) -> Result { - let remote_info: GoogleUserInfo = serde_json::from_slice(&data[0])?; - let verified_emails = if remote_info.email_verified { - vec![remote_info.email] - } else { - vec![] - }; - - let profile_info: GoogleProfile = serde_json::from_slice(&data[1])?; - let display_name = profile_info - .names - .into_iter() - .filter_map(|name| name.metadata.primary.then_some(name.display_name)) - .nth(0); - - Ok(UserInfo { - external_id: ExternalUserId::Google(remote_info.sub), - verified_emails, - display_name, - }) - } -} - -impl OAuthProvider for GoogleOAuthProvider { - fn name(&self) -> OAuthProviderName { - OAuthProviderName::Google - } - - fn scopes(&self) -> Vec<&str> { - let mut default = vec!["openid", "email", "profile"]; - default.extend(self.additional_scopes.iter().map(|s| s.as_str())); - default - } - - fn initialize_headers(&self, _request: &mut Request) {} - - fn client(&self) -> &reqwest::Client { - &self.client - } - - fn client_id(&self, client_type: &ClientType) -> &str { - match client_type { - ClientType::Device => &self.device_public.client_id, - ClientType::Web => &self.web_public.client_id, - } - } - - fn client_secret(&self, client_type: &ClientType) -> Option<&SecretString> { - match client_type { - ClientType::Device => self - .device_private - .as_ref() - .map(|private| &private.client_secret), - ClientType::Web => self - .web_private - .as_ref() - .map(|private| &private.client_secret), - } - } - - fn user_info_endpoints(&self) -> Vec<&str> { - vec![ - "https://openidconnect.googleapis.com/v1/userinfo", - "https://people.googleapis.com/v1/people/me?personFields=names", - ] - } - - fn device_code_endpoint(&self) -> &str { - "https://oauth2.googleapis.com/device/code" - } - - fn auth_url_endpoint(&self) -> &str { - "https://accounts.google.com/o/oauth2/v2/auth" - } - - fn token_exchange_content_type(&self) -> &str { - "application/x-www-form-urlencoded" - } - - fn token_exchange_endpoint(&self) -> &str { - "https://oauth2.googleapis.com/token" - } - - fn token_revocation_endpoint(&self) -> Option<&str> { - Some("https://oauth2.googleapis.com/revoke") - } - - fn supports_pkce(&self) -> bool { - true - } -} diff --git a/v-api/src/endpoints/login/oauth/mod.rs b/v-api/src/endpoints/login/oauth/mod.rs index f89d822d..bc071c2c 100644 --- a/v-api/src/endpoints/login/oauth/mod.rs +++ b/v-api/src/endpoints/login/oauth/mod.rs @@ -5,44 +5,48 @@ use async_trait::async_trait; use http::Method; use hyper::{body::Bytes, header::HeaderValue, header::AUTHORIZATION}; +use newtype_uuid::TypedUuid; use oauth2::{ basic::BasicClient, url::ParseError, AuthUrl, ClientId, ClientSecret, EndpointMaybeSet, EndpointNotSet, EndpointSet, RedirectUrl, RevocationUrl, TokenUrl, }; use reqwest::Request; use schemars::JsonSchema; -use secrecy::{ExposeSecret, SecretString}; +use secrecy::ExposeSecret; use serde::{Deserialize, Serialize}; use std::fmt::{Debug, Display}; use thiserror::Error; use tracing::instrument; -use v_model::OAuthClient; +use v_model::{OAuthClient, OAuthClientId}; -use crate::authn::{key::RawKey, Verify}; +use crate::{ + authn::{key::RawKey, Verify}, + secrets::OpenApiSecretString, +}; -use super::{UserInfo, UserInfoError, UserInfoProvider}; +use super::{is_redirect_uri_valid, UserInfo, UserInfoError, UserInfoProvider}; pub mod client; -pub mod code; -pub mod device_token; -pub mod github; -pub mod google; +pub mod flow; +pub mod remote; #[derive(Debug, Error)] pub enum OAuthProviderError { #[error("Unable to instantiate invalid provider")] FailToCreateInvalidProvider, + #[error("Missing redirect URI")] + MissingRedirectUri, + #[error("Failed to parse URL")] + UrlParseError(#[from] ParseError), + #[error("Provider does not support web clients")] + WebClientNotSupported, } #[derive(Debug)] pub enum ClientType { Device, Web, -} - -#[derive(Debug)] -pub struct WebClientConfig { - prefix: String, + WebPkce, } pub type WebClient = BasicClient< @@ -58,67 +62,46 @@ pub type WebClient = BasicClient< EndpointSet, >; -pub struct OAuthPublicCredentials { - client_id: String, -} - -pub struct OAuthPrivateCredentials { - client_secret: SecretString, -} - pub trait OAuthProvider: ExtractUserInfo + Debug + Send + Sync { fn name(&self) -> OAuthProviderName; - fn scopes(&self) -> Vec<&str>; fn initialize_headers(&self, request: &mut Request); fn client(&self) -> &reqwest::Client; - fn client_id(&self, client_type: &ClientType) -> &str; - fn client_secret(&self, client_type: &ClientType) -> Option<&SecretString>; - - // TODO: How can user info be change to something statically checked instead of a runtime check fn user_info_endpoints(&self) -> Vec<&str>; - fn device_code_endpoint(&self) -> &str; - fn auth_url_endpoint(&self) -> &str; - fn token_exchange_content_type(&self) -> &str; - fn token_exchange_endpoint(&self) -> &str; - fn token_revocation_endpoint(&self) -> Option<&str>; - fn supports_pkce(&self) -> bool; - fn provider_info(&self, public_url: &str, client_type: &ClientType) -> OAuthProviderInfo { - OAuthProviderInfo { - provider: self.name(), - client_id: self.client_id(client_type).to_string(), - auth_url_endpoint: self.auth_url_endpoint().to_string(), - device_code_endpoint: self.device_code_endpoint().to_string(), - token_endpoint: format!("{}/login/oauth/{}/device/exchange", public_url, self.name(),), - scopes: self - .scopes() - .into_iter() - .map(|s| s.to_string()) - .collect::>(), - } - } + fn authz_code_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodeInfo>; + fn authz_code_pkce_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodePkceInfo>; + fn device_code_flow_info(&self) -> Option<&OAuthProviderDeviceInfo>; - fn as_web_client(&self, config: &WebClientConfig) -> Result { - let mut client = - BasicClient::new(ClientId::new(self.client_id(&ClientType::Web).to_string())) - .set_auth_uri(AuthUrl::new(self.auth_url_endpoint().to_string())?) - .set_token_uri(TokenUrl::new(self.token_exchange_endpoint().to_string())?) - .set_revocation_url_option( - self.token_revocation_endpoint() - .map(|s| RevocationUrl::new(s.to_string())) - .transpose()?, - ) - .set_redirect_uri(RedirectUrl::new(format!( - "{}/login/oauth/{}/code/callback", - &config.prefix, - self.name() - ))?); - - if let Some(secret) = self.client_secret(&ClientType::Web) { - client = client.set_client_secret(ClientSecret::new(secret.expose_secret().to_string())) - } + fn default_scopes(&self) -> &[String]; + + /// Whether the remote OAuth provider supports PKCE (RFC 7636). Providers must + /// explicitly declare this. This controls whether v-api sends a PKCE challenge + /// to the remote provider during the authorization code exchange. Note: clients + /// calling v-api are always required to use PKCE regardless of this setting. + fn supports_pkce(&self) -> bool; - Ok(client) + fn as_web_client(&self) -> Result { + match self.authz_code_flow_info() { + Some(info) => { + let client = BasicClient::new(ClientId::new(info.remote.client_id.clone())) + .set_auth_uri(AuthUrl::new(info.remote.auth_url_endpoint.clone())?) + .set_token_uri(TokenUrl::new(info.remote.token_endpoint.clone())?) + .set_revocation_url_option( + info.remote + .revocation_endpoint + .as_ref() + .map(|url| RevocationUrl::new(url.to_string())) + .transpose()?, + ) + .set_redirect_uri(RedirectUrl::new(info.redirect_endpoint.to_string())?) + .set_client_secret(ClientSecret::new( + info.remote.client_secret.0.expose_secret().to_string(), + )); + + Ok(client) + } + None => Err(OAuthProviderError::WebClientNotSupported), + } } } @@ -155,33 +138,88 @@ where ); let response = self.client().execute(request).await?; - - tracing::trace!(status = ?response.status(), "Received response from OAuth provider"); + let status = response.status(); + + tracing::trace!(?status, "Received response from OAuth provider"); + + if !status.is_success() { + tracing::error!( + ?status, + endpoint, + "User info endpoint returned non-success status" + ); + return Err(UserInfoError::UnexpectedStatus { + endpoint: endpoint.to_string(), + status, + }); + } let bytes = response.bytes().await?; responses.push(bytes); } - self.extract_user_info(&responses) + let mut info = self.extract_user_info(&responses)?; + info.idp_token = Some(token.to_string()); + Ok(info) } } -#[derive(Debug, Deserialize, Serialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct OAuthProviderInfo { provider: OAuthProviderName, client_id: String, + code: Option, + pkce: Option, + device: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub struct OAuthProviderAuthorizationCodeInfo { + auth_url_endpoint: String, + redirect_endpoint: String, + token_endpoint_content_type: String, + token_endpoint: String, + remote: OAuthProviderAuthorizationCodeRemoteInfo, +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub struct OAuthProviderAuthorizationCodeRemoteInfo { + client_id: String, + #[serde(skip_serializing)] + client_secret: OpenApiSecretString, auth_url_endpoint: String, + token_endpoint_content_type: String, + token_endpoint: String, + revocation_endpoint: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub struct OAuthProviderAuthorizationCodePkceInfo { + client_id: TypedUuid, + redirect_endpoint: String, + proxy_port: u16, + web: OAuthProviderAuthorizationCodeInfo, +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub struct OAuthProviderDeviceInfo { + client_id: TypedUuid, + remote_client_id: String, + #[serde(skip_serializing)] + remote_client_secret: OpenApiSecretString, device_code_endpoint: String, + token_endpoint_content_type: String, token_endpoint: String, - scopes: Vec, + revocation_endpoint: Option, } -#[derive(Debug, Deserialize, PartialEq, Eq, Hash, Serialize, JsonSchema)] +#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Hash, Serialize, JsonSchema)] #[serde(rename_all = "kebab-case")] pub enum OAuthProviderName { #[serde(rename = "github")] GitHub, Google, + Zendesk, } impl Display for OAuthProviderName { @@ -189,6 +227,7 @@ impl Display for OAuthProviderName { match self { OAuthProviderName::GitHub => write!(f, "github"), OAuthProviderName::Google => write!(f, "google"), + OAuthProviderName::Zendesk => write!(f, "zendesk"), } } } @@ -224,8 +263,9 @@ impl CheckOAuthClient for OAuthClient { fn is_redirect_uri_valid(&self, redirect_uri: &str) -> bool { tracing::trace!(?redirect_uri, valid_uris = ?self.redirect_uris, "Checking redirect uri against list of valid uris"); - self.redirect_uris - .iter() - .any(|r| r.redirect_uri == redirect_uri) + is_redirect_uri_valid( + redirect_uri, + self.redirect_uris.iter().map(|r| r.redirect_uri.as_str()), + ) } } diff --git a/v-api/src/endpoints/login/oauth/github.rs b/v-api/src/endpoints/login/oauth/remote/github.rs similarity index 50% rename from v-api/src/endpoints/login/oauth/github.rs rename to v-api/src/endpoints/login/oauth/remote/github.rs index bce13561..7d59e260 100644 --- a/v-api/src/endpoints/login/oauth/github.rs +++ b/v-api/src/endpoints/login/oauth/remote/github.rs @@ -5,26 +5,27 @@ use http::{header::USER_AGENT, HeaderMap, HeaderValue}; use hyper::body::Bytes; use reqwest::Request; -use secrecy::SecretString; use serde::Deserialize; use std::fmt; -use crate::endpoints::login::{ExternalUserId, UserInfo, UserInfoError}; - -use super::{ - ClientType, ExtractUserInfo, OAuthPrivateCredentials, OAuthProvider, OAuthProviderName, - OAuthPublicCredentials, +use crate::{ + config::ResolvedOAuthConfig, + endpoints::login::{ + oauth::{ + OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, + OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo, + }, + ExternalUserId, UserInfo, UserInfoError, + }, }; +use super::super::{ExtractUserInfo, OAuthProvider, OAuthProviderName}; + pub struct GitHubOAuthProvider { - // public: GitHubPublicProvider, - // private: Option, - device_public: OAuthPublicCredentials, - device_private: Option, - web_public: OAuthPublicCredentials, - web_private: Option, - additional_scopes: Vec, + authz_code_flow_info: Option, + device_code_flow_info: Option, default_headers: HeaderMap, + default_scopes: Vec, client: reqwest::Client, } @@ -36,30 +37,45 @@ impl fmt::Debug for GitHubOAuthProvider { impl GitHubOAuthProvider { pub fn new( - device_client_id: String, - device_client_secret: SecretString, - web_client_id: String, - web_client_secret: SecretString, + config: ResolvedOAuthConfig, + public_url: String, additional_scopes: Option>, ) -> Self { let mut headers = HeaderMap::new(); headers.insert(USER_AGENT, HeaderValue::from_static("v-api")); - Self { - device_public: OAuthPublicCredentials { - client_id: device_client_id, - }, - device_private: Some(OAuthPrivateCredentials { - client_secret: device_client_secret, - }), - web_public: OAuthPublicCredentials { - client_id: web_client_id, + let mut default_scopes = vec!["user:email".to_string()]; + default_scopes.extend(additional_scopes.unwrap_or_default()); + + let authz_code_flow_info = config.web.map(|web| OAuthProviderAuthorizationCodeInfo { + auth_url_endpoint: format!("{}/login/oauth/github/code/authorize", public_url), + redirect_endpoint: format!("{}/login/oauth/github/code/callback", public_url), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: format!("{}/login/oauth/github/device/exchange", public_url), + remote: OAuthProviderAuthorizationCodeRemoteInfo { + client_id: web.remote_client_id, + client_secret: web.remote_client_secret.into(), + auth_url_endpoint: "https://github.com/login/oauth/authorize".to_string(), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: "https://github.com/login/oauth/access_token".to_string(), + revocation_endpoint: None, }, - web_private: Some(OAuthPrivateCredentials { - client_secret: web_client_secret, - }), - additional_scopes: additional_scopes.unwrap_or_default(), + }); + let device_code_flow_info = config.device.map(|device| OAuthProviderDeviceInfo { + client_id: device.client_id, + remote_client_id: device.remote_client_id, + remote_client_secret: device.remote_client_secret.into(), + device_code_endpoint: "https://github.com/login/device/code".to_string(), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: "https://github.com/login/oauth/access_token".to_string(), + revocation_endpoint: None, + }); + + Self { + authz_code_flow_info, + device_code_flow_info, default_headers: headers, + default_scopes, client: reqwest::ClientBuilder::new() .redirect(reqwest::redirect::Policy::none()) .build() @@ -104,6 +120,7 @@ impl ExtractUserInfo for GitHubOAuthProvider { external_id: ExternalUserId::GitHub(user.id.to_string()), verified_emails, display_name: Some(user.login), + idp_token: None, }) } } @@ -112,69 +129,32 @@ impl OAuthProvider for GitHubOAuthProvider { fn name(&self) -> OAuthProviderName { OAuthProviderName::GitHub } - - fn scopes(&self) -> Vec<&str> { - let mut default = vec!["user:email"]; - default.extend(self.additional_scopes.iter().map(|s| s.as_str())); - default - } - fn initialize_headers(&self, request: &mut Request) { *request.headers_mut() = self.default_headers.clone(); } - fn client(&self) -> &reqwest::Client { &self.client } - - fn client_id(&self, client_type: &ClientType) -> &str { - match client_type { - ClientType::Device => &self.device_public.client_id, - ClientType::Web => &self.web_public.client_id, - } - } - - fn client_secret(&self, client_type: &ClientType) -> Option<&SecretString> { - match client_type { - ClientType::Device => self - .device_private - .as_ref() - .map(|private| &private.client_secret), - ClientType::Web => self - .web_private - .as_ref() - .map(|private| &private.client_secret), - } - } - fn user_info_endpoints(&self) -> Vec<&str> { vec![ "https://api.github.com/user", "https://api.github.com/user/emails", ] } - - fn device_code_endpoint(&self) -> &str { - "https://github.com/login/device/code" - } - - fn auth_url_endpoint(&self) -> &str { - "https://github.com/login/oauth/authorize" + fn default_scopes(&self) -> &[String] { + &self.default_scopes } - - fn token_exchange_content_type(&self) -> &str { - "application/x-www-form-urlencoded" + fn supports_pkce(&self) -> bool { + false } - fn token_exchange_endpoint(&self) -> &str { - "https://github.com/login/oauth/access_token" + fn authz_code_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodeInfo> { + self.authz_code_flow_info.as_ref() } - - fn token_revocation_endpoint(&self) -> Option<&str> { + fn authz_code_pkce_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodePkceInfo> { None } - - fn supports_pkce(&self) -> bool { - true + fn device_code_flow_info(&self) -> Option<&OAuthProviderDeviceInfo> { + self.device_code_flow_info.as_ref() } } diff --git a/v-api/src/endpoints/login/oauth/remote/google.rs b/v-api/src/endpoints/login/oauth/remote/google.rs new file mode 100644 index 00000000..23879e89 --- /dev/null +++ b/v-api/src/endpoints/login/oauth/remote/google.rs @@ -0,0 +1,185 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use hyper::body::Bytes; +use reqwest::Request; +use serde::Deserialize; +use std::fmt; + +use crate::{ + config::ResolvedOAuthConfig, + endpoints::login::{ + oauth::{ + OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, + OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo, + }, + ExternalUserId, UserInfo, UserInfoError, + }, +}; + +use super::super::{ExtractUserInfo, OAuthProvider, OAuthProviderName}; + +pub struct GoogleOAuthProvider { + authz_code_flow_info: Option, + authz_code_pkce_flow_info: Option, + device_code_flow_info: Option, + default_scopes: Vec, + client: reqwest::Client, +} + +impl fmt::Debug for GoogleOAuthProvider { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GoogleOAuthProvider").finish() + } +} + +impl GoogleOAuthProvider { + pub fn new( + config: ResolvedOAuthConfig, + public_url: String, + additional_scopes: Option>, + ) -> Self { + let mut default_scopes = vec![ + "openid".to_string(), + "email".to_string(), + "profile".to_string(), + ]; + default_scopes.extend(additional_scopes.unwrap_or_default()); + + let authz_code_flow_info = config.web.map(|web| OAuthProviderAuthorizationCodeInfo { + auth_url_endpoint: format!("{}/login/oauth/google/code/authorize", public_url), + redirect_endpoint: format!("{}/login/oauth/google/code/callback", public_url), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: format!("{}/login/oauth/google/device/exchange", public_url), + remote: OAuthProviderAuthorizationCodeRemoteInfo { + client_id: web.remote_client_id, + client_secret: web.remote_client_secret.into(), + auth_url_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: "https://oauth2.googleapis.com/token".to_string(), + revocation_endpoint: Some("https://oauth2.googleapis.com/revoke".to_string()), + }, + }); + let authz_code_pkce_flow_info = config + .proxy_web + .and_then(|proxy| authz_code_flow_info.as_ref().map(|web| (web, proxy))) + .map(|(web, proxy)| OAuthProviderAuthorizationCodePkceInfo { + client_id: proxy.client_id, + redirect_endpoint: proxy.redirect_uri, + proxy_port: proxy.proxy_port, + web: web.clone(), + }); + let device_code_flow_info = config.device.map(|device| OAuthProviderDeviceInfo { + client_id: device.client_id, + remote_client_id: device.remote_client_id, + remote_client_secret: device.remote_client_secret.into(), + device_code_endpoint: "https://oauth2.googleapis.com/device/code".to_string(), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: "https://oauth2.googleapis.com/token".to_string(), + revocation_endpoint: Some("https://oauth2.googleapis.com/revoke".to_string()), + }); + + Self { + authz_code_flow_info, + authz_code_pkce_flow_info, + device_code_flow_info, + default_scopes, + client: reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Static client must build"), + } + } + + pub fn with_client(&mut self, client: reqwest::Client) -> &mut Self { + self.client = client; + self + } +} + +#[derive(Debug, Deserialize)] +struct GoogleUserInfo { + sub: String, + email: String, + email_verified: bool, +} + +#[derive(Debug, Deserialize)] +struct GoogleProfile { + #[serde(default)] + names: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct GoogleProfileName { + display_name: String, + metadata: GoogleProfileNameMeta, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct GoogleProfileNameMeta { + #[serde(default)] + primary: bool, +} + +impl ExtractUserInfo for GoogleOAuthProvider { + // There should always be as many entries in the data list as there are endpoints. This should + // be changed in the future to be a static check + fn extract_user_info(&self, data: &[Bytes]) -> Result { + let remote_info: GoogleUserInfo = serde_json::from_slice(&data[0])?; + let verified_emails = if remote_info.email_verified { + vec![remote_info.email] + } else { + vec![] + }; + + let profile_info: GoogleProfile = serde_json::from_slice(&data[1])?; + let display_name = profile_info + .names + .into_iter() + .filter_map(|name| name.metadata.primary.then_some(name.display_name)) + .nth(0); + + Ok(UserInfo { + external_id: ExternalUserId::Google(remote_info.sub), + verified_emails, + display_name, + idp_token: None, + }) + } +} + +impl OAuthProvider for GoogleOAuthProvider { + fn name(&self) -> OAuthProviderName { + OAuthProviderName::Google + } + fn initialize_headers(&self, _request: &mut Request) {} + fn client(&self) -> &reqwest::Client { + &self.client + } + fn user_info_endpoints(&self) -> Vec<&str> { + vec![ + "https://openidconnect.googleapis.com/v1/userinfo", + "https://people.googleapis.com/v1/people/me?personFields=names", + ] + } + fn default_scopes(&self) -> &[String] { + &self.default_scopes + } + fn supports_pkce(&self) -> bool { + true + } + + fn authz_code_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodeInfo> { + self.authz_code_flow_info.as_ref() + } + fn authz_code_pkce_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodePkceInfo> { + self.authz_code_pkce_flow_info.as_ref() + } + fn device_code_flow_info(&self) -> Option<&OAuthProviderDeviceInfo> { + self.device_code_flow_info.as_ref() + } +} diff --git a/v-api/src/endpoints/login/oauth/remote/mod.rs b/v-api/src/endpoints/login/oauth/remote/mod.rs new file mode 100644 index 00000000..db2a1dd9 --- /dev/null +++ b/v-api/src/endpoints/login/oauth/remote/mod.rs @@ -0,0 +1,3 @@ +pub mod github; +pub mod google; +pub mod zendesk; diff --git a/v-api/src/endpoints/login/oauth/remote/zendesk.rs b/v-api/src/endpoints/login/oauth/remote/zendesk.rs new file mode 100644 index 00000000..51ecd31a --- /dev/null +++ b/v-api/src/endpoints/login/oauth/remote/zendesk.rs @@ -0,0 +1,156 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use hyper::body::Bytes; +use reqwest::Request; +use serde::Deserialize; +use std::fmt; + +use crate::{ + config::ResolvedOAuthConfig, + endpoints::login::{ + oauth::{ + OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, + OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo, + }, + ExternalUserId, UserInfo, UserInfoError, + }, +}; + +use super::super::{ExtractUserInfo, OAuthProvider, OAuthProviderName}; + +pub struct ZendeskOAuthProvider { + authz_code_flow_info: Option, + authz_code_pkce_flow_info: Option, + user_info_endpoint: String, + default_scopes: Vec, + client: reqwest::Client, +} + +impl fmt::Debug for ZendeskOAuthProvider { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ZendeskOAuthProvider").finish() + } +} + +impl ZendeskOAuthProvider { + pub fn new( + config: ResolvedOAuthConfig, + public_url: String, + subdomain: String, + additional_scopes: Option>, + ) -> Self { + let base_url = format!("https://{}.zendesk.com", subdomain); + + let mut default_scopes = vec!["read".to_string(), "write".to_string()]; + default_scopes.extend(additional_scopes.unwrap_or_default()); + + let authz_code_flow_info = config.web.map(|web| OAuthProviderAuthorizationCodeInfo { + auth_url_endpoint: format!("{}/login/oauth/zendesk/code/authorize", public_url), + redirect_endpoint: format!("{}/login/oauth/zendesk/code/callback", public_url), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: format!("{}/login/oauth/zendesk/device/exchange", public_url), + remote: OAuthProviderAuthorizationCodeRemoteInfo { + client_id: web.remote_client_id, + client_secret: web.remote_client_secret.into(), + auth_url_endpoint: format!("{}/oauth/authorizations/new", base_url), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: format!("{}/oauth/tokens", base_url), + revocation_endpoint: None, + }, + }); + let authz_code_pkce_flow_info = config + .proxy_web + .and_then(|proxy| authz_code_flow_info.as_ref().map(|web| (web, proxy))) + .map(|(web, proxy)| OAuthProviderAuthorizationCodePkceInfo { + client_id: proxy.client_id, + redirect_endpoint: proxy.redirect_uri, + proxy_port: proxy.proxy_port, + web: web.clone(), + }); + + Self { + authz_code_flow_info, + authz_code_pkce_flow_info, + user_info_endpoint: format!("{}/api/v2/users/me.json", base_url), + default_scopes, + client: reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Static client must build"), + } + } + + pub fn with_client(&mut self, client: reqwest::Client) -> &mut Self { + self.client = client; + self + } +} + +#[derive(Debug, Deserialize)] +struct ZendeskUserResponse { + user: ZendeskUser, +} + +#[derive(Debug, Deserialize)] +struct ZendeskUser { + id: u64, + name: String, + email: String, + verified: bool, + suspended: bool, +} + +impl ExtractUserInfo for ZendeskOAuthProvider { + fn extract_user_info(&self, data: &[Bytes]) -> Result { + let response: ZendeskUserResponse = serde_json::from_slice(&data[0])?; + let user = response.user; + + if user.suspended { + return Err(UserInfoError::Locked); + } + + let verified_emails = if user.verified { + vec![user.email] + } else { + vec![] + }; + + Ok(UserInfo { + external_id: ExternalUserId::Zendesk(user.id.to_string()), + verified_emails, + display_name: Some(user.name), + idp_token: None, + }) + } +} + +impl OAuthProvider for ZendeskOAuthProvider { + fn name(&self) -> OAuthProviderName { + OAuthProviderName::Zendesk + } + fn initialize_headers(&self, _request: &mut Request) {} + fn client(&self) -> &reqwest::Client { + &self.client + } + fn user_info_endpoints(&self) -> Vec<&str> { + vec![&self.user_info_endpoint] + } + fn default_scopes(&self) -> &[String] { + &self.default_scopes + } + fn supports_pkce(&self) -> bool { + true + } + + fn authz_code_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodeInfo> { + self.authz_code_flow_info.as_ref() + } + fn authz_code_pkce_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodePkceInfo> { + self.authz_code_pkce_flow_info.as_ref() + } + fn device_code_flow_info(&self) -> Option<&OAuthProviderDeviceInfo> { + None + } +} diff --git a/v-api/src/secrets.rs b/v-api/src/secrets.rs index a4f95604..c55ab6a4 100644 --- a/v-api/src/secrets.rs +++ b/v-api/src/secrets.rs @@ -10,7 +10,7 @@ use secrecy::{ExposeSecret, SecretString}; use serde::{Deserialize, Serialize, Serializer}; use std::borrow::Cow; -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct OpenApiSecretString(pub SecretString); impl From for OpenApiSecretString { diff --git a/v-cli-sdk/Cargo.toml b/v-cli-sdk/Cargo.toml new file mode 100644 index 00000000..030df94a --- /dev/null +++ b/v-cli-sdk/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "v-cli-sdk" +version = "0.2.0" +edition = "2021" + +[dependencies] +anyhow = { workspace = true } +chrono = { workspace = true } +clap = { workspace = true } +http = { workspace = true } +http-body-util = { workspace = true } +hyper = { workspace = true, features = ["server", "http1"] } +hyper-util = { workspace = true, features = ["tokio"] } +oauth2 = { workspace = true } +oauth2-reqwest = { workspace = true } +owo-colors = { workspace = true } +progenitor-client = { workspace = true } +reqwest = { workspace = true } +schemars = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tabwriter = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt", "net", "sync"] } +uuid = { workspace = true } diff --git a/v-cli-sdk/src/cmd/auth/login.rs b/v-cli-sdk/src/cmd/auth/login.rs new file mode 100644 index 00000000..3e21691a --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/login.rs @@ -0,0 +1,282 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::Result; +use clap::{Parser, Subcommand, ValueEnum}; +use oauth2::TokenResponse; +use std::{error::Error as StdError, fmt::Debug, future::Future, io::Write, pin::Pin, sync::Arc}; + +use crate::{ + cmd::auth::oauth::{self, CliOAuthAdapter, CliOAuthProviderInfo}, + VCliConfig, VCliContext, +}; + +pub trait CliAdapterToken { + fn access_token(&self) -> &str; + fn idp_token(&self) -> Option<&str>; +} + +pub trait CliConsumerLoginProvider: Into + Subcommand + Debug + Clone {} +impl CliConsumerLoginProvider for T where T: Into + Subcommand + Debug + Clone {} + +// Authenticates and generates an access token for interacting with the api +#[derive(Parser, Debug, Clone)] +#[clap(name = "login")] +pub struct Login +where + SupportedProviders: CliConsumerLoginProvider, +{ + #[command(subcommand)] + method: LoginMethod, + #[arg(short = 'm', default_value = "id")] + mode: AuthenticationMode, +} + +impl

Login

+where + P: CliConsumerLoginProvider, +{ + pub async fn run(&self, ctx: &mut T) -> Result<()> + where + T: VCliContext, + >::Error: StdError + Send + Sync + 'static, + { + let (access_token, idp_token) = self.method.run(ctx, self.mode).await?; + + ctx.config_mut().set_token(access_token); + ctx.config_mut().save()?; + + // If we are acquiring an IdP token, present it to the user. + if let Some(idp_token) = idp_token { + println!( + "\nYou can now additionally authenticate against the requested remote service API \ + with the following token." + ); + println!("IdP token: {}", idp_token); + println!(); + println!( + "Please note that this should be kept secure as calls made with this token are \ + made on behalf of your user acount" + ); + } + + Ok(()) + } +} + +#[derive(Subcommand, Debug, Clone)] +pub enum LoginMethod +where + SupportedProviders: Subcommand + Debug + Clone, +{ + #[command(name = "oauth")] + /// Login via OAuth + OAuth { + #[command(subcommand)] + provider: SupportedProviders, + /// Additionally retrieve a the underlying IdP token. This token is not stored. Remote mode + /// should be used when you need to authenticate to the underlying system frontend by the API + #[arg(long, default_value = "false")] + request_idp_token: bool, + }, + /// Login via Magic Link + #[command(name = "mlink")] + MagicLink { + /// Email recipient to login via + email: String, + /// Optional access scopes to apply to this session + scope: Option, + }, +} + +pub enum LoginProvider { + Google, + GitHub, + Zendesk, +} + +#[derive(Copy, ValueEnum, Debug, Clone, PartialEq)] +pub enum AuthenticationMode { + /// Retrieve and store an identity token. Identity mode is the default and should be used to + /// when you do not require extended (multi-day) access + #[value(name = "id")] + Identity, + /// Retrieve and store an api token. Token mode should be used when you want to authenticate + /// a machine for continued access. This requires the permission to create api tokens + #[value(name = "token")] + Token, +} + +impl LoginMethod +where + SupportedProviders: CliConsumerLoginProvider, +{ + pub async fn run( + &self, + ctx: &T, + mode: AuthenticationMode, + ) -> Result<(String, Option)> + where + T: VCliContext, + >::Error: StdError + Send + Sync + 'static, + { + match self { + Self::OAuth { + provider, + request_idp_token, + } => { + let adapter = ctx.oauth_adapter(); + let provider = provider.clone().into(); + let provider = adapter.provider(provider).await?; + + // We now need to inspect the provider to determine the correct flow to use. If + // possible we use a limited input device flow, but not all providers support it. + // To handle those cases we need to use a proxy path that emulates an authorization + // code flow. + if provider.device_code_endpoint().is_some() { + if *request_idp_token { + anyhow::bail!( + "Remote token access is not supported via device authentication flow" + ); + } + Ok(( + self.run_oauth_device_provider(provider, mode, ctx.oauth_adapter()) + .await?, + None, + )) + } else if provider.supports_pkce_only() { + self.run_oauth_code_provider( + provider, + mode, + *request_idp_token, + ctx.oauth_adapter(), + ) + .await + } else { + anyhow::bail!("OAuth provider does not support any CLI authentication methods") + } + } + Self::MagicLink { email, scope } => Ok(( + self.run_magic_link(email, scope.as_deref(), ctx.mlink_adapter()) + .await?, + None, + )), + } + } + + async fn run_oauth_device_provider( + &self, + provider: V, + mode: AuthenticationMode, + adapter: T, + ) -> Result + where + T: CliOAuthAdapter, + V: CliOAuthProviderInfo, + { + let oauth_client = oauth::device::DeviceOAuth::new(provider)?; + let details = oauth_client.get_device_authorization().await?; + + println!( + "To complete login visit: {} and enter {}", + details.verification_uri().as_str(), + details.user_code().secret() + ); + + let token_response = oauth_client.login(&details).await; + + let identity_token = match token_response { + Ok(token) => Ok(token.access_token().to_owned()), + Err(err) => Err(anyhow::anyhow!("Authentication failed: {}", err)), + }?; + + match mode { + AuthenticationMode::Identity => Ok(identity_token.secret().to_string()), + AuthenticationMode::Token => { + let token = adapter + .get_long_lived_token(identity_token.secret()) + .await?; + Ok(token.access_token().to_string()) + } + } + } + + async fn run_oauth_code_provider( + &self, + provider: T, + mode: AuthenticationMode, + request_idp_token: bool, + adapter: V, + ) -> Result<(String, Option)> + where + T: CliOAuthProviderInfo, + V: CliOAuthAdapter + Send + Sync + 'static, + { + let oauth_client = oauth::code::CodeOAuth::new(provider)?; + let adapter = Arc::new(adapter); + + let identity_token = oauth_client + .login(Arc::clone(&adapter), request_idp_token) + .await?; + + let access_token = match mode { + AuthenticationMode::Identity => identity_token.access_token().to_string(), + AuthenticationMode::Token => { + let token = adapter + .get_long_lived_token(identity_token.access_token()) + .await?; + token.access_token().to_string() + } + }; + + let idp_token = if request_idp_token { + identity_token.idp_token().map(|s| s.to_string()) + } else { + None + }; + + Ok((access_token, idp_token)) + } + + async fn run_magic_link( + &self, + email: &str, + scope: Option<&str>, + adapter: T, + ) -> Result + where + T: CliMagicLinkAdapter, + { + let attempt = adapter.create_attempt(email, scope).await?; + + let mut auth_secret = String::new(); + print!("Enter the login token sent to the recipient: "); + std::io::stdout().flush()?; + std::io::stdin().read_line(&mut auth_secret)?; + + let token = adapter.exchange(attempt, email, &auth_secret).await?; + + Ok(token.access_token().to_string()) + } +} + +pub trait CliMagicLinkAdapter { + type Attempt; + type Token: CliAdapterToken; + type Error: StdError + Send + Sync + 'static; + + #[allow(clippy::type_complexity)] + fn create_attempt( + &self, + email: &str, + scope: Option<&str>, + ) -> Pin> + Send>>; + #[allow(clippy::type_complexity)] + fn exchange( + &self, + attempt: Self::Attempt, + email: &str, + token: &str, + ) -> Pin> + Send>>; +} diff --git a/v-cli-sdk/src/cmd/auth/mod.rs b/v-cli-sdk/src/cmd/auth/mod.rs new file mode 100644 index 00000000..ad7d6a24 --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/mod.rs @@ -0,0 +1,48 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::Result; +use clap::{Parser, Subcommand}; +use std::error::Error as StdError; + +use crate::{cmd::auth::login::CliConsumerLoginProvider, VCliContext}; + +pub mod login; +pub mod oauth; +pub mod proxy; + +// Authenticate against the Meetings API +#[derive(Parser, Debug)] +#[clap(name = "auth")] +pub struct Auth

+where + P: CliConsumerLoginProvider, +{ + #[command(subcommand)] + auth: AuthCommands

, +} + +#[derive(Subcommand, Debug, Clone)] +enum AuthCommands

+where + P: CliConsumerLoginProvider, +{ + /// Login via an authentication provider + Login(login::Login

), +} + +impl

Auth

+where + P: CliConsumerLoginProvider, +{ + pub async fn run(&self, ctx: &mut T) -> Result<()> + where + T: VCliContext, + >::Error: StdError + Send + Sync + 'static, + { + match &self.auth { + AuthCommands::Login(login) => login.run(ctx).await, + } + } +} diff --git a/v-cli-sdk/src/cmd/auth/oauth/code.rs b/v-cli-sdk/src/cmd/auth/oauth/code.rs new file mode 100644 index 00000000..0177d874 --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/oauth/code.rs @@ -0,0 +1,228 @@ +use std::{ + future::Future, + pin::Pin, + sync::{Arc, Mutex}, +}; + +use anyhow::Result; +use http::{Request, Response, StatusCode}; +use http_body_util::Full; +use hyper::body::{Bytes, Incoming}; + +use oauth2::{ + basic::BasicClient, AuthType, AuthUrl, ClientId, CsrfToken, EndpointNotSet, EndpointSet, + PkceCodeChallenge, RedirectUrl, Scope, TokenUrl, +}; +use tokio::sync::oneshot; +use uuid::Uuid; + +use crate::cmd::auth::{ + oauth::{CliOAuthAdapter, CliOAuthProviderInfo}, + proxy::run_proxy_server, +}; + +type CodeClient = BasicClient< + // HasAuthUrl + EndpointSet, + // HasDeviceAuthUrl + EndpointNotSet, + // HasIntrospectionUrl + EndpointNotSet, + // HasRevocationUrl + EndpointNotSet, + // HasTokenUrl + EndpointSet, +>; + +pub struct CodeOAuth { + client: CodeClient, + client_id: Uuid, + redirect_uri: String, + scopes: Vec, + port: u16, +} + +impl CodeOAuth { + pub fn new(provider: T) -> Result + where + T: CliOAuthProviderInfo, + { + let client = BasicClient::new(ClientId::new(provider.client_id().to_string())) + .set_auth_uri(AuthUrl::new( + provider + .auth_url_endpoint() + .ok_or_else(|| { + anyhow::anyhow!("OAuth code flow provider must define an authorization url") + })? + .to_string(), + )?) + .set_auth_type(AuthType::RequestBody) + .set_token_uri(TokenUrl::new(provider.token_endpoint().to_string())?) + .set_redirect_uri(RedirectUrl::new( + provider + .redirect_endpoint() + .ok_or_else(|| { + anyhow::anyhow!("OAuth code flow provider must define a redirect url") + })? + .to_string(), + )?); + + Ok(Self { + client, + client_id: provider.client_id(), + redirect_uri: provider.redirect_endpoint().unwrap_or_default().to_string(), + scopes: provider.scopes().iter().map(|s| s.to_string()).collect(), + port: provider.public_pkce_port().ok_or_else(|| { + anyhow::anyhow!("OAuth code flow provider must define a public proxy port") + })?, + }) + } + + /// Build the authorization URL that the user should visit in a browser. + /// Returns the full URL and the CSRF state token used for verification. + pub fn authorize_url( + &self, + pkce_challenge: PkceCodeChallenge, + ) -> (oauth2::url::Url, CsrfToken) { + let mut req = self + .client + .authorize_url(CsrfToken::new_random) + .set_pkce_challenge(pkce_challenge); + + for scope in &self.scopes { + req = req.add_scope(Scope::new(scope.to_string())); + } + + req.url() + } + + /// Run the full authorization code login flow: + /// + /// 1. Generate the authorization URL and print it for the user. + /// 2. Spin up a local HTTP proxy server to capture the IdP redirect. + /// 3. Forward the redirect request to the API server via the adapter. + /// 4. Extract the token from the server's response. + /// 5. Return a success page to the browser and shut down the proxy. + pub async fn login(&self, adapter: Arc, request_idp_token: bool) -> Result + where + T: CliOAuthAdapter + Send + Sync + 'static, + { + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + let (auth_url, _csrf_state) = self.authorize_url(pkce_challenge); + + println!( + "Open the following URL in your browser to authenticate:\n\n {}\n", + auth_url + ); + + // Channel to receive the token extracted from the server response. + let (token_tx, token_rx) = oneshot::channel::>(); + #[allow(clippy::type_complexity)] + let token_tx: Arc>>>> = + Arc::new(Mutex::new(Some(token_tx))); + + // Channel to shut down the proxy server once we have the token. + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + + let port = self.port; + + // Spawn the local proxy server in a background task. + tokio::spawn({ + let callback_token_tx = Arc::clone(&token_tx); + let error_token_tx = Arc::clone(&token_tx); + let client_id = self.client_id; + let redirect_uri = self.redirect_uri.clone(); + + async move { + let callback: crate::cmd::auth::proxy::Callback = Arc::new(Mutex::new(Some( + Box::new(move |request: Request| { + let adapter = Arc::clone(&adapter); + let token_tx = Arc::clone(&callback_token_tx); + + Box::pin(async move { + let code = request + .uri() + .query() + .and_then(|q: &str| { + q.split('&') + .filter_map(|pair: &str| pair.split_once('=')) + .find(|(key, _): &(&str, &str)| *key == "code") + .map(|(_, value): (&str, &str)| value.to_string()) + }) + .ok_or_else(|| { + anyhow::anyhow!( + "Missing 'code' query parameter in callback request" + ) + })?; + + // Forward the redirect request to the API server. + let token = adapter + .exchange_authorization_code(super::AuthorizationCodeExchange { + provider: crate::cmd::auth::login::LoginProvider::Zendesk, + client_id, + redirect_uri: redirect_uri.clone(), + grant_type: "authorization_code".to_string(), + code, + pkce_verifier, + request_idp_token, + }) + .await + .map_err(|e| anyhow::anyhow!(e))?; + + // Send the token back to the main task. + if let Ok(mut guard) = token_tx.lock() { + if let Some(tx) = guard.take() { + let _ = tx.send(Ok(token)); + } + } + + // Return a friendly page to the browser so the user + // knows they can close the tab. + Ok(Response::builder() + .status(StatusCode::OK) + .header("content-type", "text/html; charset=utf-8") + .body(Full::new(Bytes::from(concat!( + "", + "", + "

Authentication successful. This window should close automatically.

", + "" + ))))?) + }) + as Pin< + Box< + dyn Future>>> + + Send, + >, + > + }), + ))); + + if let Err(e) = run_proxy_server(port, callback, shutdown_rx).await { + eprintln!("Proxy server error: {e}"); + + // If the proxy died before we got a token, unblock the + // receiver so the caller is not stuck forever. + if let Ok(mut guard) = error_token_tx.lock() { + if let Some(tx) = guard.take() { + let _ = tx.send(Err(anyhow::anyhow!( + "Proxy server exited unexpectedly: {e}" + ))); + } + } + } + } + }); + + // Wait for the proxy callback to extract the token. + let token = token_rx.await.map_err(|_| { + anyhow::anyhow!( + "Authentication callback was never received — proxy server may have exited early" + ) + })??; + + // Tell the proxy server to stop. + let _ = shutdown_tx.send(()); + + Ok(token) + } +} diff --git a/v-cli-sdk/src/cmd/auth/oauth/device.rs b/v-cli-sdk/src/cmd/auth/oauth/device.rs new file mode 100644 index 00000000..4acf227b --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/oauth/device.rs @@ -0,0 +1,92 @@ +use anyhow::Result; +use oauth2::{ + basic::{BasicClient, BasicTokenType}, + AuthType, AuthUrl, ClientId, DeviceAuthorizationUrl, EmptyExtraTokenFields, EndpointNotSet, + EndpointSet, Scope, StandardDeviceAuthorizationResponse, StandardTokenResponse, TokenUrl, +}; + +use crate::cmd::auth::oauth::CliOAuthProviderInfo; + +type DeviceClient = BasicClient< + // HasAuthUrl + EndpointSet, + // HasDeviceAuthUrl + EndpointSet, + // HasIntrospectionUrl + EndpointNotSet, + // HasRevocationUrl + EndpointNotSet, + // HasTokenUrl + EndpointSet, +>; + +pub struct DeviceOAuth { + client: DeviceClient, + http: oauth2_reqwest::ReqwestClient, + scopes: Vec, +} + +impl DeviceOAuth { + pub fn new(provider: T) -> Result + where + T: CliOAuthProviderInfo, + { + if let Some(device_endpoint) = provider.device_code_endpoint() { + let device_auth_url = DeviceAuthorizationUrl::new(device_endpoint.to_string())?; + + let client = BasicClient::new(ClientId::new(provider.client_id().to_string())) + .set_auth_uri(AuthUrl::new( + provider + .device_code_endpoint() + .ok_or_else(|| { + anyhow::anyhow!( + "OAuth device flow provider must define an device code url" + ) + })? + .to_string(), + )?) + .set_auth_type(AuthType::RequestBody) + .set_token_uri(TokenUrl::new(provider.token_endpoint().to_string())?) + .set_device_authorization_url(device_auth_url); + + Ok(Self { + client, + http: oauth2_reqwest::ReqwestClient::from( + reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap(), + ), + scopes: provider.scopes().iter().map(|s| s.to_string()).collect(), + }) + } else { + anyhow::bail!("Device authorization is not supported by this provider") + } + } + + pub async fn login( + &self, + details: &StandardDeviceAuthorizationResponse, + ) -> Result> { + let token = self + .client + .exchange_device_access_token(details) + .set_max_backoff_interval(details.interval()) + .request_async(&self.http, tokio::time::sleep, Some(details.expires_in())) + .await; + + Ok(token?) + } + + pub async fn get_device_authorization(&self) -> Result { + let mut req = self.client.exchange_device_code(); + + for scope in &self.scopes { + req = req.add_scope(Scope::new(scope.to_string())); + } + + let res = req.request_async(&self.http).await; + + Ok(res?) + } +} diff --git a/v-cli-sdk/src/cmd/auth/oauth/mod.rs b/v-cli-sdk/src/cmd/auth/oauth/mod.rs new file mode 100644 index 00000000..8a65605a --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/oauth/mod.rs @@ -0,0 +1,58 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::Result; +use oauth2::PkceCodeVerifier; +use std::{error::Error as StdError, future::Future, pin::Pin}; +use uuid::Uuid; + +pub mod code; +pub mod device; + +use crate::cmd::auth::login::CliAdapterToken; + +/// Parameters for exchanging an authorization code for an access token. +pub struct AuthorizationCodeExchange { + pub provider: super::login::LoginProvider, + pub client_id: Uuid, + pub redirect_uri: String, + pub grant_type: String, + pub code: String, + pub pkce_verifier: PkceCodeVerifier, + pub request_idp_token: bool, +} + +pub trait CliOAuthAdapter { + type ShortToken: CliAdapterToken + Send + 'static; + type LongToken: CliAdapterToken + Send + 'static; + type Error: StdError + Send + Sync + 'static; + + #[allow(clippy::type_complexity)] + fn provider( + &self, + provider: super::login::LoginProvider, + ) -> Pin> + Send>>; + #[allow(clippy::type_complexity)] + fn exchange_authorization_code( + &self, + exchange: AuthorizationCodeExchange, + ) -> Pin> + Send>>; + #[allow(clippy::type_complexity)] + fn get_long_lived_token( + &self, + access_token: &str, + ) -> Pin> + Send>>; +} + +pub trait CliOAuthProviderInfo { + fn client_id(&self) -> Uuid; + fn remote_client_id(&self) -> &str; + fn public_pkce_port(&self) -> Option; + fn supports_pkce_only(&self) -> bool; + fn device_code_endpoint(&self) -> Option<&str>; + fn auth_url_endpoint(&self) -> Option<&str>; + fn token_endpoint(&self) -> &str; + fn redirect_endpoint(&self) -> Option<&str>; + fn scopes(&self) -> &[String]; +} diff --git a/v-cli-sdk/src/cmd/auth/proxy.rs b/v-cli-sdk/src/cmd/auth/proxy.rs new file mode 100644 index 00000000..527584f7 --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/proxy.rs @@ -0,0 +1,157 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use std::future::Future; +use std::net::SocketAddr; +use std::sync::{Arc, Mutex}; + +use http_body_util::Full; +use hyper::body::{Bytes, Incoming}; +use hyper::service::service_fn; +use hyper::{Request, Response, StatusCode}; +use hyper_util::rt::TokioIo; +use tokio::net::TcpListener; +use tokio::sync::oneshot; + +/// A callback function that receives an incoming HTTP request and returns a response. +/// Wrapped in `Arc>>` so it can be called at most once. +pub type CallbackFn = Box< + dyn FnOnce( + Request, + ) -> std::pin::Pin< + Box>>> + Send>, + > + Send, +>; + +/// A shareable, single-use callback. The first request to arrive `.take()`s +/// the inner function; any subsequent request receives an error response. +pub type Callback = Arc>>; + +/// Start a minimal HTTP server on the given port that forwards every incoming +/// request to `callback` and returns whatever response the callback produces. +/// +/// The server will run until a message is sent on the `shutdown` channel, at +/// which point it will stop accepting new connections and return. +pub async fn run_proxy_server( + port: u16, + callback: Callback, + shutdown: oneshot::Receiver<()>, +) -> anyhow::Result<()> { + let addr = SocketAddr::from(([127, 0, 0, 1], port)); + let listener = TcpListener::bind(addr).await?; + serve_loop(listener, callback, shutdown).await +} + +/// Core accept-loop shared by [`run_proxy_server`] and tests. +/// +/// Accepts connections on `listener`, forwarding each request to `callback`. +/// Stops when `shutdown` fires. +async fn serve_loop( + listener: TcpListener, + callback: Callback, + shutdown: oneshot::Receiver<()>, +) -> anyhow::Result<()> { + tokio::pin!(shutdown); + + loop { + tokio::select! { + _ = &mut shutdown => { + break; + } + accepted = listener.accept() => { + let (stream, _remote_addr) = accepted?; + let io = TokioIo::new(stream); + let cb = Arc::clone(&callback); + + tokio::task::spawn(async move { + let service = service_fn(move |req: Request| { + let cb = Arc::clone(&cb); + async move { + let handler = cb + .lock() + .expect("callback mutex poisoned") + .take(); + + match handler { + Some(f) => f(req).await, + None => Ok(Response::builder() + .status(StatusCode::GONE) + .body(Full::new(Bytes::from( + "Callback has already been invoked", + ))) + .expect("building static response cannot fail")), + } + } + }); + + if let Err(err) = + hyper::server::conn::http1::Builder::new() + .serve_connection(io, service) + .await + { + eprintln!("Error serving connection: {err}"); + } + }); + } + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use hyper::StatusCode; + + #[tokio::test] + async fn test_proxy_server_responds() { + let callback: Callback = Arc::new(Mutex::new(Some(Box::new(|_req| { + Box::pin(async { + Ok(Response::builder() + .status(StatusCode::OK) + .body(Full::new(Bytes::from("hello from callback"))) + .unwrap()) + }) + })))); + + let (tx, rx) = oneshot::channel::<()>(); + + // Use port 0 to let the OS pick an available port. + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn({ + let callback = Arc::clone(&callback); + async move { + serve_loop(listener, callback, rx).await.unwrap(); + } + }); + + // Send a request to the server. + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://{}", local_addr)) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + assert_eq!(resp.text().await.unwrap(), "hello from callback"); + + // A second request should get a 410 GONE. + let resp2 = client + .get(format!("http://{}", local_addr)) + .send() + .await + .unwrap(); + + assert_eq!(resp2.status(), 410); + + // Shut down the server. + tx.send(()).unwrap(); + server_handle.await.unwrap(); + } +} diff --git a/v-cli-sdk/src/cmd/config/mod.rs b/v-cli-sdk/src/cmd/config/mod.rs new file mode 100644 index 00000000..1ba94478 --- /dev/null +++ b/v-cli-sdk/src/cmd/config/mod.rs @@ -0,0 +1,143 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::Result; +use clap::{Parser, Subcommand}; + +use crate::{FormatStyle, VCliContext}; + +pub trait VCliConfig { + fn host(&self) -> Option<&str>; + fn set_host(&mut self, host: String); + fn token(&self) -> Option<&str>; + fn set_token(&mut self, token: String); + fn default_format(&self) -> FormatStyle; + fn set_default_format(&mut self, format: FormatStyle); + fn mlink_redirect(&self) -> Option<&str>; + fn set_mlink_redirect(&mut self, redirect: String); + fn mlink_secret(&self) -> Option<&str>; + fn set_mlink_secret(&mut self, secret: String); + fn save(&self) -> Result<(), std::io::Error>; +} + +#[derive(Debug, Parser)] +#[clap(name = "config")] +pub struct ConfigCmd { + #[clap(subcommand)] + setting: SettingCmd, +} + +#[derive(Debug, Subcommand)] +pub enum SettingCmd { + /// Gets a setting + #[clap(subcommand, name = "get")] + Get(GetCmd), + /// Sets a setting + #[clap(subcommand, name = "set")] + Set(SetCmd), +} + +#[derive(Debug, Subcommand)] +pub enum GetCmd { + /// Get the default formatter to use when printing results + #[clap(name = "format")] + Format, + /// Get the configured API host in use + #[clap(name = "host")] + Host, + /// Get the configured access token + #[clap(name = "token")] + Token, + /// Get the configured magic redirect uri + #[clap(name = "mlink-redirect")] + MagicLinkRedirectUri, + /// Get the configured magic link secret + #[clap(name = "mlink-secret")] + MagicLinkSecret, +} + +#[derive(Debug, Subcommand)] +pub enum SetCmd { + /// Set the default formatter to use when printing results + #[clap(name = "format")] + Format { format: FormatStyle }, + /// Set the configured API host to use + #[clap(name = "host")] + Host { host: String }, + /// Set the configured magic redirect uri + #[clap(name = "mlink-redirect")] + MagicLinkRedirectUri { redirect: String }, + /// Set the configured magic link secret + #[clap(name = "mlink-secret")] + MagicLinkSecret { secret: String }, +} + +impl ConfigCmd { + pub async fn run(&self, ctx: &mut T) -> Result<()> + where + T: VCliContext, + { + match &self.setting { + SettingCmd::Get(get) => get.run(ctx.config()).await?, + SettingCmd::Set(set) => set.run(ctx.config_mut()).await?, + } + + Ok(()) + } +} + +impl GetCmd { + pub async fn run(&self, config: &T) -> Result<()> + where + T: VCliConfig, + { + match &self { + GetCmd::Format => { + println!("{}", config.default_format()); + } + GetCmd::Host => { + println!("{}", config.host().unwrap_or("None")); + } + GetCmd::Token => { + println!("{}", config.token().unwrap_or("None")); + } + GetCmd::MagicLinkRedirectUri => { + println!("{}", config.mlink_redirect().unwrap_or("None")); + } + GetCmd::MagicLinkSecret => { + println!("{}", config.mlink_secret().unwrap_or("None")); + } + } + + Ok(()) + } +} + +impl SetCmd { + pub async fn run(&self, config: &mut T) -> Result<()> + where + T: VCliConfig, + { + match &self { + SetCmd::Format { format } => { + config.set_default_format(*format); + config.save()?; + } + SetCmd::Host { host } => { + config.set_host(host.to_string()); + config.save()?; + } + SetCmd::MagicLinkRedirectUri { redirect } => { + config.set_mlink_redirect(redirect.to_string()); + config.save()?; + } + SetCmd::MagicLinkSecret { secret } => { + config.set_mlink_secret(secret.to_string()); + config.save()?; + } + } + + Ok(()) + } +} diff --git a/v-cli-sdk/src/cmd/mod.rs b/v-cli-sdk/src/cmd/mod.rs new file mode 100644 index 00000000..df67c4a3 --- /dev/null +++ b/v-cli-sdk/src/cmd/mod.rs @@ -0,0 +1,6 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +pub mod auth; +pub mod config; diff --git a/v-cli-sdk/src/err.rs b/v-cli-sdk/src/err.rs new file mode 100644 index 00000000..6f8066a6 --- /dev/null +++ b/v-cli-sdk/src/err.rs @@ -0,0 +1,71 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::{anyhow, Error}; +use progenitor_client::Error as ProgenitorClientError; + +use crate::{VApiErrorMessage, VCliContext, VerbosityLevel}; + +pub fn format_api_err(ctx: &T, client_err: ProgenitorClientError) -> Error +where + T: VCliContext, + E: VApiErrorMessage, +{ + let mut err = anyhow!("API Request failed"); + + match client_err { + ProgenitorClientError::CommunicationError(inner) => { + if ctx.verbosity() >= VerbosityLevel::All { + err = err.context("Communication Error").context(inner); + } + } + ProgenitorClientError::ErrorResponse(response) => { + if ctx.verbosity() >= VerbosityLevel::All { + err = err.context(format!("Status: {}", response.status())); + err = err.context(format!("Headers {:?}", response.headers())); + } + + let response_message = response.into_inner(); + + if ctx.verbosity() >= VerbosityLevel::All { + err = err.context(format!( + "Request {}", + response_message.request_id().unwrap_or("") + )); + } + + err = err.context(format!( + "Code: {}", + response_message.error_code().unwrap_or("") + )); + err = err.context(response_message.message().unwrap_or("").to_string()); + } + ProgenitorClientError::InvalidRequest(message) => { + err = err.context("Invalid request").context(message); + } + ProgenitorClientError::InvalidResponsePayload(_, inner) => { + err = err.context("Invalid response").context(inner); + } + ProgenitorClientError::UnexpectedResponse(response) => { + err = err + .context("Unexpected response") + .context(format!("Status: {}", response.status())); + + if ctx.verbosity() >= VerbosityLevel::All { + err = err.context(format!("Headers {:?}", response.headers())); + } + } + ProgenitorClientError::ResponseBodyError(inner) => { + err = err.context("Invalid response").context(inner); + } + ProgenitorClientError::InvalidUpgrade(inner) => { + err = err.context("Invalid upgrade").context(inner) + } + ProgenitorClientError::Custom(inner) => { + err = err.context("Inner progenitor error").context(inner) + } + } + + err +} diff --git a/v-cli-sdk/src/lib.rs b/v-cli-sdk/src/lib.rs new file mode 100644 index 00000000..27f678a5 --- /dev/null +++ b/v-cli-sdk/src/lib.rs @@ -0,0 +1,71 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use clap::ValueEnum; +use serde::{Deserialize, Serialize}; +use std::fmt::Display; + +pub mod cmd; +pub mod err; +pub mod printer; + +use crate::cmd::auth::{login::CliMagicLinkAdapter, oauth::CliOAuthAdapter}; +pub use cmd::config::VCliConfig; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum VerbosityLevel { + None, + All, +} + +#[derive( + Copy, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Clone, Serialize, Deserialize, Default, +)] +pub enum FormatStyle { + #[default] + #[value(name = "json")] + Json, + #[value(name = "tab")] + Tab, +} + +impl Display for FormatStyle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Json => write!(f, "json"), + Self::Tab => write!(f, "tab"), + } + } +} + +pub trait VCliContext { + type ShortToken; + type LongToken; + type Error; + + fn config(&self) -> &impl VCliConfig; + fn config_mut(&mut self) -> &mut impl VCliConfig; + fn client(&self) -> Option; + fn printer(&self) -> Option<&P>; + fn verbosity(&self) -> VerbosityLevel; + + fn oauth_adapter( + &self, + ) -> impl CliOAuthAdapter< + ShortToken = Self::ShortToken, + LongToken = Self::LongToken, + Error = Self::Error, + > + Send + + Sync + + 'static; + fn mlink_adapter( + &self, + ) -> impl CliMagicLinkAdapter + Send + Sync + 'static; +} + +pub trait VApiErrorMessage { + fn message(&self) -> Option<&str>; + fn error_code(&self) -> Option<&str>; + fn request_id(&self) -> Option<&str>; +} diff --git a/v-cli-sdk/src/printer/mod.rs b/v-cli-sdk/src/printer/mod.rs new file mode 100644 index 00000000..4685b8ed --- /dev/null +++ b/v-cli-sdk/src/printer/mod.rs @@ -0,0 +1,227 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use owo_colors::{OwoColorize, Style}; +use serde::Serialize; +use std::io::Write; +use tabwriter::TabWriter; + +#[derive(Debug, Clone)] +pub enum Printer { + Json, + Tab, +} + +pub trait CliOutput { + fn output_error(&self, value: &progenitor_client::Error) + where + T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug; +} + +impl Printer { + /// Print any serializable response object in the configured format. + /// + /// - `Json` mode emits compact, single-line JSON. + /// - `Tab` mode serializes to a `serde_json::Value` and pretty-prints it + /// with tab-aligned key/value pairs. + pub fn print_response(&self, value: &T) + where + T: Serialize, + { + let json_value = serde_json::to_value(value) + .unwrap_or_else(|e| serde_json::Value::String(format!("", e))); + + match self { + Printer::Json => { + println!("{}", serde_json::to_string(&json_value).unwrap_or_default()); + } + Printer::Tab => { + let styles = TabStyles::default(); + let mut tw = TabWriter::new(vec![]).ansi(true); + pretty_print_value(&mut tw, &json_value, 0, &styles); + tw.flush().unwrap(); + let output = String::from_utf8(tw.into_inner().unwrap()).unwrap(); + print!("{}", output); + } + } + } + + /// Print an error from a progenitor client response. + /// + /// A 401 Unauthorized is treated specially: instead of dumping the raw + /// server error we print a short, actionable message telling the user to + /// authenticate first. + pub fn print_error_response(&self, value: &progenitor_client::Error) + where + T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug, + { + // Check for 401 Unauthorized up-front, regardless of output format. + if let Some(status) = value.status() { + if status == reqwest::StatusCode::UNAUTHORIZED { + eprintln!("Authentication required. Please run `sprue auth login` first."); + return; + } + } + + match self { + Printer::Json => { + // For JSON mode, try to extract a serializable body from the + // error and fall back to the Debug representation. + let msg = match value { + progenitor_client::Error::ErrorResponse(rv) => { + serde_json::to_string(rv.as_ref()).ok() + } + _ => None, + }; + eprintln!("{}", msg.unwrap_or_else(|| format!("{:?}", value))); + } + Printer::Tab => { + eprintln!("{}", value); + } + } + } +} + +impl CliOutput for Printer { + fn output_error(&self, value: &progenitor_client::Error) + where + T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug, + { + self.print_error_response(value); + } +} + +// --------------------------------------------------------------------------- +// Tab-indented pretty-printer for serde_json::Value +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone)] +struct TabStyles { + label: Style, + value: Style, + null: Style, +} + +impl Default for TabStyles { + fn default() -> Self { + TabStyles { + label: Style::new().bold(), + value: Style::new(), + null: Style::new().dimmed(), + } + } +} + +fn indent(tw: &mut TabWriter>, depth: usize) { + for _ in 0..depth { + let _ = write!(tw, "\t"); + } +} + +fn pretty_print_value( + tw: &mut TabWriter>, + value: &serde_json::Value, + depth: usize, + styles: &TabStyles, +) { + match value { + serde_json::Value::Object(map) => { + for (key, val) in map { + pretty_print_field(tw, key, val, depth, styles); + } + } + serde_json::Value::Array(arr) => { + for (i, val) in arr.iter().enumerate() { + indent(tw, depth); + let _ = writeln!(tw, "{}", format!("[{}]", i).style(styles.label),); + pretty_print_value(tw, val, depth + 1, styles); + } + } + _ => { + indent(tw, depth); + let _ = writeln!(tw, "{}", format_scalar(value, styles)); + } + } +} + +fn pretty_print_field( + tw: &mut TabWriter>, + key: &str, + value: &serde_json::Value, + depth: usize, + styles: &TabStyles, +) { + match value { + serde_json::Value::Object(_) => { + indent(tw, depth); + let _ = writeln!(tw, "{}:", key.style(styles.label)); + pretty_print_value(tw, value, depth + 1, styles); + } + serde_json::Value::Array(arr) if arr.is_empty() => { + indent(tw, depth); + let _ = writeln!( + tw, + "{}:\t{}", + key.style(styles.label), + "[]".style(styles.null), + ); + } + serde_json::Value::Array(arr) if arr.iter().all(is_scalar) => { + // Print simple arrays inline, one value per line with the key on + // the first line only (mimics the existing TabDisplay list style). + for (i, val) in arr.iter().enumerate() { + indent(tw, depth); + if i == 0 { + let _ = writeln!( + tw, + "{}:\t{}", + key.style(styles.label), + format_scalar(val, styles), + ); + } else { + let _ = writeln!(tw, "\t{}", format_scalar(val, styles)); + } + } + } + serde_json::Value::Array(arr) => { + indent(tw, depth); + let _ = writeln!(tw, "{}:", key.style(styles.label)); + for (i, val) in arr.iter().enumerate() { + indent(tw, depth + 1); + let _ = writeln!(tw, "{}", format!("[{}]", i).style(styles.label),); + pretty_print_value(tw, val, depth + 2, styles); + } + } + _ => { + indent(tw, depth); + let _ = writeln!( + tw, + "{}:\t{}", + key.style(styles.label), + format_scalar(value, styles), + ); + } + } +} + +fn is_scalar(value: &serde_json::Value) -> bool { + matches!( + value, + serde_json::Value::Null + | serde_json::Value::Bool(_) + | serde_json::Value::Number(_) + | serde_json::Value::String(_) + ) +} + +fn format_scalar(value: &serde_json::Value, styles: &TabStyles) -> String { + match value { + serde_json::Value::Null => format!("{}", "null".style(styles.null)), + serde_json::Value::Bool(b) => format!("{}", b.style(styles.value)), + serde_json::Value::Number(n) => format!("{}", n.style(styles.value)), + serde_json::Value::String(s) => format!("{}", s.style(styles.value)), + // Fallback for non-scalars that end up here + other => format!("{}", other.style(styles.value)), + } +} diff --git a/v-model/Cargo.toml b/v-model/Cargo.toml index 87d5755a..63234757 100644 --- a/v-model/Cargo.toml +++ b/v-model/Cargo.toml @@ -24,6 +24,7 @@ serde_json = { workspace = true } steno = { workspace = true, optional = true } thiserror = { workspace = true } tracing = { workspace = true } +url = { workspace = true } uuid = { workspace = true, features = ["v4", "serde"] } [dev-dependencies] diff --git a/v-model/src/lib.rs b/v-model/src/lib.rs index b02c4efa..ac7b2f7e 100644 --- a/v-model/src/lib.rs +++ b/v-model/src/lib.rs @@ -15,11 +15,9 @@ use schema_ext::MagicLinkAttemptState; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::{ - collections::{BTreeMap, BTreeSet}, - fmt::Display, -}; +use std::{collections::BTreeSet, fmt::Display}; use thiserror::Error; +use url::Url; pub mod db; pub mod permissions; @@ -282,26 +280,22 @@ pub struct LoginAttempt { } impl LoginAttempt { - pub fn callback_url(&self) -> String { - let mut params = BTreeMap::new(); - - if let Some(state) = &self.state { - params.insert("state", state); - } - - if let Some(error) = &self.error { - params.insert("error", error); - } else if let Some(authz_code) = &self.authz_code { - params.insert("code", authz_code); + pub fn callback_url(&self) -> Result { + let mut url = Url::parse(&self.redirect_uri)?; + + { + let mut pairs = url.query_pairs_mut(); + if let Some(state) = &self.state { + pairs.append_pair("state", state); + } + if let Some(error) = &self.error { + pairs.append_pair("error", error); + } else if let Some(authz_code) = &self.authz_code { + pairs.append_pair("code", authz_code); + } } - let query_string = params - .into_iter() - .map(|(k, v)| format!("{}={}", k, v)) - .collect::>() - .join("&"); - - [self.redirect_uri.as_str(), query_string.as_str()].join("?") + Ok(url.to_string()) } } @@ -324,6 +318,13 @@ impl NewLoginAttempt { redirect_uri: String, scope: String, ) -> Result { + // Validate that the redirect URI is a well-formed URL. This ensures + // callback_url() can always parse it later. + Url::parse(&redirect_uri).map_err(|err| InvalidValueError { + field: "redirect_uri".to_string(), + error: format!("Invalid URL: {}", err), + })?; + Ok(Self { id: TypedUuid::new_v4(), attempt_state: LoginAttemptState::New, diff --git a/v-model/src/storage/mod.rs b/v-model/src/storage/mod.rs index 114a77ec..70568dd3 100644 --- a/v-model/src/storage/mod.rs +++ b/v-model/src/storage/mod.rs @@ -270,6 +270,15 @@ pub trait LoginAttemptStore { pagination: &ListPagination, ) -> Result, StoreError>; async fn upsert(&self, attempt: NewLoginAttempt) -> Result; + /// Atomically update a login attempt, but only if it is currently in the `expected_state`. + /// The `attempt` must have the desired target state and any other field updates already set. + /// Returns `StoreError::InvariantFailed` if the attempt is not in the expected state, + /// which prevents TOCTOU races on state transitions. + async fn update_if_state( + &self, + attempt: NewLoginAttempt, + expected_state: LoginAttemptState, + ) -> Result; } #[derive(Debug, Default)] diff --git a/v-model/src/storage/postgres.rs b/v-model/src/storage/postgres.rs index 5daaaf5c..16a38f82 100644 --- a/v-model/src/storage/postgres.rs +++ b/v-model/src/storage/postgres.rs @@ -29,7 +29,7 @@ use crate::{ magic_link_client_redirect_uri, magic_link_client_secret, mapper, oauth_client, oauth_client_redirect_uri, oauth_client_secret, }, - schema_ext::MagicLinkAttemptState, + schema_ext::{LoginAttemptState, MagicLinkAttemptState}, storage::{LinkRequestFilter, LinkRequestStore, StoreError}, AccessGroup, AccessGroupId, AccessToken, AccessTokenId, ApiKey, ApiKeyId, ApiUser, ApiUserContactEmail, ApiUserInfo, ApiUserProvider, LinkRequest, LinkRequestId, LoginAttempt, @@ -775,6 +775,38 @@ impl LoginAttemptStore for PostgresStore { Ok(attempt_m.into()) } + + async fn update_if_state( + &self, + attempt: NewLoginAttempt, + expected_state: LoginAttemptState, + ) -> Result { + let conn = self.pool.get().await?; + let result: Option = update( + login_attempt::dsl::login_attempt + .filter(login_attempt::id.eq(attempt.id.into_untyped_uuid())) + .filter(login_attempt::attempt_state.eq(expected_state)), + ) + .set(( + login_attempt::attempt_state.eq(attempt.attempt_state.clone()), + login_attempt::authz_code.eq(attempt.authz_code), + login_attempt::expires_at.eq(attempt.expires_at), + login_attempt::error.eq(attempt.error), + login_attempt::provider_authz_code.eq(attempt.provider_authz_code), + login_attempt::provider_error.eq(attempt.provider_error), + )) + .get_result_async::(&*conn) + .await + .optional()?; + + match result { + Some(attempt) => Ok(LoginAttempt::from(attempt)), + None => Err(StoreError::InvariantFailed(format!( + "Login attempt {} is not in expected state for transition to {}", + attempt.id, attempt.attempt_state + ))), + } + } } #[async_trait]