diff --git a/Cargo.lock b/Cargo.lock index 2081413a..ea8930a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -752,7 +752,7 @@ dependencies = [ [[package]] name = "dropshot-authorization-header" -version = "0.3.0" +version = "0.4.0" dependencies = [ "async-trait", "base64", @@ -1885,6 +1885,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" +[[package]] +name = "owo-colors" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d211803b9b6b570f68772237e415a029d5a50c65d382910b879fb19d3271f94d" + [[package]] name = "parking_lot" version = "0.12.5" @@ -2081,6 +2087,21 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "progenitor-client" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e8a874cf25a33cac7a01b9c1de87bcfbc8aea93f3156d09dcc3bee516a78926" +dependencies = [ + "bytes", + "futures-core", + "percent-encoding", + "reqwest", + "serde", + "serde_json", + "serde_urlencoded", +] + [[package]] name = "quinn" version = "0.11.9" @@ -2329,6 +2350,7 @@ dependencies = [ "rustls-platform-verifier", "serde", "serde_json", + "serde_urlencoded", "sync_wrapper", "tokio", "tokio-rustls 0.26.4", @@ -3077,6 +3099,15 @@ dependencies = [ "libc", ] +[[package]] +name = "tabwriter" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fce91f2f0ec87dff7e6bcbbeb267439aa1188703003c6055193c821487400432" +dependencies = [ + "unicode-width", +] + [[package]] name = "take_mut" version = "0.2.2" @@ -3467,6 +3498,12 @@ version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + [[package]] name = "unicode-xid" version = "0.2.6" @@ -3524,7 +3561,7 @@ dependencies = [ [[package]] name = "v-api" -version = "0.3.0" +version = "0.4.0" dependencies = [ "anyhow", "async-trait", @@ -3546,7 +3583,6 @@ dependencies = [ "oauth2", "oauth2-reqwest", "partial-struct", - "percent-encoding", "rand 0.10.1", "reqwest", "rsa", @@ -3573,7 +3609,7 @@ dependencies = [ [[package]] name = "v-api-installer" -version = "0.3.0" +version = "0.4.0" dependencies = [ "diesel", "diesel_migrations", @@ -3581,7 +3617,7 @@ dependencies = [ [[package]] name = "v-api-param" -version = "0.3.0" +version = "0.4.0" dependencies = [ "secrecy", "serde", @@ -3592,7 +3628,7 @@ dependencies = [ [[package]] name = "v-api-permission-derive" -version = "0.3.0" +version = "0.4.0" dependencies = [ "heck", "newtype-uuid", @@ -3607,9 +3643,32 @@ dependencies = [ "v-model", ] +[[package]] +name = "v-cli-sdk" +version = "0.4.0" +dependencies = [ + "anyhow", + "clap", + "http", + "http-body-util", + "hyper", + "hyper-util", + "oauth2", + "oauth2-reqwest", + "owo-colors", + "progenitor-client", + "reqwest", + "schemars 0.8.22", + "serde", + "serde_json", + "tabwriter", + "tokio", + "uuid", +] + [[package]] name = "v-model" -version = "0.3.0" +version = "0.4.0" dependencies = [ "async-bb8-diesel", "async-trait", @@ -3627,6 +3686,7 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tracing", + "url", "uuid", "v-api-installer", ] @@ -4267,7 +4327,7 @@ checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" [[package]] name = "xtask" -version = "0.3.0" +version = "0.4.0" dependencies = [ "clap", "regex", diff --git a/Cargo.toml b/Cargo.toml index c1fa7faa..d3ad8588 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "v-api-installer", "v-api-param", "v-api-permission-derive", + "v-cli-sdk", "v-model", "xtask" ] @@ -12,7 +13,7 @@ resolver = "2" [workspace.package] publish = true edition = "2024" -version = "0.3.0" +version = "0.4.0" [workspace.dependencies] anyhow = "1.0" @@ -34,14 +35,16 @@ hex = "0.4.3" http = "1" http-body-util = "0.1.3" hyper = "1.9.0" +hyper-util = "0.1" jsonwebtoken = { version = "10.2", features = ["aws_lc_rs"] } mockall = "0.14.0" newtype-uuid = { version = "1.3.2", features = ["schemars08", "serde", "v4"] } oauth2 = { version = "5.0.0", default-features = false } oauth2-reqwest = "0.1.0-alpha.3" +owo-colors = "4.2.3" partial-struct = { git = "https://github.com/oxidecomputer/partial-struct" } -percent-encoding = "2.3.2" proc-macro2 = "1" +progenitor-client = "0.14.0" quote = "1" rand = "0.10.1" rand_core = "0.10.1" @@ -58,6 +61,7 @@ sha2 = "0.11.0" slog = "2.8.2" steno = { git = "https://github.com/oxidecomputer/steno" } syn = "2" +tabwriter = "1.4.1" tap = "1.0.1" tempfile = "3" thiserror = "2" diff --git a/v-api-permission-derive/src/lib.rs b/v-api-permission-derive/src/lib.rs index 3abb41dd..aa75984d 100644 --- a/v-api-permission-derive/src/lib.rs +++ b/v-api-permission-derive/src/lib.rs @@ -439,6 +439,7 @@ fn from_system_permission_tokens( VPermission::ManageMagicLinkClientsAll => Self::ManageMagicLinkClientsAll, VPermission::CreateAccessToken => Self::CreateAccessToken, + VPermission::RetrieveRemoteAccessToken => Self::RetrieveRemoteAccessToken, VPermission::GetSagasAll => Self::GetSagasAll, VPermission::ManageSagasAll => Self::ManageSagasAll, @@ -722,6 +723,7 @@ fn system_permission_tokens() -> TokenStream { ManageMagicLinkClientsAll, CreateAccessToken, + RetrieveRemoteAccessToken, #[v_api(scope(to = "saga:r", from = "saga:r"))] GetSagasAll, diff --git a/v-api/Cargo.toml b/v-api/Cargo.toml index 35e0a586..0e036309 100644 --- a/v-api/Cargo.toml +++ b/v-api/Cargo.toml @@ -29,7 +29,6 @@ oauth2 = { workspace = true } oauth2-reqwest = { workspace = true } newtype-uuid = { workspace = true } partial-struct = { workspace = true } -percent-encoding = { workspace = true } rand = { workspace = true, features = ["std"] } reqwest = { workspace = true } rsa = { workspace = true, features = ["sha2"] } diff --git a/v-api/src/config.rs b/v-api/src/config.rs index 4c3ee2ad..0823d05a 100644 --- a/v-api/src/config.rs +++ b/v-api/src/config.rs @@ -8,20 +8,23 @@ use jsonwebtoken::jwk::{ AlgorithmParameters, CommonParameters, Jwk, KeyAlgorithm, PublicKeyUse, RSAKeyParameters, RSAKeyType, }; +use newtype_uuid::TypedUuid; +use partial_struct::partial; use rsa::{ RsaPrivateKey, RsaPublicKey, pkcs1v15::{SigningKey, VerifyingKey}, pkcs8::{DecodePrivateKey, DecodePublicKey}, traits::PublicKeyParts, }; -use secrecy::ExposeSecret; +use secrecy::{ExposeSecret, SecretString}; use serde::{ Deserialize, Deserializer, de::{self, Visitor}, }; use std::path::PathBuf; use thiserror::Error; -use v_api_param::StringParam; +use v_api_param::{ParamResolutionError, StringParam}; +use v_model::OAuthClientId; use crate::{ authn::{ @@ -151,25 +154,108 @@ pub struct SendGridConfig { pub struct OAuthProviders { pub github: Option, pub google: Option, + pub zendesk: Option, } -#[derive(Debug, Deserialize)] +#[partial(ResolvedOAuthConfig)] +#[derive(Clone, Debug, Deserialize)] pub struct OAuthConfig { - pub device: OAuthDeviceConfig, - pub web: OAuthWebConfig, + #[partial(ResolvedOAuthConfig(retype = Option))] + pub device: Option, + #[partial(ResolvedOAuthConfig(retype = Option))] + pub web: Option, + #[partial(ResolvedOAuthConfig(retype = Option))] + pub proxy_web: Option, } -#[derive(Debug, Deserialize)] +#[partial(ResolvedOAuthDeviceConfig)] +#[derive(Clone, Debug, Deserialize)] pub struct OAuthDeviceConfig { - pub client_id: String, - pub client_secret: StringParam, + pub client_id: TypedUuid, + pub remote_client_id: String, + #[partial(ResolvedOAuthDeviceConfig(retype = SecretString))] + pub remote_client_secret: StringParam, } -#[derive(Debug, Deserialize)] +#[partial(ResolvedOAuthWebConfig)] +#[derive(Clone, Debug, Deserialize)] pub struct OAuthWebConfig { - pub client_id: String, - pub client_secret: StringParam, + pub remote_client_id: String, + #[partial(ResolvedOAuthWebConfig(retype = SecretString))] + pub remote_client_secret: StringParam, +} + +#[partial(ResolvedOAuthWebProxyConfig)] +#[derive(Clone, Debug, Deserialize)] +pub struct OAuthWebProxyConfig { + pub client_id: TypedUuid, pub redirect_uri: String, + pub proxy_port: u16, +} + +impl OAuthConfig { + pub fn resolve( + &self, + base: Option, + ) -> Result { + let device = self + .device + .as_ref() + .map(|d| d.resolve(base.clone())) + .transpose()?; + let web = self + .web + .as_ref() + .map(|w| w.resolve(base.clone())) + .transpose()?; + let proxy_web = self + .proxy_web + .as_ref() + .map(|p| p.resolve(base)) + .transpose()?; + Ok(ResolvedOAuthConfig { + device, + web, + proxy_web, + }) + } +} +impl OAuthDeviceConfig { + pub fn resolve( + &self, + base: Option, + ) -> Result { + let remote_client_secret = self.remote_client_secret.resolve(base)?; + Ok(ResolvedOAuthDeviceConfig { + client_id: self.client_id, + remote_client_id: self.remote_client_id.clone(), + remote_client_secret, + }) + } +} +impl OAuthWebConfig { + pub fn resolve( + &self, + base: Option, + ) -> Result { + let remote_client_secret = self.remote_client_secret.resolve(base)?; + Ok(ResolvedOAuthWebConfig { + remote_client_id: self.remote_client_id.clone(), + remote_client_secret, + }) + } +} +impl OAuthWebProxyConfig { + pub fn resolve( + &self, + _base: Option, + ) -> Result { + Ok(ResolvedOAuthWebProxyConfig { + client_id: self.client_id, + redirect_uri: self.redirect_uri.clone(), + proxy_port: self.proxy_port, + }) + } } impl AsymmetricKey { diff --git a/v-api/src/context/auth.rs b/v-api/src/context/auth.rs index 0e7958bc..e3208633 100644 --- a/v-api/src/context/auth.rs +++ b/v-api/src/context/auth.rs @@ -40,9 +40,22 @@ where jwks: JwkSet, signers: Vec, verifiers: Vec, + mut additional_permissions: Vec, ) -> Result { let signers = signers.into_iter().map(Arc::new).collect::>(); let verifiers = verifiers.into_iter().map(Arc::new).collect::>(); + additional_permissions.extend_from_slice(&[ + VPermission::CreateApiUser.into(), + VPermission::GetApiUsersAll.into(), + VPermission::ManageApiUsersAll.into(), + VPermission::GetApiKeysAll.into(), + VPermission::CreateGroup.into(), + VPermission::GetGroupsAll.into(), + VPermission::CreateMapper.into(), + VPermission::GetMappersAll.into(), + VPermission::GetOAuthClientsAll.into(), + VPermission::CreateAccessToken.into(), + ]); Ok(Self { unauthenticated_caller: Caller { id: "00000000-0000-4000-8000-000000000000".parse().unwrap(), @@ -51,19 +64,7 @@ where }, registration_caller: Caller { id: "00000000-0000-4000-8000-000000000001".parse().unwrap(), - permissions: vec![ - VPermission::CreateApiUser, - VPermission::GetApiUsersAll, - VPermission::ManageApiUsersAll, - VPermission::GetApiKeysAll, - VPermission::CreateGroup, - VPermission::GetGroupsAll, - VPermission::CreateMapper, - VPermission::GetMappersAll, - VPermission::GetOAuthClientsAll, - VPermission::CreateAccessToken, - ] - .into(), + permissions: additional_permissions.into(), extensions: HashMap::default(), }, jwt: JwtContext { @@ -215,6 +216,7 @@ mod tests { wrong_verifier.resolve_verifier(None).await.unwrap(), verifier.resolve_verifier(None).await.unwrap(), ], + vec![], ) .unwrap(); diff --git a/v-api/src/context/login.rs b/v-api/src/context/login.rs index e1264642..d2dde23a 100644 --- a/v-api/src/context/login.rs +++ b/v-api/src/context/login.rs @@ -44,14 +44,12 @@ where attempt: LoginAttempt, code: String, ) -> Result { - let mut attempt: NewLoginAttempt = attempt.into(); - attempt.provider_authz_code = Some(code); + let mut update: NewLoginAttempt = attempt.into(); + update.provider_authz_code = Some(code); + update.attempt_state = LoginAttemptState::RemoteAuthenticated; + update.authz_code = Some(CsrfToken::new_random().secret().to_string()); - // TODO: Internal state changes to the struct - attempt.attempt_state = LoginAttemptState::RemoteAuthenticated; - attempt.authz_code = Some(CsrfToken::new_random().secret().to_string()); - - LoginAttemptStore::upsert(&*self.storage, attempt).await + LoginAttemptStore::update_if_state(&*self.storage, update, LoginAttemptState::New).await } pub async fn get_login_attempt( @@ -64,10 +62,12 @@ where pub async fn get_login_attempt_for_code( &self, code: &str, + provider: &str, ) -> Result, StoreError> { let filter = LoginAttemptFilter { attempt_state: Some(vec![LoginAttemptState::RemoteAuthenticated]), authz_code: Some(vec![code.to_string()]), + provider: Some(vec![provider.to_string()]), ..Default::default() }; @@ -84,25 +84,38 @@ where Ok(attempts.pop()) } - pub async fn complete_login_attempt( + /// Atomically claim a login attempt by transitioning it from `RemoteAuthenticated` + /// to `Complete`. Returns an error if the attempt has already been claimed by a + /// concurrent request (i.e., the state is no longer `RemoteAuthenticated`). + /// This must be called before exchanging the authorization code with the remote + /// provider to prevent the same code from being used twice (RFC 6749 §4.1.2). + pub async fn claim_login_attempt( &self, attempt: LoginAttempt, ) -> Result { - let mut attempt: NewLoginAttempt = attempt.into(); - attempt.attempt_state = LoginAttemptState::Complete; - LoginAttemptStore::upsert(&*self.storage, attempt).await + let mut update: NewLoginAttempt = attempt.into(); + update.attempt_state = LoginAttemptState::Complete; + + LoginAttemptStore::update_if_state( + &*self.storage, + update, + LoginAttemptState::RemoteAuthenticated, + ) + .await } pub async fn fail_login_attempt( &self, attempt: LoginAttempt, + expected_state: LoginAttemptState, error: Option<&str>, provider_error: Option<&str>, ) -> Result { - let mut attempt: NewLoginAttempt = attempt.into(); - attempt.attempt_state = LoginAttemptState::Failed; - attempt.error = error.map(|s| s.to_string()); - attempt.provider_error = provider_error.map(|s| s.to_string()); - LoginAttemptStore::upsert(&*self.storage, attempt).await + let mut update: NewLoginAttempt = attempt.into(); + update.attempt_state = LoginAttemptState::Failed; + update.error = error.map(|s| s.to_string()); + update.provider_error = provider_error.map(|s| s.to_string()); + + LoginAttemptStore::update_if_state(&*self.storage, update, expected_state).await } } diff --git a/v-api/src/context/mod.rs b/v-api/src/context/mod.rs index e95b2b20..43321639 100644 --- a/v-api/src/context/mod.rs +++ b/v-api/src/context/mod.rs @@ -764,6 +764,7 @@ pub struct VContextBuilder { keys: Option>, #[cfg(feature = "sagas")] saga: Option<(TypedUuid, Option)>, + additional_builtin_permissions: Vec, } impl Default for VContextBuilder @@ -790,6 +791,7 @@ where keys: None, #[cfg(feature = "sagas")] saga: None, + additional_builtin_permissions: Vec::new(), } } @@ -838,6 +840,11 @@ where self } + pub fn with_additional_builtin_permissions(mut self, permissions: Vec) -> Self { + self.additional_builtin_permissions = permissions; + self + } + pub async fn build(self) -> Result, VContextBuilderError> { if self.storage.is_some() && self.storage_url.is_some() { return Err(VContextBuilderError::ConfigConflict( @@ -900,7 +907,14 @@ where .into_iter() .filter_map(|key| key.ok()) .collect::>(); - let auth_ctx = AuthContext::new(jwt, jwks, signers, verifiers).map_err(|err| { + let auth_ctx = AuthContext::new( + jwt, + jwks, + signers, + verifiers, + self.additional_builtin_permissions, + ) + .map_err(|err| { tracing::error!(?err, "Auth context construction failed"); VContextError::InternalAuthContext })?; @@ -1248,8 +1262,13 @@ pub(crate) mod test_mocks { use crate::{ VContextBuilder, - config::JwtConfig, - endpoints::login::oauth::{OAuthProviderName, google::GoogleOAuthProvider}, + config::{ + JwtConfig, ResolvedOAuthConfig, ResolvedOAuthWebConfig, ResolvedOAuthWebProxyConfig, + }, + endpoints::login::oauth::{ + OAuthProviderName, remote::google::GoogleOAuthProvider, + remote::zendesk::ZendeskOAuthProvider, + }, mapper::DefaultMappingEngine, permissions::VPermission, util::tests::{MockKey, mock_key}, @@ -1261,7 +1280,7 @@ pub(crate) mod test_mocks { pub async fn mock_context(storage: Arc) -> VContext { let MockKey { signer, verifier } = mock_key("test"); let mut ctx = VContextBuilder::::new() - .with_public_url("".to_string()) + .with_public_url("https://test_public_url".to_string()) .with_storage(storage) .with_jwt_expiration(JwtConfig::default().default_expiration) .with_keys(vec![signer, verifier]) @@ -1280,10 +1299,38 @@ pub(crate) mod test_mocks { OAuthProviderName::Google, Box::new(move || { Box::new(GoogleOAuthProvider::new( - "google_device_client_id".to_string(), - "google_device_client_secret".to_string().into(), - "google_web_client_id".to_string(), - "google_web_client_secret".to_string().into(), + ResolvedOAuthConfig { + device: None, + web: Some(ResolvedOAuthWebConfig { + remote_client_id: "google_web_client_id".to_string(), + remote_client_secret: "google_web_client_secret".to_string().into(), + }), + proxy_web: None, + }, + "https://test_public_url".to_string(), + None, + )) + }), + ); + + ctx.auth.insert_oauth_provider( + OAuthProviderName::Zendesk, + Box::new(move || { + Box::new(ZendeskOAuthProvider::new( + ResolvedOAuthConfig { + device: None, + web: Some(ResolvedOAuthWebConfig { + remote_client_id: "zendesk_web_client_id".to_string(), + remote_client_secret: "zendesk_web_client_secret".to_string().into(), + }), + proxy_web: Some(ResolvedOAuthWebProxyConfig { + client_id: TypedUuid::new_v4(), + redirect_uri: "https://test_public_url/pkce-callback".to_string(), + proxy_port: 1234, + }), + }, + "https://test_public_url".to_string(), + "subdomain".to_string(), None, )) }), @@ -1602,6 +1649,18 @@ pub(crate) mod test_mocks { .upsert(attempt) .await } + + async fn update_if_state( + &self, + attempt: NewLoginAttempt, + expected_state: v_model::LoginAttemptState, + ) -> Result { + self.login_attempt_store + .as_ref() + .unwrap() + .update_if_state(attempt, expected_state) + .await + } } #[async_trait] diff --git a/v-api/src/context/oauth.rs b/v-api/src/context/oauth.rs index 0dc2ee87..af7590a8 100644 --- a/v-api/src/context/oauth.rs +++ b/v-api/src/context/oauth.rs @@ -40,15 +40,10 @@ where pub async fn create_oauth_client( &self, caller: &Caller, + id: TypedUuid, ) -> ResourceResult { if caller.can(&VPermission::CreateOAuthClient.into()) { - Ok(OAuthClientStore::upsert( - &*self.storage, - NewOAuthClient { - id: TypedUuid::new_v4(), - }, - ) - .await?) + Ok(OAuthClientStore::upsert(&*self.storage, NewOAuthClient { id }).await?) } else { resource_restricted() } diff --git a/v-api/src/endpoints/handlers.rs b/v-api/src/endpoints/handlers.rs index f2dd46e4..f939ee38 100644 --- a/v-api/src/endpoints/handlers.rs +++ b/v-api/src/endpoints/handlers.rs @@ -69,15 +69,16 @@ mod macros { DeleteOAuthClientSecretPath, GetOAuthClientPath, InitialOAuthClientSecretResponse, }, - code::{ + flow::code::{ authz_code_callback_op, authz_code_exchange_op, authz_code_redirect_op, - OAuthAuthzCodeExchangeBody, OAuthAuthzCodeExchangeResponse, + get_public_pkce_provider_op, + OAuthAuthzCodeExchangeBody, OAuthAuthzCodeExchangeResponse, OAuthAuthzCodeExchangeQuery, OAuthAuthzCodeQuery, OAuthAuthzCodeReturnQuery, }, - device_token::{ + flow::device_token::{ exchange_device_token_op, get_device_provider_op, AccessTokenExchangeRequest, }, - OAuthProviderInfo, OAuthProviderNameParam, + OAuthProviderNameParam, OAuthProviderDeviceInfo, OAuthProviderAuthorizationCodePkceInfo } }, mappers::{ @@ -259,7 +260,7 @@ mod macros { // LOGIN ENDPOINTS - // AUTHZ CODE + // AUTHORIZATION CODE FLOW /// Generate the remote provider login url and redirect the user #[endpoint { @@ -295,15 +296,30 @@ mod macros { }] pub async fn authz_code_exchange( rqctx: RequestContext<$context_type>, + query: Query, path: Path, body: TypedBody, ) -> Result, HttpError> { - authz_code_exchange_op(&rqctx, path, body).await + authz_code_exchange_op(&rqctx, query, path, body).await } - // DEVICE CODE + // AUTHORIZATION CODE PKCE ONLY FLOW - /// Retrieve the metadata about an OAuth provider + /// Retrieve the metadata about an OAuth provider for public PKCE authorization code flow + #[endpoint { + method = GET, + path = "/login/oauth/{provider}/public-pkce" + }] + pub async fn get_web_pkce_provider( + rqctx: RequestContext<$context_type>, + path: Path, + ) -> Result, HttpError> { + get_public_pkce_provider_op(&rqctx, path).await + } + + // DEVICE CODE FLOW + + /// Retrieve the metadata about an OAuth provider for limited input flow #[endpoint { method = GET, path = "/login/oauth/{provider}/device" @@ -311,7 +327,7 @@ mod macros { pub async fn get_device_provider( rqctx: RequestContext<$context_type>, path: Path, - ) -> Result, HttpError> { + ) -> Result, HttpError> { get_device_provider_op(&rqctx, path).await } @@ -756,7 +772,9 @@ mod macros { $api.register(authz_code_exchange) .expect("Failed to register endpoint"); - // OAuth Device Login + // OAuth Login + $api.register(get_web_pkce_provider) + .expect("Failed to register endpoint"); $api.register(get_device_provider) .expect("Failed to register endpoint"); $api.register(exchange_device_token) diff --git a/v-api/src/endpoints/login/local/mod.rs b/v-api/src/endpoints/login/local/mod.rs index 122c0850..92f0e307 100644 --- a/v-api/src/endpoints/login/local/mod.rs +++ b/v-api/src/endpoints/login/local/mod.rs @@ -13,7 +13,7 @@ use v_model::permissions::PermissionStorage; use crate::{ authn::jwt::Claims, context::ApiContext, - endpoints::login::{ExternalUserId, UserInfo, oauth::device_token::ProxyTokenResponse}, + endpoints::login::{ExternalUserId, UserInfo, oauth::flow::device_token::ProxyTokenResponse}, permissions::{VAppPermission, VPermission}, }; @@ -38,6 +38,7 @@ where external_id: ExternalUserId::Local(body.external_id), verified_emails: vec![body.email], display_name: Some("Local Dev".to_string()), + idp_token: None, }; let (api_user, api_user_provider) = ctx diff --git a/v-api/src/endpoints/login/magic_link/client.rs b/v-api/src/endpoints/login/magic_link/client.rs index 09d3db38..a1981689 100644 --- a/v-api/src/endpoints/login/magic_link/client.rs +++ b/v-api/src/endpoints/login/magic_link/client.rs @@ -8,6 +8,7 @@ use newtype_uuid::{GenericUuid, TypedUuid}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use tracing::instrument; +use url::Url; use v_model::{ MagicLink, MagicLinkId, MagicLinkRedirectUri, MagicLinkRedirectUriId, MagicLinkSecret, MagicLinkSecretId, @@ -20,7 +21,7 @@ use crate::{ context::{ApiContext, VContextWithCaller}, permissions::{VAppPermission, VPermission}, secrets::OpenApiSecretString, - util::response::to_internal_error, + util::response::{bad_request, to_internal_error}, }; #[instrument(skip(rqctx), err(Debug))] @@ -192,6 +193,18 @@ where let (ctx, caller) = rqctx.as_ctx().await?; let path = path.into_inner(); let body = body.into_inner(); + + // Validate that the redirect URI is a well-formed URL before storing it. + // Per RFC 6749 §3.1.2, redirect URIs must be absolute URIs and must not + // include a fragment component. + let parsed = Url::parse(&body.redirect_uri) + .map_err(|_| bad_request("Invalid redirect URI: not a valid URL"))?; + if parsed.fragment().is_some() { + return Err(bad_request( + "Invalid redirect URI: must not contain a fragment", + )); + } + Ok(HttpResponseOk( ctx.magic_link .add_magic_link_redirect_uri(&caller, &path.client_id, &body.redirect_uri) diff --git a/v-api/src/endpoints/login/magic_link/mod.rs b/v-api/src/endpoints/login/magic_link/mod.rs index ca868570..d36181f9 100644 --- a/v-api/src/endpoints/login/magic_link/mod.rs +++ b/v-api/src/endpoints/login/magic_link/mod.rs @@ -218,6 +218,7 @@ where external_id: ExternalUserId::MagicLink(body.recipient.clone()), verified_emails: vec![body.recipient], display_name: None, + idp_token: None, }, ) .await?; @@ -312,8 +313,9 @@ impl CheckMagicLinkClient for MagicLink { fn is_redirect_uri_valid(&self, redirect_uri: &str) -> bool { tracing::trace!(?redirect_uri, valid_uris = ?self.redirect_uris, "Checking redirect uri against list of valid uris"); - self.redirect_uris - .iter() - .any(|r| r.redirect_uri == redirect_uri) + super::is_redirect_uri_valid( + redirect_uri, + self.redirect_uris.iter().map(|r| r.redirect_uri.as_str()), + ) } } diff --git a/v-api/src/endpoints/login/mod.rs b/v-api/src/endpoints/login/mod.rs index 2088db31..7034d782 100644 --- a/v-api/src/endpoints/login/mod.rs +++ b/v-api/src/endpoints/login/mod.rs @@ -10,6 +10,7 @@ use serde::{ de::{self, Visitor}, }; use thiserror::Error; +use url::Url; use crate::{ permissions::VPermission, @@ -60,6 +61,7 @@ impl From for HttpError { pub enum ExternalUserId { GitHub(String), Google(String), + Zendesk(String), #[cfg(feature = "local-dev")] Local(String), MagicLink(String), @@ -70,6 +72,7 @@ impl ExternalUserId { match self { Self::GitHub(id) => id, Self::Google(id) => id, + Self::Zendesk(id) => id, #[cfg(feature = "local-dev")] Self::Local(id) => id, Self::MagicLink(id) => id, @@ -80,6 +83,7 @@ impl ExternalUserId { match self { Self::GitHub(_) => "github", Self::Google(_) => "google", + Self::Zendesk(_) => "zendesk", #[cfg(feature = "local-dev")] Self::Local(_) => "local", Self::MagicLink(_) => "magic-link", @@ -103,6 +107,7 @@ impl Serialize for ExternalUserId { match self { ExternalUserId::GitHub(id) => serializer.serialize_str(&format!("github-{}", id)), ExternalUserId::Google(id) => serializer.serialize_str(&format!("google-{}", id)), + ExternalUserId::Zendesk(id) => serializer.serialize_str(&format!("zendesk-{}", id)), #[cfg(feature = "local-dev")] ExternalUserId::Local(id) => serializer.serialize_str(&format!("local-{}", id)), ExternalUserId::MagicLink(id) => { @@ -142,6 +147,12 @@ impl<'de> Deserialize<'de> for ExternalUserId { } else { Err(de::Error::custom(ExternalUserIdDeserializeError::Empty)) } + } else if let Some(("", id)) = value.split_once("zendesk-") { + if !id.is_empty() { + Ok(ExternalUserId::Zendesk(id.to_string())) + } else { + Err(de::Error::custom(ExternalUserIdDeserializeError::Empty)) + } } else if let Some(("", id)) = value.split_once("local-") { #[cfg(feature = "local-dev")] { @@ -181,6 +192,7 @@ pub struct UserInfo { pub external_id: ExternalUserId, pub verified_emails: Vec, pub display_name: Option, + pub idp_token: Option, } #[derive(Debug, Error)] @@ -191,11 +203,134 @@ pub enum UserInfoError { Deserialize(#[from] serde_json::Error), #[error("Failed to create user info request {0}")] Http(#[from] http::Error), + #[error("User account is locked")] + Locked, #[error("User information is missing")] MissingUserInfoData(String), + #[error("User info endpoint returned HTTP {status} for {endpoint}")] + UnexpectedStatus { + endpoint: String, + status: http::StatusCode, + }, } #[async_trait] pub trait UserInfoProvider { async fn get_user_info(&self, token: &str) -> Result; } + +/// Structurally compare a candidate redirect URI against a list of registered redirect URIs. +/// Comparison is performed on scheme, host, port, and path. URIs that fail to parse or contain +/// fragments are rejected (per RFC 6749 §3.1.2). +pub fn is_redirect_uri_valid<'a>( + redirect_uri: &str, + registered_uris: impl Iterator, +) -> bool { + let candidate = match Url::parse(redirect_uri) { + Ok(url) => url, + Err(_) => return false, + }; + + // Reject redirect URIs that contain fragments (per RFC 6749 §3.1.2) + if candidate.fragment().is_some() { + return false; + } + + registered_uris + .into_iter() + .any(|registered| match Url::parse(registered) { + Ok(registered) => { + registered.scheme() == candidate.scheme() + && registered.host() == candidate.host() + && registered.port() == candidate.port() + && registered.path() == candidate.path() + && registered.query() == candidate.query() + } + Err(_) => false, + }) +} + +#[cfg(test)] +mod tests { + use super::is_redirect_uri_valid; + + #[test] + fn test_redirect_uri_exact_match() { + assert!(is_redirect_uri_valid( + "https://example.com/callback", + ["https://example.com/callback"].iter().copied(), + )); + } + + #[test] + fn test_redirect_uri_rejects_different_host() { + assert!(!is_redirect_uri_valid( + "https://evil.com/callback", + ["https://example.com/callback"].iter().copied(), + )); + } + + #[test] + fn test_redirect_uri_rejects_different_path() { + assert!(!is_redirect_uri_valid( + "https://example.com/other", + ["https://example.com/callback"].iter().copied(), + )); + } + + #[test] + fn test_redirect_uri_rejects_fragment() { + assert!(!is_redirect_uri_valid( + "https://example.com/callback#fragment", + ["https://example.com/callback"].iter().copied(), + )); + } + + #[test] + fn test_redirect_uri_rejects_unparseable() { + assert!(!is_redirect_uri_valid( + "not-a-url", + ["https://example.com/callback"].iter().copied(), + )); + } + + #[test] + fn test_redirect_uri_query_params_must_match() { + // Registered with query params — candidate must have the same query + assert!(is_redirect_uri_valid( + "https://example.com/callback?key=value", + ["https://example.com/callback?key=value"].iter().copied(), + )); + + // Different query param value must be rejected + assert!(!is_redirect_uri_valid( + "https://example.com/callback?key=evil", + ["https://example.com/callback?key=value"].iter().copied(), + )); + + // Missing query params when registered URI has them must be rejected + assert!(!is_redirect_uri_valid( + "https://example.com/callback", + ["https://example.com/callback?key=value"].iter().copied(), + )); + + // Extra query params when registered URI has none must be rejected + assert!(!is_redirect_uri_valid( + "https://example.com/callback?extra=param", + ["https://example.com/callback"].iter().copied(), + )); + } + + #[test] + fn test_redirect_uri_matches_with_port() { + assert!(is_redirect_uri_valid( + "https://example.com:8443/callback", + ["https://example.com:8443/callback"].iter().copied(), + )); + + assert!(!is_redirect_uri_valid( + "https://example.com:9999/callback", + ["https://example.com:8443/callback"].iter().copied(), + )); + } +} diff --git a/v-api/src/endpoints/login/oauth/client.rs b/v-api/src/endpoints/login/oauth/client.rs index 59562ccc..22e0e0df 100644 --- a/v-api/src/endpoints/login/oauth/client.rs +++ b/v-api/src/endpoints/login/oauth/client.rs @@ -8,6 +8,7 @@ use newtype_uuid::{GenericUuid, TypedUuid}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use tracing::instrument; +use url::Url; use v_model::{ OAuthClient, OAuthClientId, OAuthClientRedirectUri, OAuthClientSecret, OAuthRedirectUriId, OAuthSecretId, @@ -20,7 +21,7 @@ use crate::{ context::{ApiContext, VContextWithCaller}, permissions::{VAppPermission, VPermission}, secrets::OpenApiSecretString, - util::response::to_internal_error, + util::response::{bad_request, to_internal_error}, }; #[instrument(skip(rqctx), err(Debug))] @@ -54,7 +55,10 @@ where T: VAppPermission + From + PermissionStorage, { // Create the new client - let client = ctx.oauth.create_oauth_client(&caller).await?; + let client = ctx + .oauth + .create_oauth_client(&caller, TypedUuid::new_v4()) + .await?; // Give the caller permission to perform actions on the client ctx.user @@ -188,6 +192,18 @@ where let (ctx, caller) = rqctx.as_ctx().await?; let path = path.into_inner(); let body = body.into_inner(); + + // Validate that the redirect URI is a well-formed URL before storing it. + // Per RFC 6749 §3.1.2, redirect URIs must be absolute URIs and must not + // include a fragment component. + let parsed = Url::parse(&body.redirect_uri) + .map_err(|_| bad_request("Invalid redirect URI: not a valid URL"))?; + if parsed.fragment().is_some() { + return Err(bad_request( + "Invalid redirect URI: must not contain a fragment", + )); + } + Ok(HttpResponseOk( ctx.oauth .add_oauth_redirect_uri(&caller, &path.client_id, &body.redirect_uri) diff --git a/v-api/src/endpoints/login/oauth/code.rs b/v-api/src/endpoints/login/oauth/code.rs deleted file mode 100644 index 6ba1d0d1..00000000 --- a/v-api/src/endpoints/login/oauth/code.rs +++ /dev/null @@ -1,1637 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -use base64::{Engine, prelude::BASE64_URL_SAFE_NO_PAD}; -use chrono::{TimeDelta, Utc}; -use cookie::{Cookie, SameSite}; -use dropshot::{ - ClientErrorStatusCode, HttpError, HttpResponseOk, HttpResponseTemporaryRedirect, Path, Query, - RequestContext, RequestInfo, SharedExtractor, TypedBody, http_response_temporary_redirect, -}; -use dropshot_authorization_header::basic::BasicAuth; -use http::{HeaderValue, header::SET_COOKIE}; -use newtype_uuid::TypedUuid; -use oauth2::{ - AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, -}; -use percent_encoding::{NON_ALPHANUMERIC, percent_encode}; -use schemars::JsonSchema; -use secrecy::SecretString; -use serde::{Deserialize, Serialize}; -use sha2::{Digest, Sha256}; -use std::{fmt::Debug, ops::Add}; -use tap::TapFallible; -use tracing::instrument; -use v_model::{ - LoginAttempt, LoginAttemptId, NewLoginAttempt, OAuthClient, OAuthClientId, - permissions::{AsScope, PermissionStorage}, - schema_ext::LoginAttemptState, -}; - -use super::{OAuthProvider, OAuthProviderNameParam, UserInfoProvider, WebClientConfig}; -use crate::{ - authn::key::RawKey, - context::{ApiContext, VContext}, - endpoints::login::{ - LoginError, UserInfo, - oauth::{CheckOAuthClient, ClientType}, - }, - error::ApiError, - permissions::{VAppPermission, VPermission}, - secrets::OpenApiSecretString, - util::{ - request::RequestCookies, - response::{ResourceError, internal_error, to_internal_error, unauthorized}, - }, -}; - -static LOGIN_ATTEMPT_COOKIE: &str = "__v_login"; -static DEFAULT_SCOPE: &str = "user:info:r"; - -#[derive(Debug, Deserialize, JsonSchema, Serialize, PartialEq, Eq)] -struct OAuthError { - error: OAuthErrorCode, - #[serde(skip_serializing_if = "Option::is_none")] - error_description: Option, - #[serde(skip_serializing_if = "Option::is_none")] - error_uri: Option, - #[serde(skip_serializing_if = "Option::is_none")] - state: Option, -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize, PartialEq, Eq)] -#[serde(untagged)] -enum OAuthErrorCode { - AccessDenied, - InvalidClient, - InvalidGrant, - InvalidRequest, - InvalidScope, - ServerError, - TemporarilyUnavailable, - UnauthorizedClient, - UnsupportedGrantType, - UnsupportedResponseType, -} - -impl From for HttpError { - fn from(value: OAuthError) -> Self { - let serialized = serde_json::to_string(&value).unwrap(); - HttpError { - headers: None, - status_code: ClientErrorStatusCode::BAD_REQUEST.into(), - error_code: None, - external_message: serialized.clone(), - internal_message: serialized, - } - } -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct OAuthAuthzCodeQuery { - pub client_id: TypedUuid, - pub redirect_uri: String, - pub response_type: String, - pub state: String, - pub scope: Option, -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct OAuthAuthzCodeRedirectHeaders { - #[serde(rename = "set-cookies")] - cookies: String, - location: String, -} - -// Lookup the client specified by the provided client id and verify that the redirect uri -// is a valid for this client. If either of these fail we return an unauthorized response -async fn get_oauth_client( - ctx: &VContext, - client_id: &TypedUuid, - redirect_uri: &str, -) -> Result -where - T: VAppPermission + PermissionStorage, -{ - let client = ctx - .oauth - .get_oauth_client(&ctx.builtin_registration_user(), client_id) - .await - .map_err(|err| { - tracing::error!(?err, "Failed to lookup OAuth client"); - - match err { - ResourceError::DoesNotExist => OAuthError { - error: OAuthErrorCode::InvalidClient, - error_description: Some("Unknown client id".to_string()), - error_uri: None, - state: None, - }, - // Given that the builtin caller should have access to all OAuth clients, any other - // error is considered an internal error - _ => OAuthError { - error: OAuthErrorCode::ServerError, - error_description: None, - error_uri: None, - state: None, - }, - } - })?; - - if client.is_redirect_uri_valid(redirect_uri) { - Ok(client) - } else { - Err(OAuthError { - error: OAuthErrorCode::InvalidRequest, - error_description: Some("Invalid redirect uri".to_string()), - error_uri: None, - state: None, - }) - } -} - -#[instrument(skip(rqctx), err(Debug))] -pub async fn authz_code_redirect_op( - rqctx: &RequestContext>, - path: Path, - query: Query, -) -> Result -where - T: VAppPermission + PermissionStorage, -{ - let ctx = rqctx.v_ctx(); - let path = path.into_inner(); - let query = query.into_inner(); - - get_oauth_client(ctx, &query.client_id, &query.redirect_uri).await?; - - tracing::debug!(?query.client_id, ?query.redirect_uri, "Verified client id and redirect uri"); - - // Find the configured provider for the requested remote backend. We should always have a valid - // provider value, so if this fails then a 500 is returned - let provider = ctx - .get_oauth_provider(&path.provider) - .await - .map_err(ApiError::OAuth)?; - - tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for authz code login"); - - // Check that the passed in scopes are valid. The scopes are not currently restricted by client - let scope = query.scope.unwrap_or_else(|| DEFAULT_SCOPE.to_string()); - let scope_error = VPermission::from_scope_arg(&scope) - .err() - .map(|_| "invalid_scope".to_string()); - - // Construct a new login attempt with the minimum required values - let mut attempt = NewLoginAttempt::new( - provider.name().to_string(), - query.client_id, - query.redirect_uri, - scope, - ) - .map_err(|err| { - tracing::error!(?err, "Attempted to construct invalid login attempt"); - internal_error("Attempted to construct invalid login attempt".to_string()) - })?; - - // Set a default expiration for the login attempt - // TODO: Make this configurable - attempt.expires_at = Some(Utc::now().add(TimeDelta::try_minutes(5).unwrap())); - - // Assign any scope errors that arose - attempt.error = scope_error; - - // Add in the user defined state and redirect uri. State is an arbitrary value and may be - // malicious. It must be url-encoded before being presented back to the client. Therefore we - // process once before storing so all downstream consumers see the encoded value. - attempt.state = Some(percent_encode(query.state.as_bytes(), NON_ALPHANUMERIC).to_string()); - - // If the remote provider supports pkce, set up a challenge - let pkce_challenge = if provider.supports_pkce() { - let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); - attempt.provider_pkce_verifier = Some(pkce_verifier.secret().to_string()); - Some(pkce_challenge) - } else { - None - }; - - // Store the generated attempt - let attempt = ctx - .login - .create_login_attempt(attempt) - .await - .map_err(to_internal_error)?; - - tracing::info!(?attempt.id, "Created login attempt"); - - oauth_redirect_response(ctx.public_url(), &*provider, &attempt, pkce_challenge) -} - -fn oauth_redirect_response( - public_url: &str, - provider: &dyn OAuthProvider, - attempt: &LoginAttempt, - code_challenge: Option, -) -> Result { - // We may fail if the provider configuration is not correctly configured - // TODO: This behavior should be changed so that clients are precomputed. We do not need to be - // constructing a new client on every request. That said, we need to ensure the client does not - // maintain state between requests - let client = provider - .as_web_client(&WebClientConfig { - prefix: public_url.to_string(), - }) - .map_err(to_internal_error)?; - - // Create an attempt cookie header for storing the login attempt. This also acts as our csrf - // check - let mut cookie = Cookie::new(LOGIN_ATTEMPT_COOKIE, attempt.id.to_string()); - cookie.set_http_only(true); - cookie.set_same_site(SameSite::Lax); - cookie.set_secure(public_url.starts_with("https")); - cookie.set_max_age(cookie::time::Duration::seconds(600)); - - let login_cookie = HeaderValue::from_str(&cookie.to_string()).map_err(to_internal_error)?; - - // Generate the url to the remote provider that the user will be redirected to - let mut authz_url = client - .authorize_url(|| CsrfToken::new(attempt.id.to_string())) - .add_scopes( - provider - .scopes() - .into_iter() - .map(|s| Scope::new(s.to_string())) - .collect::>(), - ); - - // If the caller has provided a code challenge, add it to the url - if let Some(challenge) = code_challenge { - authz_url = authz_url.set_pkce_challenge(challenge); - }; - - let mut redirect = http_response_temporary_redirect(authz_url.url().0.to_string())?; - redirect.headers_mut().append(SET_COOKIE, login_cookie); - - Ok(redirect) -} - -// TODO: Determine if 401 empty responses are correct here -fn verify_csrf( - request: &RequestInfo, - query: &OAuthAuthzCodeReturnQuery, -) -> Result, HttpError> { - // If we are missing the expected state parameter then we can not proceed at all with verifying - // this callback request. We also do not have a redirect uri to send the user to so we instead - // report unauthorized - let attempt_id = query - .state - .as_ref() - .ok_or_else(|| { - tracing::warn!("OAuth callback is missing a state parameter"); - unauthorized() - })? - .parse() - .map_err(|err| { - tracing::warn!(?err, "Failed to parse state"); - unauthorized() - })?; - - // The client must present the attempt cookie at a minimum. Without it we are unable to lookup a - // login attempt to match against. Without the cookie to verify the state parameter we can not - // determine a redirect uri so we instead report unauthorized - let attempt_cookie = request - .cookie(LOGIN_ATTEMPT_COOKIE) - .ok_or_else(|| { - tracing::warn!("OAuth callback is missing a login state cookie"); - unauthorized() - })? - .value() - .parse() - .map_err(|err| { - tracing::warn!(?err, "Failed to parse state"); - unauthorized() - })?; - - // Verify that the attempt_id returned from the state matches the expected client value. If they - // do not match we can not lookup a redirect uri so we instead return unauthorized - if attempt_id != attempt_cookie { - tracing::warn!( - ?attempt_id, - ?attempt_cookie, - "OAuth state does not match expected cookie value" - ); - Err(unauthorized()) - } else { - Ok(attempt_id) - } -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct OAuthAuthzCodeReturnQuery { - pub state: Option, - pub code: Option, - pub error: Option, -} - -#[instrument(skip(rqctx), err(Debug))] -pub async fn authz_code_callback_op( - rqctx: &RequestContext>, - path: Path, - query: Query, -) -> Result -where - T: VAppPermission + PermissionStorage, -{ - let ctx = rqctx.v_ctx(); - let path = path.into_inner(); - let query = query.into_inner(); - let provider = ctx - .get_oauth_provider(&path.provider) - .await - .map_err(ApiError::OAuth)?; - - tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for authz code exchange"); - - // Verify and extract the attempt id before performing any work - let attempt_id = verify_csrf(&rqctx.request, &query)?; - - // Clear the login attempt cookie - let mut cookie = Cookie::new(LOGIN_ATTEMPT_COOKIE, ""); - cookie.set_http_only(true); - cookie.set_same_site(SameSite::Lax); - cookie.set_secure(ctx.public_url().starts_with("https")); - cookie.set_max_age(cookie::time::Duration::seconds(0)); - let login_cookie = HeaderValue::from_str(&cookie.to_string()).map_err(to_internal_error)?; - - let mut redirect = http_response_temporary_redirect( - authz_code_callback_op_inner(ctx, &attempt_id, query.code, query.error).await?, - )?; - redirect.headers_mut().append(SET_COOKIE, login_cookie); - - Ok(redirect) -} - -pub async fn authz_code_callback_op_inner( - ctx: &VContext, - attempt_id: &TypedUuid, - code: Option, - error: Option, -) -> Result -where - T: VAppPermission + PermissionStorage, -{ - // We have now verified the attempt id and can use it to look up the rest of the login attempt - // material to try and complete the flow - let mut attempt = ctx - .login - .get_login_attempt(attempt_id) - .await - .map_err(to_internal_error)? - .ok_or_else(|| { - // If we fail to find a matching attempt, there is not much we can do other than return - // unauthorized - unauthorized() - }) - .and_then(|attempt| { - if attempt.attempt_state == LoginAttemptState::New { - Ok(attempt) - } else { - Err(unauthorized()) - } - })?; - - attempt = match (code, error) { - (Some(code), None) => { - tracing::info!(?attempt.id, "Received valid login attempt. Storing authorization code"); - - // Store the authorization code returned by the underlying OAuth provider and transition the - // attempt to the awaiting state - ctx.login - .set_login_provider_authz_code(attempt, code.to_string()) - .await - .map_err(to_internal_error)? - } - (code, error) => { - tracing::info!(?attempt.id, ?error, "Received an error response from the remote server"); - - // Store the provider return error for future debugging, but if an error has been - // returned or there is a missing code, then we can not report a successful process - attempt.provider_authz_code = code; - - // When a user has explicitly denied access we want to forward that error message - // onwards to the upstream requester. All other errors should be opaque to the - // original requester and are returned as server errors - let error_message = match error.as_deref() { - Some("access_denied") => "access_denied", - _ => "server_error", - }; - - // TODO: Specialize the returned error - ctx.login - .fail_login_attempt(attempt, Some(error_message), error.as_deref()) - .await - .map_err(to_internal_error)? - } - }; - - // Redirect back to the original authenticator - Ok(attempt.callback_url()) -} - -#[derive(Debug, Deserialize, JsonSchema)] -pub struct OAuthAuthzCodeExchangeBody { - pub client_id: Option>, - pub client_secret: Option, - pub redirect_uri: String, - pub grant_type: String, - pub code: String, - pub pkce_verifier: Option, -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct OAuthAuthzCodeExchangeResponse { - pub access_token: String, - pub token_type: String, - pub expires_in: i64, -} - -#[instrument(skip(rqctx), err(Debug))] -pub async fn authz_code_exchange_op( - rqctx: &RequestContext>, - path: Path, - body: TypedBody, -) -> Result, HttpError> -where - T: VAppPermission + PermissionStorage, -{ - let ctx = rqctx.v_ctx(); - let path = path.into_inner(); - let body = body.into_inner(); - - let (client_id, client_secret) = - if let (Some(client_id), Some(client_secret)) = (body.client_id, body.client_secret) { - Ok::<_, HttpError>((client_id, client_secret)) - } else { - // Attempt to extract basic authorization credentials from the request if they were not - // present in the request body - let auth = ::from_request(rqctx) - .await - .tap_err(|err| { - tracing::warn!(?err, "Failed to extract basic authentication values"); - }); - let (client_id, client_secret) = match auth { - Ok(auth) if auth.username().is_some() && auth.password().is_some() => Ok(( - auth.username().unwrap().to_string(), - auth.password().unwrap().to_string(), - )), - _ => Err(internal_error( - "Missing client id and client secret from authz code exchange", - )), - }?; - - Ok(( - client_id.parse().map_err(to_internal_error)?, - OpenApiSecretString(client_secret.into()), - )) - }?; - - let provider = ctx - .get_oauth_provider(&path.provider) - .await - .map_err(ApiError::OAuth)?; - - tracing::debug!("Attempting code exchange"); - - // Verify the submitted client credentials - authorize_code_exchange( - ctx, - &body.grant_type, - client_id, - &client_secret.0, - &body.redirect_uri, - ) - .await?; - - tracing::debug!("Authorized code exchange"); - - // Lookup the request assigned to this code - let mut attempt = ctx - .login - .get_login_attempt_for_code(&body.code) - .await - .map_err(to_internal_error)? - .ok_or(OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: None, - error_uri: None, - state: None, - })?; - - // Verify that the login attempt is valid and matches the submitted client credentials - verify_login_attempt( - &attempt, - client_id, - &body.redirect_uri, - body.pkce_verifier.as_deref(), - )?; - - tracing::debug!("Verified login attempt"); - - // Now that the attempt has been confirmed, use it to fetch user information form the remote - // provider - let info = fetch_user_info(ctx.public_url(), &ctx.web_client(), &*provider, &attempt).await?; - - tracing::debug!("Retrieved user information from remote provider"); - - // During fetch_user_info we revoke any downstream codes if possible, therefore At this point we - // consider the login attempt to be consumed and can no longer be used. We state transition to - // complete, even though we may fail further along in the handler. If a failure occurs then the - // user will need to re-authenticate. - attempt = ctx - .login - .complete_login_attempt(attempt) - .await - .map_err(|err| { - tracing::error!(?err, "Failed to complete login attempt"); - OAuthError { - error: OAuthErrorCode::ServerError, - error_description: Some("An unexpected error occurred".to_string()), - error_uri: None, - state: None, - } - })?; - - // Register this user as an API user if needed - let (api_user_info, api_user_provider) = ctx - .register_api_user(&ctx.builtin_registration_user(), info) - .await?; - - tracing::info!(api_user_id = ?api_user_info.user.id, "Retrieved api user to generate access token for"); - - let scope = attempt - .scope - .split(' ') - .map(|s| s.to_string()) - .collect::>(); - - let token = ctx - .generate_access_token( - &ctx.builtin_registration_user(), - &api_user_info.user.id, - &api_user_provider.id, - Some(scope), - ) - .await?; - - Ok(HttpResponseOk(OAuthAuthzCodeExchangeResponse { - token_type: "Bearer".to_string(), - access_token: token.signed_token, - expires_in: token.expires_in, - })) -} - -async fn authorize_code_exchange( - ctx: &VContext, - grant_type: &str, - client_id: TypedUuid, - client_secret: &SecretString, - redirect_uri: &str, -) -> Result<(), OAuthError> -where - T: VAppPermission + PermissionStorage, -{ - let client = get_oauth_client(ctx, &client_id, redirect_uri).await?; - - // Verify that we received the expected grant type - if grant_type != "authorization_code" { - return Err(OAuthError { - error: OAuthErrorCode::UnsupportedGrantType, - error_description: None, - error_uri: None, - state: None, - }); - } - - tracing::debug!(grant_type, "Verified grant type"); - - let client_secret = RawKey::try_from(client_secret).map_err(|err| { - tracing::warn!(?err, "Failed to parse OAuth client secret"); - - OAuthError { - error: OAuthErrorCode::InvalidRequest, - error_description: Some("Malformed client secret".to_string()), - error_uri: None, - state: None, - } - })?; - - tracing::debug!("Constructed client secret"); - - if !client.is_secret_valid(&client_secret, ctx) { - Err(OAuthError { - error: OAuthErrorCode::InvalidClient, - error_description: Some("Invalid client secret".to_string()), - error_uri: None, - state: None, - }) - } else { - tracing::debug!("Verified client secret validity"); - - Ok(()) - } -} - -fn verify_login_attempt( - attempt: &LoginAttempt, - client_id: TypedUuid, - redirect_uri: &str, - pkce_verifier: Option<&str>, -) -> Result<(), OAuthError> { - if attempt.client_id != client_id { - Err(OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Invalid client id".to_string()), - error_uri: None, - state: None, - }) - } else if attempt.redirect_uri != redirect_uri { - Err(OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Invalid redirect uri".to_string()), - error_uri: None, - state: None, - }) - } else if attempt.attempt_state != LoginAttemptState::RemoteAuthenticated { - Err(OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Grant is in an invalid state".to_string()), - error_uri: None, - state: None, - }) - } else if attempt.expires_at.map(|t| t <= Utc::now()).unwrap_or(true) { - Err(OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Grant has expired".to_string()), - error_uri: None, - state: None, - }) - } else { - match (attempt.pkce_challenge.as_deref(), pkce_verifier) { - (Some(_), None) => Err(OAuthError { - error: OAuthErrorCode::InvalidRequest, - error_description: Some("Missing pkce verifier".to_string()), - error_uri: None, - state: None, - }), - (Some(challenge), Some(verifier)) => { - let mut hasher = Sha256::new(); - hasher.update(verifier); - let hash = hasher.finalize(); - let computed_challenge = BASE64_URL_SAFE_NO_PAD.encode(hash); - - if challenge == computed_challenge { - Ok(()) - } else { - Err(OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Invalid pkce verifier".to_string()), - error_uri: None, - state: None, - }) - } - } - (None, _) => Ok(()), - } - } -} - -#[instrument(skip(attempt))] -async fn fetch_user_info( - public_url: &str, - client_type: &ClientType, - provider: &dyn OAuthProvider, - attempt: &LoginAttempt, -) -> Result { - // Exchange the stored authorization code with the remote provider for a remote access token - let client = provider - .as_web_client(&WebClientConfig { - prefix: public_url.to_string(), - }) - .map_err(to_internal_error)?; - - let mut request = client.exchange_code(AuthorizationCode::new( - attempt - .provider_authz_code - .as_ref() - .ok_or_else(|| { - internal_error("Expected authorization code to exist due to attempt state") - })? - .to_string(), - )); - - if let Some(pkce_verifier) = &attempt.provider_pkce_verifier { - request = request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier.to_string())) - } - - let oauth_client: oauth2_reqwest::ReqwestClient = provider.client().clone().into(); - let response = request - .request_async(&oauth_client) - .await - .map_err(to_internal_error)?; - - tracing::info!("Fetched access token from remote service"); - - // Use the retrieved access token to fetch the user information from the remote API - let info = provider - .get_user_info(response.access_token().secret()) - .await - .map_err(LoginError::UserInfo) - .tap_err(|err| tracing::error!(?err, "Failed to look up user information"))?; - - tracing::info!("Fetched user info from remote service"); - - // Now that we are done with fetching user information from the remote API, we can revoke it if - // the provider supports it - if provider.token_revocation_endpoint().is_some() { - client - .revoke_token(response.access_token().into()) - .map_err(internal_error)? - .request_async(&oauth_client) - .await - .map_err(internal_error)?; - } - - Ok(info) -} - -#[cfg(test)] -mod tests { - use std::{ - net::{Ipv4Addr, SocketAddrV4}, - ops::Add, - sync::{Arc, Mutex}, - }; - - use chrono::{TimeDelta, Utc}; - use dropshot::{HttpResponse, RequestInfo}; - use http::{ - HeaderValue, StatusCode, - header::{COOKIE, LOCATION, SET_COOKIE}, - }; - use http_body_util::Empty; - use mockall::predicate::eq; - use newtype_uuid::TypedUuid; - use oauth2::PkceCodeChallenge; - use secrecy::SecretString; - use uuid::Uuid; - use v_model::{ - LoginAttempt, OAuthClient, OAuthClientRedirectUri, OAuthClientSecret, - schema_ext::LoginAttemptState, - storage::{MockLoginAttemptStore, MockOAuthClientStore}, - }; - - use crate::{ - authn::key::RawKey, - context::{ - VContext, - test_mocks::{MockStorage, mock_context}, - }, - endpoints::login::oauth::{ - OAuthProviderName, - code::{ - LOGIN_ATTEMPT_COOKIE, OAuthAuthzCodeReturnQuery, OAuthError, OAuthErrorCode, - authz_code_callback_op_inner, verify_csrf, verify_login_attempt, - }, - }, - permissions::VPermission, - }; - - use super::{authorize_code_exchange, get_oauth_client, oauth_redirect_response}; - - async fn mock_client() -> (VContext, OAuthClient, SecretString) { - let ctx = mock_context(Arc::new(MockStorage::new())).await; - let client_id = TypedUuid::new_v4(); - let key = RawKey::generate::<8>(&Uuid::new_v4()) - .sign(ctx.signer()) - .await - .unwrap(); - let secret_signature = key.signature().to_string(); - let client_secret = key.key(); - let redirect_uri = "callback-destination"; - - ( - ctx, - OAuthClient { - id: client_id, - secrets: vec![OAuthClientSecret { - id: TypedUuid::new_v4(), - oauth_client_id: client_id, - secret_signature, - created_at: Utc::now(), - deleted_at: None, - }], - redirect_uris: vec![OAuthClientRedirectUri { - id: TypedUuid::new_v4(), - oauth_client_id: client_id, - redirect_uri: redirect_uri.to_string(), - created_at: Utc::now(), - deleted_at: None, - }], - created_at: Utc::now(), - deleted_at: None, - }, - client_secret, - ) - } - - #[tokio::test] - async fn test_oauth_client_lookup_checks_redirect_uri() { - let client_id = TypedUuid::new_v4(); - let client = OAuthClient { - id: client_id, - secrets: vec![], - redirect_uris: vec![OAuthClientRedirectUri { - id: TypedUuid::new_v4(), - oauth_client_id: client_id, - redirect_uri: "https://test.oxeng.dev/callback".to_string(), - created_at: Utc::now(), - deleted_at: None, - }], - created_at: Utc::now(), - deleted_at: None, - }; - - let mut client_store = MockOAuthClientStore::new(); - client_store - .expect_get() - .with(eq(client_id), eq(false)) - .returning(move |_, _| Ok(Some(client.clone()))); - - let mut storage = MockStorage::new(); - storage.oauth_client_store = Some(Arc::new(client_store)); - let ctx = mock_context(Arc::new(storage)).await; - - let failure = get_oauth_client(&ctx, &client_id, "https://not-test.oxeng.dev/callback") - .await - .unwrap_err(); - assert_eq!(OAuthErrorCode::InvalidRequest, failure.error); - assert_eq!( - Some("Invalid redirect uri".to_string()), - failure.error_description - ); - - let success = get_oauth_client(&ctx, &client_id, "https://test.oxeng.dev/callback").await; - assert_eq!(client_id, success.unwrap().id); - } - - #[tokio::test] - async fn test_remote_provider_redirect_url() { - let storage = MockStorage::new(); - let mut ctx = mock_context(Arc::new(storage)).await; - ctx.with_public_url("https://api.oxeng.dev"); - - let (challenge, _) = PkceCodeChallenge::new_random_sha256(); - let attempt = LoginAttempt { - id: TypedUuid::new_v4(), - attempt_state: LoginAttemptState::New, - client_id: TypedUuid::new_v4(), - redirect_uri: "https://test.oxeng.dev/callback".to_string(), - state: Some("ox_state".to_string()), - pkce_challenge: Some("ox_challenge".to_string()), - pkce_challenge_method: Some("S256".to_string()), - authz_code: None, - expires_at: None, - error: None, - provider: "google".to_string(), - provider_pkce_verifier: Some("v_verifier".to_string()), - provider_authz_code: None, - provider_error: None, - created_at: Utc::now(), - updated_at: Utc::now(), - scope: String::new(), - }; - - let response = oauth_redirect_response( - ctx.public_url(), - &*ctx - .get_oauth_provider(&OAuthProviderName::Google) - .await - .unwrap(), - &attempt, - Some(challenge.clone()), - ) - .unwrap() - .to_result() - .unwrap(); - let headers = response.headers(); - - let expected_location = format!( - "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=google_web_client_id&state={}&code_challenge={}&code_challenge_method=S256&redirect_uri=https%3A%2F%2Fapi.oxeng.dev%2Flogin%2Foauth%2Fgoogle%2Fcode%2Fcallback&scope=openid+email+profile", - attempt.id, - challenge.as_str() - ); - - assert_eq!( - expected_location, - String::from_utf8(headers.get(LOCATION).unwrap().as_bytes().to_vec()).unwrap() - ); - assert_eq!( - format!( - "{}; HttpOnly; SameSite=Lax; Secure; Max-Age=600", - attempt.id - ) - .as_str(), - String::from_utf8(headers.get(SET_COOKIE).unwrap().as_bytes().to_vec()) - .unwrap() - .split_once('=') - .unwrap() - .1 - ) - } - - #[tokio::test] - async fn test_csrf_check() { - let id = TypedUuid::new_v4(); - - let mut rq = hyper::Request::new(Empty::<()>::new()); - rq.headers_mut().insert( - COOKIE, - HeaderValue::from_str(&format!("{}={}", LOGIN_ATTEMPT_COOKIE, id)).unwrap(), - ); - let with_valid_cookie = RequestInfo::new( - &rq, - std::net::SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8888)), - ); - let query = OAuthAuthzCodeReturnQuery { - state: Some(id.to_string()), - code: None, - error: None, - }; - assert_eq!(id, verify_csrf(&with_valid_cookie, &query).unwrap()); - - let query = OAuthAuthzCodeReturnQuery { - state: None, - code: None, - error: None, - }; - assert_eq!( - StatusCode::UNAUTHORIZED, - verify_csrf(&with_valid_cookie, &query) - .unwrap_err() - .status_code - ); - - let mut rq = hyper::Request::new(Empty::<()>::new()); - rq.headers_mut().insert( - COOKIE, - HeaderValue::from_str(&format!("{}={}", LOGIN_ATTEMPT_COOKIE, Uuid::new_v4())).unwrap(), - ); - let with_invalid_cookie = RequestInfo::new( - &rq, - std::net::SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8888)), - ); - let query = OAuthAuthzCodeReturnQuery { - state: Some(id.to_string()), - code: None, - error: None, - }; - assert_eq!( - StatusCode::UNAUTHORIZED, - verify_csrf(&with_invalid_cookie, &query) - .unwrap_err() - .status_code - ); - - let rq = hyper::Request::new(Empty::<()>::new()); - let with_missing_cookie = RequestInfo::new( - &rq, - std::net::SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8888)), - ); - let query = OAuthAuthzCodeReturnQuery { - state: Some(id.to_string()), - code: None, - error: None, - }; - assert_eq!( - StatusCode::UNAUTHORIZED, - verify_csrf(&with_missing_cookie, &query) - .unwrap_err() - .status_code - ); - } - - #[tokio::test] - async fn test_callback_fails_when_not_in_new_state() { - let invalid_states = [ - LoginAttemptState::Complete, - LoginAttemptState::Failed, - LoginAttemptState::RemoteAuthenticated, - ]; - - for state in invalid_states { - let attempt_id = TypedUuid::new_v4(); - let attempt = LoginAttempt { - id: attempt_id, - attempt_state: state, - client_id: TypedUuid::new_v4(), - redirect_uri: "https://test.oxeng.dev/callback".to_string(), - state: Some("ox_state".to_string()), - pkce_challenge: Some("ox_challenge".to_string()), - pkce_challenge_method: Some("S256".to_string()), - authz_code: None, - expires_at: None, - error: None, - provider: "google".to_string(), - provider_pkce_verifier: Some("v_verifier".to_string()), - provider_authz_code: None, - provider_error: None, - created_at: Utc::now(), - updated_at: Utc::now(), - scope: String::new(), - }; - - let mut storage = MockStorage::new(); - let mut attempt_store = MockLoginAttemptStore::new(); - attempt_store - .expect_get() - .with(eq(attempt.id)) - .returning(move |_| Ok(Some(attempt.clone()))); - storage.login_attempt_store = Some(Arc::new(attempt_store)); - - let ctx = mock_context(Arc::new(storage)).await; - let err = authz_code_callback_op_inner( - &ctx, - &attempt_id, - Some("remote-code".to_string()), - None, - ) - .await; - - assert_eq!(StatusCode::UNAUTHORIZED, err.unwrap_err().status_code); - } - } - - #[tokio::test] - async fn test_callback_fails_when_error_is_passed() { - let attempt_id = TypedUuid::new_v4(); - let attempt = LoginAttempt { - id: attempt_id, - attempt_state: LoginAttemptState::New, - client_id: TypedUuid::new_v4(), - redirect_uri: "https://test.oxeng.dev/callback".to_string(), - state: Some("ox_state".to_string()), - pkce_challenge: Some("ox_challenge".to_string()), - pkce_challenge_method: Some("S256".to_string()), - authz_code: None, - expires_at: None, - error: None, - provider: "google".to_string(), - provider_pkce_verifier: Some("v_verifier".to_string()), - provider_authz_code: None, - provider_error: None, - created_at: Utc::now(), - updated_at: Utc::now(), - scope: String::new(), - }; - - let mut attempt_store = MockLoginAttemptStore::new(); - let original_attempt = attempt.clone(); - attempt_store - .expect_get() - .with(eq(attempt.id)) - .returning(move |_| Ok(Some(original_attempt.clone()))); - - attempt_store - .expect_upsert() - .withf(|attempt| attempt.attempt_state == LoginAttemptState::Failed) - .returning(move |arg| { - let mut returned = attempt.clone(); - returned.attempt_state = arg.attempt_state; - returned.authz_code = arg.authz_code; - returned.error = arg.error; - Ok(returned) - }); - - let mut storage = MockStorage::new(); - storage.login_attempt_store = Some(Arc::new(attempt_store)); - let ctx = mock_context(Arc::new(storage)).await; - - let location = authz_code_callback_op_inner( - &ctx, - &attempt_id, - Some("remote-code".to_string()), - Some("not_access_denied".to_string()), - ) - .await - .unwrap(); - - assert_eq!( - format!("https://test.oxeng.dev/callback?error=server_error&state=ox_state",), - location - ); - } - - #[tokio::test] - async fn test_callback_forwards_access_denied() { - let attempt_id = TypedUuid::new_v4(); - let attempt = LoginAttempt { - id: attempt_id, - attempt_state: LoginAttemptState::New, - client_id: TypedUuid::new_v4(), - redirect_uri: "https://test.oxeng.dev/callback".to_string(), - state: Some("ox_state".to_string()), - pkce_challenge: Some("ox_challenge".to_string()), - pkce_challenge_method: Some("S256".to_string()), - authz_code: None, - expires_at: None, - error: None, - provider: "google".to_string(), - provider_pkce_verifier: Some("v_verifier".to_string()), - provider_authz_code: None, - provider_error: None, - created_at: Utc::now(), - updated_at: Utc::now(), - scope: String::new(), - }; - - let mut attempt_store = MockLoginAttemptStore::new(); - let original_attempt = attempt.clone(); - attempt_store - .expect_get() - .with(eq(attempt.id)) - .returning(move |_| Ok(Some(original_attempt.clone()))); - - attempt_store - .expect_upsert() - .withf(|attempt| attempt.attempt_state == LoginAttemptState::Failed) - .returning(move |arg| { - let mut returned = attempt.clone(); - returned.attempt_state = arg.attempt_state; - returned.authz_code = arg.authz_code; - returned.error = arg.error; - Ok(returned) - }); - - let mut storage = MockStorage::new(); - storage.login_attempt_store = Some(Arc::new(attempt_store)); - let ctx = mock_context(Arc::new(storage)).await; - - let location = authz_code_callback_op_inner( - &ctx, - &attempt_id, - Some("remote-code".to_string()), - Some("access_denied".to_string()), - ) - .await - .unwrap(); - - assert_eq!( - format!("https://test.oxeng.dev/callback?error=access_denied&state=ox_state",), - location - ); - } - - #[tokio::test] - async fn test_handles_callback_with_code() { - let attempt_id = TypedUuid::new_v4(); - let attempt = LoginAttempt { - id: attempt_id, - attempt_state: LoginAttemptState::New, - client_id: TypedUuid::new_v4(), - redirect_uri: "https://test.oxeng.dev/callback".to_string(), - state: Some("ox_state".to_string()), - pkce_challenge: Some("ox_challenge".to_string()), - pkce_challenge_method: Some("S256".to_string()), - authz_code: None, - expires_at: None, - error: None, - provider: "google".to_string(), - provider_pkce_verifier: Some("v_verifier".to_string()), - provider_authz_code: None, - provider_error: None, - created_at: Utc::now(), - updated_at: Utc::now(), - scope: String::new(), - }; - - let mut attempt_store = MockLoginAttemptStore::new(); - let original_attempt = attempt.clone(); - attempt_store - .expect_get() - .with(eq(attempt.id)) - .returning(move |_| Ok(Some(original_attempt.clone()))); - - let extracted_code = Arc::new(Mutex::new(None)); - let extractor = extracted_code.clone(); - attempt_store - .expect_upsert() - .withf(|attempt| attempt.attempt_state == LoginAttemptState::RemoteAuthenticated) - .returning(move |arg| { - let mut returned = attempt.clone(); - returned.attempt_state = arg.attempt_state; - returned.authz_code = arg.authz_code; - *extractor.lock().unwrap() = returned.authz_code.clone(); - Ok(returned) - }); - - let mut storage = MockStorage::new(); - storage.login_attempt_store = Some(Arc::new(attempt_store)); - let ctx = mock_context(Arc::new(storage)).await; - - let location = - authz_code_callback_op_inner(&ctx, &attempt_id, Some("remote-code".to_string()), None) - .await - .unwrap(); - - let lock = extracted_code.lock(); - assert_eq!( - format!( - "https://test.oxeng.dev/callback?code={}&state=ox_state", - lock.unwrap().as_ref().unwrap() - ), - location - ); - } - - #[tokio::test] - async fn test_fails_callback_with_error() {} - - #[tokio::test] - async fn test_exchange_checks_client_id_and_redirect() { - let (mut ctx, client, client_secret) = mock_client().await; - let client_id = client.id; - let redirect_uri = client.redirect_uris[0].redirect_uri.clone(); - let wrong_client_id = TypedUuid::new_v4(); - - let mut client_store = MockOAuthClientStore::new(); - client_store - .expect_get() - .with(eq(wrong_client_id), eq(false)) - .returning(move |_, _| Ok(None)); - client_store - .expect_get() - .with(eq(client_id), eq(false)) - .returning(move |_, _| Ok(Some(client.clone()))); - - let mut storage = MockStorage::new(); - storage.oauth_client_store = Some(Arc::new(client_store)); - - ctx.set_storage(Arc::new(storage)); - - // 1. Verify exchange fails when passing an incorrect client id - assert_eq!( - Some("Unknown client id".to_string()), - authorize_code_exchange( - &ctx, - "authorization_code", - wrong_client_id, - &client_secret, - &redirect_uri, - ) - .await - .unwrap_err() - .error_description - ); - - // 2. Verify exchange fails when passing an incorrect redirect uri - assert_eq!( - Some("Invalid redirect uri".to_string()), - authorize_code_exchange( - &ctx, - "authorization_code", - client_id, - &client_secret, - "wrong-callback-destination", - ) - .await - .unwrap_err() - .error_description - ); - - // 3. Verify a successful exchange - assert_eq!( - (), - authorize_code_exchange( - &ctx, - "authorization_code", - client_id, - &client_secret, - &redirect_uri, - ) - .await - .unwrap() - ); - } - - #[tokio::test] - async fn test_exchange_checks_grant_type() { - let (mut ctx, client, client_secret) = mock_client().await; - let client_id = client.id; - let redirect_uri = client.redirect_uris[0].redirect_uri.clone(); - - let mut client_store = MockOAuthClientStore::new(); - client_store - .expect_get() - .with(eq(client_id), eq(false)) - .returning(move |_, _| Ok(Some(client.clone()))); - - let mut storage = MockStorage::new(); - storage.oauth_client_store = Some(Arc::new(client_store)); - - ctx.set_storage(Arc::new(storage)); - - assert_eq!( - OAuthErrorCode::UnsupportedGrantType, - authorize_code_exchange( - &ctx, - "not_authorization_code", - client_id, - &client_secret, - &redirect_uri - ) - .await - .unwrap_err() - .error - ); - - assert_eq!( - (), - authorize_code_exchange( - &ctx, - "authorization_code", - client_id, - &client_secret, - &redirect_uri - ) - .await - .unwrap() - ); - } - - #[tokio::test] - async fn test_exchange_checks_for_valid_secret() { - let (mut ctx, client, client_secret) = mock_client().await; - let client_id = client.id; - let redirect_uri = client.redirect_uris[0].redirect_uri.clone(); - - let mut client_store = MockOAuthClientStore::new(); - client_store - .expect_get() - .with(eq(client_id), eq(false)) - .returning(move |_, _| Ok(Some(client.clone()))); - - let mut storage = MockStorage::new(); - storage.oauth_client_store = Some(Arc::new(client_store)); - - ctx.set_storage(Arc::new(storage)); - - let invalid_secret = RawKey::generate::<8>(&Uuid::new_v4()) - .sign(ctx.signer()) - .await - .unwrap() - .signature() - .to_string(); - - assert_eq!( - OAuthErrorCode::InvalidRequest, - authorize_code_exchange( - &ctx, - "authorization_code", - client_id, - &"too-short".to_string().into(), - &redirect_uri - ) - .await - .unwrap_err() - .error - ); - - assert_eq!( - OAuthErrorCode::InvalidClient, - authorize_code_exchange( - &ctx, - "authorization_code", - client_id, - &invalid_secret.into(), - &redirect_uri - ) - .await - .unwrap_err() - .error - ); - - assert_eq!( - (), - authorize_code_exchange( - &ctx, - "authorization_code", - client_id, - &client_secret, - &redirect_uri - ) - .await - .unwrap() - ); - } - - #[tokio::test] - async fn test_login_attempt_verification() { - let (challenge, verifier) = PkceCodeChallenge::new_random_sha256(); - let attempt = LoginAttempt { - id: TypedUuid::new_v4(), - attempt_state: LoginAttemptState::RemoteAuthenticated, - client_id: TypedUuid::new_v4(), - redirect_uri: "https://test.oxeng.dev/callback".to_string(), - state: Some("ox_state".to_string()), - pkce_challenge: Some(challenge.as_str().to_string()), - pkce_challenge_method: Some("S256".to_string()), - authz_code: None, - expires_at: Some(Utc::now().add(TimeDelta::try_seconds(60).unwrap())), - error: None, - provider: "google".to_string(), - provider_pkce_verifier: Some("v_verifier".to_string()), - provider_authz_code: None, - provider_error: None, - created_at: Utc::now(), - updated_at: Utc::now(), - scope: String::new(), - }; - - let bad_client_id = LoginAttempt { - client_id: TypedUuid::new_v4(), - ..attempt.clone() - }; - - assert_eq!( - OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Invalid client id".to_string()), - error_uri: None, - state: None, - }, - verify_login_attempt( - &bad_client_id, - attempt.client_id, - &attempt.redirect_uri, - Some(verifier.secret().as_str()), - ) - .unwrap_err() - ); - - let bad_redirect_uri = LoginAttempt { - redirect_uri: "https://bad.oxeng.dev/callback".to_string(), - ..attempt.clone() - }; - - assert_eq!( - OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Invalid redirect uri".to_string()), - error_uri: None, - state: None, - }, - verify_login_attempt( - &bad_redirect_uri, - attempt.client_id, - &attempt.redirect_uri, - Some(verifier.secret().as_str()), - ) - .unwrap_err() - ); - - let unconfirmed_state = LoginAttempt { - attempt_state: LoginAttemptState::New, - ..attempt.clone() - }; - - assert_eq!( - OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Grant is in an invalid state".to_string()), - error_uri: None, - state: None, - }, - verify_login_attempt( - &unconfirmed_state, - attempt.client_id, - &attempt.redirect_uri, - Some(verifier.secret().as_str()), - ) - .unwrap_err() - ); - - let already_used_state = LoginAttempt { - attempt_state: LoginAttemptState::Complete, - ..attempt.clone() - }; - - assert_eq!( - OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Grant is in an invalid state".to_string()), - error_uri: None, - state: None, - }, - verify_login_attempt( - &already_used_state, - attempt.client_id, - &attempt.redirect_uri, - Some(verifier.secret().as_str()), - ) - .unwrap_err() - ); - - let failed_state = LoginAttempt { - attempt_state: LoginAttemptState::Failed, - ..attempt.clone() - }; - - assert_eq!( - OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Grant is in an invalid state".to_string()), - error_uri: None, - state: None, - }, - verify_login_attempt( - &failed_state, - attempt.client_id, - &attempt.redirect_uri, - Some(verifier.secret().as_str()), - ) - .unwrap_err() - ); - - let expired = LoginAttempt { - expires_at: Some(Utc::now()), - ..attempt.clone() - }; - - assert_eq!( - OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Grant has expired".to_string()), - error_uri: None, - state: None, - }, - verify_login_attempt( - &expired, - attempt.client_id, - &attempt.redirect_uri, - Some(verifier.secret().as_str()), - ) - .unwrap_err() - ); - - let missing_pkce = LoginAttempt { ..attempt.clone() }; - - assert_eq!( - OAuthError { - error: OAuthErrorCode::InvalidRequest, - error_description: Some("Missing pkce verifier".to_string()), - error_uri: None, - state: None, - }, - verify_login_attempt( - &missing_pkce, - attempt.client_id, - &attempt.redirect_uri, - None, - ) - .unwrap_err() - ); - - let invalid_pkce = LoginAttempt { - pkce_challenge: Some("no-the-correct-value".to_string()), - ..attempt.clone() - }; - - assert_eq!( - OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Invalid pkce verifier".to_string()), - error_uri: None, - state: None, - }, - verify_login_attempt( - &invalid_pkce, - attempt.client_id, - &attempt.redirect_uri, - Some(verifier.secret().as_str()), - ) - .unwrap_err() - ); - - assert_eq!( - (), - verify_login_attempt( - &attempt, - attempt.client_id, - &attempt.redirect_uri, - Some(verifier.secret().as_str()), - ) - .unwrap() - ); - } -} diff --git a/v-api/src/endpoints/login/oauth/device_token.rs b/v-api/src/endpoints/login/oauth/device_token.rs deleted file mode 100644 index ef0c4e72..00000000 --- a/v-api/src/endpoints/login/oauth/device_token.rs +++ /dev/null @@ -1,279 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -use chrono::{DateTime, Utc}; -use dropshot::{Body, HttpError, HttpResponseOk, Method, Path, RequestContext, TypedBody}; -use http::{HeaderValue, Response, StatusCode, header}; -use oauth2::{EmptyExtraTokenFields, StandardTokenResponse, TokenResponse, basic::BasicTokenType}; -use schemars::JsonSchema; -use secrecy::ExposeSecret; -use serde::{Deserialize, Serialize}; -use std::fmt::Debug; -use tap::TapFallible; -use tracing::instrument; -use v_model::permissions::PermissionStorage; - -use super::{ - ClientType, OAuthProvider, OAuthProviderInfo, OAuthProviderNameParam, UserInfoProvider, -}; -use crate::{ - context::ApiContext, endpoints::login::LoginError, error::ApiError, - permissions::VAppPermission, response::internal_error, util::response::bad_request, -}; - -#[instrument(skip(rqctx), err(Debug))] -pub async fn get_device_provider_op( - rqctx: &RequestContext>, - path: Path, -) -> Result, HttpError> -where - T: VAppPermission + PermissionStorage, -{ - let path = path.into_inner(); - - tracing::trace!("Getting OAuth data for {}", path.provider); - - let provider = rqctx - .v_ctx() - .get_oauth_provider(&path.provider) - .await - .map_err(ApiError::OAuth)?; - - Ok(HttpResponseOk(provider.provider_info( - rqctx.v_ctx().public_url(), - &ClientType::Device, - ))) -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct AccessTokenExchangeRequest { - pub device_code: String, - pub grant_type: String, - pub expires_at: Option>, -} - -#[derive(Serialize)] -pub struct AccessTokenExchange { - provider: ProviderTokenExchange, - expires_at: Option>, -} - -#[derive(Serialize)] -pub struct ProviderTokenExchange { - client_id: String, - device_code: String, - grant_type: String, - client_secret: String, -} - -impl AccessTokenExchange { - pub fn new( - req: AccessTokenExchangeRequest, - provider: &(dyn OAuthProvider + Send + Sync), - ) -> Option { - provider - .client_secret(&ClientType::Device) - .map(|client_secret| Self { - provider: ProviderTokenExchange { - client_id: provider.client_id(&ClientType::Device).to_string(), - device_code: req.device_code, - grant_type: req.grant_type, - client_secret: client_secret.expose_secret().to_string(), - }, - expires_at: req.expires_at, - }) - } -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct ProxyTokenResponse { - pub access_token: String, - pub token_type: String, - pub expires_in: Option, - pub refresh_token: Option, - pub scopes: Option>, -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct ProxyTokenError { - error: String, - error_description: Option, - error_uri: Option, -} - -// Complete a device exchange request against the specified provider. This effectively proxies the -// requests that would go to the provider, captures the returned access tokens, and registers a -// new internal user as needed. The user is then returned an token that is valid for interacting -// with the API -#[instrument(skip(rqctx, body), err(Debug))] -pub async fn exchange_device_token_op( - rqctx: &RequestContext>, - path: Path, - body: TypedBody, -) -> Result, HttpError> -where - T: VAppPermission + PermissionStorage, -{ - let ctx = rqctx.v_ctx(); - let path = path.into_inner(); - let provider = ctx - .get_oauth_provider(&path.provider) - .await - .map_err(ApiError::OAuth)?; - - tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for token exchange"); - - let exchange_request = body.into_inner(); - - if let Some(exchange) = AccessTokenExchange::new(exchange_request, &*provider) { - let token_exchange_endpoint = provider.token_exchange_endpoint(); - let client = reqwest::Client::new(); - - let response = client - .request(Method::POST, token_exchange_endpoint) - .header(header::CONTENT_TYPE, provider.token_exchange_content_type()) - .header(header::ACCEPT, HeaderValue::from_static("application/json")) - .body( - // We know that this is safe to unwrap as we just deserialized it via the body Extractor - serde_urlencoded::to_string(&exchange.provider).unwrap(), - ) - .send() - .await - .tap_err(|err| tracing::error!(?err, "Token exchange request failed")) - .map_err(internal_error)?; - - // Take a part the response as we will need the individual parts later - let status = response.status(); - let headers = response.headers().clone(); - let bytes = response.bytes().await.map_err(internal_error)?; - - // We unfortunately can not trust our providers to follow specs and therefore need to do - // our own inspection of the response to determine what to do - if !status.is_success() { - // If the server returned a non-success status then we are going to trust the server and - // report their error back to the client - tracing::debug!(provider = ?path.provider, ?headers, ?status, "Received error response from OAuth provider"); - - let mut client_response = Response::new(Body::from(bytes)); - *client_response.headers_mut() = headers; - *client_response.status_mut() = status; - - Ok(client_response) - } else { - // The server gave us back a non-error response but it still may not be a success. - // GitHub for instance does not use a status code for indicating the success or failure - // of a call. So instead we try to deserialize the body into an access token, with the - // understanding that it may fail and we will need to try and treat the response as - // an error instead. - - let parsed: Result< - StandardTokenResponse, - serde_json::Error, - > = serde_json::from_slice(&bytes); - - match parsed { - Ok(parsed) => { - let info = provider - .get_user_info(parsed.access_token().secret()) - .await - .map_err(LoginError::UserInfo) - .tap_err(|err| { - tracing::error!(?err, "Failed to look up user information") - })?; - - tracing::debug!("Verified and validated OAuth user"); - - let (api_user_info, api_user_provider) = ctx - .register_api_user(&ctx.builtin_registration_user(), info) - .await?; - - tracing::info!(api_user_id = ?api_user_info.user.id, api_user_provider_id = ?api_user_provider.id, "Retrieved api user to generate device token for"); - - let claims = - ctx.generate_claims(&api_user_info.user.id, &api_user_provider.id, None); - let token = ctx - .user - .register_access_token( - &ctx.builtin_registration_user(), - ctx.jwt_signer(), - &api_user_info.user.id, - &claims, - ) - .await?; - - tracing::info!(provider = ?path.provider, api_user_id = ?api_user_info.user.id, "Generated access token"); - - Ok(Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "application/json") - .body( - serde_json::to_string(&ProxyTokenResponse { - access_token: token.signed_token, - token_type: "Bearer".to_string(), - expires_in: Some(claims.exp - Utc::now().timestamp()), - refresh_token: None, - scopes: None, - }) - .unwrap() - .into(), - )?) - } - Err(_) => { - // Do not log the error here as we want to ensure we do not leak token information - tracing::debug!( - "Failed to parse a success response from the remote token endpoint" - ); - - // Try to deserialize the body again, but this time as an error - let mut error_response = match serde_json::from_slice::(&bytes) - { - Ok(error) => { - // We found an error in the message body. This is not ideal, but we at - // least can understand what the server was trying to tell us - tracing::debug!(?error, provider = ?path.provider, "Parsed error response from OAuth provider"); - - let mut client_response = Response::new(Body::from(bytes)); - *client_response.headers_mut() = headers; - *client_response.status_mut() = status; - - client_response - } - Err(_) => { - // We still do not know what the remote server is doing... and need to - // cancel the request ourselves - tracing::warn!( - "Remote OAuth provide returned a response that we do not undestand" - ); - - Response::new( - serde_json::to_vec(&ProxyTokenError { - error: "access_denied".to_string(), - error_description: Some(format!( - "{} returned a malformed response", - path.provider - )), - error_uri: None, - }) - .unwrap() - .into(), - ) - } - }; - - *error_response.status_mut() = StatusCode::BAD_REQUEST; - error_response.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static("application/json"), - ); - - Ok(error_response) - } - } - } - } else { - tracing::info!(provider = ?path.provider, "Found an OAuth provider, but it is not configured properly"); - - Err(bad_request("Invalid provider")) - } -} diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs new file mode 100644 index 00000000..30a36997 --- /dev/null +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -0,0 +1,2881 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use base64::{Engine, prelude::BASE64_URL_SAFE_NO_PAD}; +use chrono::{TimeDelta, Utc}; +use cookie::{Cookie, SameSite}; +use dropshot::{ + ClientErrorStatusCode, HttpError, HttpResponseOk, HttpResponseTemporaryRedirect, Path, Query, + RequestContext, RequestInfo, SharedExtractor, TypedBody, http_response_temporary_redirect, +}; +use dropshot_authorization_header::basic::BasicAuth; +use http::{HeaderValue, header::SET_COOKIE}; +use newtype_uuid::{GenericUuid, TypedUuid}; +use oauth2::{ + AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, +}; + +use schemars::JsonSchema; +use secrecy::SecretString; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::{fmt::Debug, ops::Add}; +use tap::TapFallible; +use tracing::instrument; +use uuid::Uuid; +use v_model::{ + LoginAttempt, LoginAttemptId, NewLoginAttempt, OAuthClient, OAuthClientId, + permissions::{AsScope, PermissionStorage}, + schema_ext::LoginAttemptState, +}; + +use super::super::{OAuthProvider, OAuthProviderNameParam}; +use crate::endpoints::login::UserInfoProvider; +use crate::{ + authn::key::RawKey, + context::{ApiContext, VContext}, + endpoints::login::{ + LoginError, UserInfo, + oauth::{CheckOAuthClient, OAuthProviderAuthorizationCodePkceInfo}, + }, + error::ApiError, + permissions::{VAppPermission, VPermission}, + response::bad_request, + secrets::OpenApiSecretString, + util::{ + request::RequestCookies, + response::{ResourceError, internal_error, to_internal_error, unauthorized}, + }, +}; + +static LOGIN_ATTEMPT_COOKIE: &str = "__v_login"; +static LOGIN_ATTEMPT_COOKIE_PATH: &str = "/login/oauth/"; +static DEFAULT_SCOPE: &str = "user:info:r"; + +/// Build the login attempt cookie with consistent attributes. +/// The `Path` is scoped to the OAuth login endpoints so the cookie is not +/// sent to unrelated paths on the same domain. +fn build_login_attempt_cookie<'a>( + value: &'a str, + public_url: &str, + max_age_secs: i64, +) -> Cookie<'a> { + let mut cookie = Cookie::new(LOGIN_ATTEMPT_COOKIE, value.to_string()); + cookie.set_path(LOGIN_ATTEMPT_COOKIE_PATH); + cookie.set_http_only(true); + cookie.set_same_site(SameSite::Lax); + cookie.set_secure(public_url.starts_with("https")); + cookie.set_max_age(cookie::time::Duration::seconds(max_age_secs)); + cookie +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize, PartialEq, Eq)] +struct OAuthError { + error: OAuthErrorCode, + #[serde(skip_serializing_if = "Option::is_none")] + error_description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + error_uri: Option, + #[serde(skip_serializing_if = "Option::is_none")] + state: Option, +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +enum OAuthErrorCode { + AccessDenied, + InvalidClient, + InvalidGrant, + InvalidRequest, + InvalidScope, + ServerError, + TemporarilyUnavailable, + UnauthorizedClient, + UnsupportedGrantType, + UnsupportedResponseType, +} + +impl From for HttpError { + fn from(value: OAuthError) -> Self { + let serialized = serde_json::to_string(&value).unwrap(); + HttpError { + headers: None, + status_code: ClientErrorStatusCode::BAD_REQUEST.into(), + error_code: None, + external_message: serialized.clone(), + internal_message: serialized, + } + } +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct OAuthAuthzCodeQuery { + pub client_id: TypedUuid, + pub redirect_uri: String, + pub response_type: String, + pub state: String, + pub scope: Option, + /// PKCE code challenge (RFC 7636). Required for all authorization code flows. + pub code_challenge: String, + /// PKCE code challenge method. Must be "S256". + pub code_challenge_method: String, +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct OAuthAuthzCodeRedirectHeaders { + #[serde(rename = "set-cookies")] + cookies: String, + location: String, +} + +/// Validate that response_type is "code" per RFC 6749 §4.1.1. +fn validate_response_type(response_type: &str) -> Result<(), OAuthError> { + if response_type == "code" { + Ok(()) + } else { + Err(OAuthError { + error: OAuthErrorCode::UnsupportedResponseType, + error_description: Some("Only response_type=code is supported".to_string()), + error_uri: None, + state: None, + }) + } +} + +// Lookup the client specified by the provided client id and verify that the redirect uri +// is a valid for this client. If either of these fail we return an unauthorized response +async fn get_oauth_client( + ctx: &VContext, + client_id: &TypedUuid, + redirect_uri: &str, +) -> Result +where + T: VAppPermission + PermissionStorage, +{ + let client = ctx + .oauth + .get_oauth_client(&ctx.builtin_registration_user(), client_id) + .await + .map_err(|err| { + tracing::error!(?err, "Failed to lookup OAuth client"); + + match err { + ResourceError::DoesNotExist => OAuthError { + error: OAuthErrorCode::InvalidClient, + error_description: Some("Unknown client id".to_string()), + error_uri: None, + state: None, + }, + // Given that the builtin caller should have access to all OAuth clients, any other + // error is considered an internal error + _ => OAuthError { + error: OAuthErrorCode::ServerError, + error_description: None, + error_uri: None, + state: None, + }, + } + })?; + + if client.is_redirect_uri_valid(redirect_uri) { + Ok(client) + } else { + Err(OAuthError { + error: OAuthErrorCode::InvalidRequest, + error_description: Some("Invalid redirect uri".to_string()), + error_uri: None, + state: None, + }) + } +} + +#[instrument(skip(rqctx), err(Debug))] +pub async fn get_public_pkce_provider_op( + rqctx: &RequestContext>, + path: Path, +) -> Result, HttpError> +where + T: VAppPermission + PermissionStorage, +{ + let path = path.into_inner(); + + tracing::trace!("Getting OAuth data for {}", path.provider); + + let provider = rqctx + .v_ctx() + .get_oauth_provider(&path.provider) + .await + .map_err(ApiError::OAuth)?; + + Ok(HttpResponseOk( + provider + .authz_code_pkce_flow_info() + .cloned() + .ok_or_else(|| bad_request("Provider does not support web pkce clients"))?, + )) +} + +#[instrument(skip(rqctx), err(Debug))] +pub async fn authz_code_redirect_op( + rqctx: &RequestContext>, + path: Path, + query: Query, +) -> Result +where + T: VAppPermission + PermissionStorage, +{ + let ctx = rqctx.v_ctx(); + let path = path.into_inner(); + let query = query.into_inner(); + + get_oauth_client(ctx, &query.client_id, &query.redirect_uri).await?; + + tracing::debug!(?query.client_id, ?query.redirect_uri, "Verified client id and redirect uri"); + + // Validate response_type. Only "code" is supported (RFC 6749 §4.1.1). + validate_response_type(&query.response_type)?; + + // Validate the client's PKCE challenge method. Only S256 is supported. + if query.code_challenge_method != "S256" { + return Err(OAuthError { + error: OAuthErrorCode::InvalidRequest, + error_description: Some( + "Unsupported code_challenge_method. Only S256 is supported.".to_string(), + ), + error_uri: None, + state: None, + } + .into()); + } + + // Validate the PKCE code challenge. For S256, this must be a base64url-no-pad + // encoding of a SHA256 hash, which is always exactly 43 characters of [A-Za-z0-9_-] + // (RFC 7636 §4.2). + if query.code_challenge.len() != 43 + || !query + .code_challenge + .bytes() + .all(|b| b.is_ascii_alphanumeric() || b == b'-' || b == b'_') + { + return Err(OAuthError { + error: OAuthErrorCode::InvalidRequest, + error_description: Some( + "Invalid code_challenge. Must be a base64url-encoded SHA256 hash (43 characters)." + .to_string(), + ), + error_uri: None, + state: None, + } + .into()); + } + + // Find the configured provider for the requested remote backend. We should always have a valid + // provider value, so if this fails then a 500 is returned + let provider = ctx + .get_oauth_provider(&path.provider) + .await + .map_err(ApiError::OAuth)?; + + tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for authz code login"); + + // Check that the passed in scopes are valid. The scopes are not currently restricted by client + let scope = query.scope.unwrap_or_else(|| DEFAULT_SCOPE.to_string()); + if let Err(err) = VPermission::from_scope_arg(&scope) { + tracing::warn!(?err, ?scope, "Client submitted an invalid scope"); + return Err(OAuthError { + error: OAuthErrorCode::InvalidScope, + error_description: Some(format!("Invalid scope: {}", scope)), + error_uri: None, + state: None, + } + .into()); + } + + // Construct a new login attempt with the minimum required values + let mut attempt = NewLoginAttempt::new( + provider.name().to_string(), + query.client_id, + query.redirect_uri, + scope, + ) + .map_err(|err| { + tracing::error!(?err, "Attempted to construct invalid login attempt"); + internal_error("Attempted to construct invalid login attempt".to_string()) + })?; + + // Set a default expiration for the login attempt + // TODO: Make this configurable + attempt.expires_at = Some(Utc::now().add(TimeDelta::try_minutes(5).unwrap())); + + // Store the client's state value as-is. Per RFC 6749 §4.1.1, the authorization server + // MUST return the state parameter unmodified. The value will be properly percent-encoded + // when it is placed into the redirect URL by `callback_url()` via `append_pair`. + attempt.state = Some(query.state); + + // Always store the client's PKCE challenge so we can verify it during the token exchange. + // This is the client-to-v-api PKCE leg and is mandatory for all flows. + attempt.pkce_challenge = Some(query.code_challenge); + attempt.pkce_challenge_method = Some(query.code_challenge_method); + + // If the remote provider supports PKCE, also set up a challenge for the v-api-to-remote leg. + // This is independent of the client-to-v-api PKCE above. + let remote_pkce_challenge = if provider.supports_pkce() { + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + attempt.provider_pkce_verifier = Some(pkce_verifier.secret().to_string()); + Some(pkce_challenge) + } else { + None + }; + + // Store the generated attempt + let attempt = ctx + .login + .create_login_attempt(attempt) + .await + .map_err(to_internal_error)?; + + tracing::info!(?attempt.id, "Created login attempt"); + + oauth_redirect_response( + ctx.public_url(), + &*provider, + &attempt, + remote_pkce_challenge, + ) +} + +fn oauth_redirect_response( + public_url: &str, + provider: &dyn OAuthProvider, + attempt: &LoginAttempt, + code_challenge: Option, +) -> Result { + // We may fail if the provider configuration is not correctly configured + // TODO: This behavior should be changed so that clients are precomputed. We do not need to be + // constructing a new client on every request. That said, we need to ensure the client does not + // maintain state between requests + let client = provider.as_web_client().map_err(to_internal_error)?; + + // Create an attempt cookie header for storing the login attempt. This also acts as our csrf + // check + let attempt_id_str = attempt.id.to_string(); + let cookie = build_login_attempt_cookie(&attempt_id_str, public_url, 600); + let login_cookie = HeaderValue::from_str(&cookie.to_string()).map_err(to_internal_error)?; + + // Generate the url to the remote provider that the user will be redirected to + let mut authz_url = client + .authorize_url(|| CsrfToken::new(attempt.id.to_string())) + .add_scopes( + provider + .default_scopes() + .iter() + .map(|s| Scope::new(s.to_string())) + .collect::>(), + ); + + // If the caller has provided a code challenge, add it to the url + if let Some(challenge) = code_challenge { + authz_url = authz_url.set_pkce_challenge(challenge); + }; + + let mut redirect = http_response_temporary_redirect(authz_url.url().0.to_string())?; + redirect.headers_mut().append(SET_COOKIE, login_cookie); + + Ok(redirect) +} + +// TODO: Determine if 401 empty responses are correct here +fn verify_csrf( + request: &RequestInfo, + query: &OAuthAuthzCodeReturnQuery, +) -> Result, HttpError> { + // If we are missing the expected state parameter then we can not proceed at all with verifying + // this callback request. We also do not have a redirect uri to send the user to so we instead + // report unauthorized + let attempt_id = query + .state + .as_ref() + .ok_or_else(|| { + tracing::warn!("OAuth callback is missing a state parameter"); + unauthorized() + })? + .parse() + .map_err(|err| { + tracing::warn!(?err, "Failed to parse state"); + unauthorized() + })?; + + // The client must present the attempt cookie at a minimum. Without it we are unable to lookup a + // login attempt to match against. Without the cookie to verify the state parameter we can not + // determine a redirect uri so we instead report unauthorized + let attempt_cookie = request + .cookie(LOGIN_ATTEMPT_COOKIE) + .ok_or_else(|| { + tracing::warn!("OAuth callback is missing a login state cookie"); + unauthorized() + })? + .value() + .parse() + .map_err(|err| { + tracing::warn!(?err, "Failed to parse state cookie"); + unauthorized() + })?; + + // Verify that the attempt_id returned from the state matches the expected client value. If they + // do not match we can not lookup a redirect uri so we instead return unauthorized + if attempt_id != attempt_cookie { + tracing::warn!( + ?attempt_id, + ?attempt_cookie, + "OAuth state does not match expected cookie value" + ); + Err(unauthorized()) + } else { + Ok(attempt_id) + } +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct OAuthAuthzCodeReturnQuery { + pub state: Option, + pub code: Option, + pub error: Option, +} + +#[instrument(skip(rqctx), err(Debug))] +pub async fn authz_code_callback_op( + rqctx: &RequestContext>, + path: Path, + query: Query, +) -> Result +where + T: VAppPermission + PermissionStorage, +{ + let ctx = rqctx.v_ctx(); + let path = path.into_inner(); + let query = query.into_inner(); + let provider = ctx + .get_oauth_provider(&path.provider) + .await + .map_err(ApiError::OAuth)?; + + tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for authz code exchange"); + + // Verify and extract the attempt id before performing any work + let attempt_id = verify_csrf(&rqctx.request, &query)?; + + // Clear the login attempt cookie + let cookie = build_login_attempt_cookie("", ctx.public_url(), 0); + let login_cookie = HeaderValue::from_str(&cookie.to_string()).map_err(to_internal_error)?; + + let mut redirect = http_response_temporary_redirect( + authz_code_callback_op_inner(ctx, &attempt_id, query.code, query.error).await?, + )?; + redirect.headers_mut().append(SET_COOKIE, login_cookie); + + Ok(redirect) +} + +pub async fn authz_code_callback_op_inner( + ctx: &VContext, + attempt_id: &TypedUuid, + code: Option, + error: Option, +) -> Result +where + T: VAppPermission + PermissionStorage, +{ + // We have now verified the attempt id and can use it to look up the rest of the login attempt + // material to try and complete the flow + let mut attempt = ctx + .login + .get_login_attempt(attempt_id) + .await + .map_err(to_internal_error)? + .ok_or_else(|| { + // If we fail to find a matching attempt, there is not much we can do other than return + // unauthorized + unauthorized() + }) + .and_then(|attempt| { + if attempt.attempt_state == LoginAttemptState::New { + Ok(attempt) + } else { + Err(unauthorized()) + } + })?; + + // Re-validate the redirect URI against the OAuth client's current registered URIs. + // The URI was checked when the login attempt was created, but it may have been removed + // since then. We must not redirect to a URI that is no longer registered (TOCTOU). + let client = ctx + .oauth + .get_oauth_client(&ctx.builtin_registration_user(), &attempt.client_id) + .await?; + if !client.is_redirect_uri_valid(&attempt.redirect_uri) { + tracing::warn!( + redirect_uri = ?attempt.redirect_uri, + client_id = ?attempt.client_id, + "Login attempt redirect URI is no longer registered on the OAuth client" + ); + return Err(unauthorized()); + } + + attempt = match (code, error) { + (Some(code), None) => { + tracing::info!(?attempt.id, "Received valid login attempt. Storing authorization code"); + + // Store the authorization code returned by the underlying OAuth provider and transition the + // attempt to the awaiting state + ctx.login + .set_login_provider_authz_code(attempt, code.to_string()) + .await + .map_err(to_internal_error)? + } + (code, error) => { + tracing::info!(?attempt.id, ?error, "Received an error response from the remote server"); + + // Store the provider return error for future debugging, but if an error has been + // returned or there is a missing code, then we can not report a successful process + attempt.provider_authz_code = code; + + // When a user has explicitly denied access we want to forward that error message + // onwards to the upstream requester. All other errors should be opaque to the + // original requester and are returned as server errors + let error_message = match error.as_deref() { + Some("access_denied") => "access_denied", + _ => "server_error", + }; + + // TODO: Specialize the returned error + ctx.login + .fail_login_attempt( + attempt, + LoginAttemptState::New, + Some(error_message), + error.as_deref(), + ) + .await + .map_err(to_internal_error)? + } + }; + + // Redirect back to the original authenticator + attempt.callback_url().map_err(|err| { + tracing::error!(?err, redirect_uri = ?attempt.redirect_uri, "Login attempt contains an invalid redirect URI"); + to_internal_error(err) + }) +} + +#[derive(Debug, Deserialize, JsonSchema)] +pub struct OAuthAuthzCodeExchangeQuery { + #[serde(default)] + pub request_idp_token: bool, +} + +#[derive(Debug, Deserialize, JsonSchema)] +pub struct OAuthAuthzCodeExchangeBody { + pub client_id: Option>, + pub client_secret: Option, + pub redirect_uri: String, + pub grant_type: String, + pub code: String, + /// PKCE code verifier (RFC 7636). Required for all authorization code exchanges. + pub pkce_verifier: String, +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct OAuthAuthzCodeExchangeResponse { + pub access_token: String, + pub token_type: String, + pub expires_in: i64, + /// The scope granted to the access token (RFC 6749 §5.1). + pub scope: String, + pub idp_token: Option, +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct OAuthAuthzCodeIdpToken { + pub token: String, +} + +#[instrument(skip(rqctx), err(Debug))] +pub async fn authz_code_exchange_op( + rqctx: &RequestContext>, + query: Query, + path: Path, + body: TypedBody, +) -> Result, HttpError> +where + T: VAppPermission + PermissionStorage, +{ + let ctx = rqctx.v_ctx(); + let query = query.into_inner(); + let path = path.into_inner(); + let body = body.into_inner(); + + let provider = ctx + .get_oauth_provider(&path.provider) + .await + .map_err(ApiError::OAuth)?; + + // Extract basic authorization credentials from the request if they were provided. + let auth = ::from_request(rqctx) + .await + .tap_err(|err| { + tracing::warn!(?err, "Failed to extract basic authentication values"); + }); + let basic_credentials = match auth { + Ok(auth) if auth.username().is_some() && auth.password().is_some() => Ok(Some(( + TypedUuid::from_untyped_uuid( + Uuid::parse_str(auth.username().unwrap()) + .map_err(|_| bad_request("Malformed client ID presented to code exchange"))?, + ), + auth.password().unwrap().to_string(), + ))), + Ok(auth) if auth.username().is_none() && auth.password().is_none() => { + tracing::info!("Credentials for code exchange not defined via basic auth"); + Ok(None) + } + Ok(_) => Err(bad_request( + "Malformed credentials presented to code exchange", + )), + Err(err) => { + tracing::info!(?err, "Failed to extract basic authentication credentials"); + Ok(None) + } + }?; + + // Extract credentials from the request body if they were provided. + let body_credentials = (body.client_id, body.client_secret); + + // Now validate if the credentials provided by the client support one of our expected schemes. + // We of course deny underspecifying credentials, but we also want to disallow over specifying + // them. For example, if the client provides both basic auth and a client id/secret in the + // request body, we should reject the request. + tracing::debug!( + ?basic_credentials, + ?body_credentials, + "Extracted credentials from request" + ); + let (client_id, client_secret) = match (basic_credentials, body_credentials) { + (Some(_), (Some(_), _)) => Err(bad_request( + "Cannot provide both basic auth and client credentials", + )), + (Some(_), (_, Some(_))) => Err(bad_request( + "Cannot provide both basic auth and client credentials", + )), + (Some((client_id, client_secret)), (None, None)) => Ok(( + client_id, + Some(OpenApiSecretString(SecretString::from(client_secret))), + )), + (None, (Some(client_id), Some(client_secret))) => Ok((client_id, Some(client_secret))), + (None, (Some(client_id), _)) if provider.authz_code_pkce_flow_info().is_some() => { + Ok((client_id, None)) + } + _ => Err(bad_request("Missing client credentials")), + }?; + + tracing::debug!("Attempting code exchange"); + + // Verify the submitted client credentials + authorize_code_exchange( + ctx, + &*provider, + &body.grant_type, + client_id, + client_secret.map(|s| s.0).as_ref(), + &body.redirect_uri, + ) + .await?; + + tracing::debug!("Authorized code exchange"); + + // Lookup the request assigned to this code + let mut attempt = ctx + .login + .get_login_attempt_for_code(&body.code, &provider.name().to_string()) + .await + .map_err(to_internal_error)? + .ok_or(OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: None, + error_uri: None, + state: None, + })?; + + // Verify that the login attempt is valid and matches the submitted client credentials + verify_login_attempt( + &attempt, + &provider.name().to_string(), + client_id, + &body.redirect_uri, + &body.pkce_verifier, + )?; + + tracing::debug!("Verified login attempt"); + + // Atomically claim this login attempt before doing any remote work. This transitions + // the attempt from RemoteAuthenticated -> Complete in a single conditional UPDATE, + // ensuring that a concurrent request using the same authorization code will fail. + // Per RFC 6749 §4.1.2, authorization codes MUST be single-use. + let attempt_id = attempt.id; + attempt = ctx + .login + .claim_login_attempt(attempt) + .await + .map_err(|err| { + tracing::warn!( + ?err, + ?attempt_id, + "Failed to claim login attempt (may have been consumed by a concurrent request)" + ); + OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Authorization code has already been used".to_string()), + error_uri: None, + state: None, + } + })?; + + tracing::debug!("Claimed login attempt"); + + // Now that the attempt has been claimed, use it to fetch user information from the + // remote provider. If this fails, the attempt is already consumed and the user must + // re-authenticate. The upstream access token is always preserved here so that + // revocation can be deferred until after the permission check. + let (info, upstream_token) = fetch_user_info(ctx.public_url(), &*provider, &attempt).await?; + + tracing::debug!("Retrieved user information from remote provider"); + + complete_exchange( + ctx, + info, + &*provider, + &attempt, + query.request_idp_token, + upstream_token, + ) + .await +} + +async fn complete_exchange( + ctx: &VContext, + info: UserInfo, + provider: &dyn OAuthProvider, + attempt: &LoginAttempt, + request_idp_token: bool, + upstream_token: Option, +) -> Result, HttpError> +where + T: VAppPermission + PermissionStorage, +{ + let idp_token = info.idp_token.clone(); + + // Register this user as an API user if needed + let (api_user_info, api_user_provider) = ctx + .register_api_user(&ctx.builtin_registration_user(), info) + .await?; + + // Only return the IdP token if the caller requested it AND the user has permission. + // We must resolve the full caller (including group permissions) rather than checking + // only the directly assigned user permissions. + let idp_token = filter_idp_token(ctx, idp_token, request_idp_token, &api_user_info).await; + + // Revoke the upstream access token whenever it will NOT be returned to the caller. + // This covers the cases where the token was never requested, where the user lacks + // the RetrieveRemoteAccessToken permission, and where the provider did not return + // a token at all. + if idp_token.is_none() + && let Some(upstream) = upstream_token + { + revoke_upstream_token(provider, &upstream).await; + } + + tracing::info!(api_user_id = ?api_user_info.user.id, "Retrieved api user to generate access token for"); + + let scope = attempt + .scope + .split(' ') + .map(|s| s.to_string()) + .collect::>(); + + let token = ctx + .generate_access_token( + &ctx.builtin_registration_user(), + &api_user_info.user.id, + &api_user_provider.id, + Some(scope), + ) + .await?; + + Ok(HttpResponseOk(OAuthAuthzCodeExchangeResponse { + token_type: "Bearer".to_string(), + access_token: token.signed_token, + expires_in: token.expires_in, + scope: attempt.scope.clone(), + idp_token, + })) +} + +/// Filter the IdP token based on whether it was requested and whether the user has +/// the `RetrieveRemoteAccessToken` permission (including permissions inherited from +/// groups). Returns `None` if either condition is not met. +async fn filter_idp_token( + ctx: &VContext, + idp_token: Option, + requested: bool, + api_user_info: &v_model::ApiUserInfo, +) -> Option +where + T: VAppPermission + PermissionStorage, +{ + if !requested { + return None; + } + + // Resolve the caller so that group-inherited permissions are included in the + // permission check, not just directly-assigned user permissions. + let caller = match ctx + .user + .resolve_caller(api_user_info, crate::context::BasePermissions::Full) + .await + { + Ok(caller) => caller, + Err(err) => { + tracing::warn!( + ?err, + "Failed to resolve caller permissions for IdP token check" + ); + return None; + } + }; + + if caller + .permissions + .can(&VPermission::RetrieveRemoteAccessToken.into()) + { + idp_token + } else { + tracing::info!("User requested IdP token but lacks RetrieveRemoteAccessToken permission"); + None + } +} + +async fn authorize_code_exchange( + ctx: &VContext, + provider: &dyn OAuthProvider, + grant_type: &str, + client_id: TypedUuid, + client_secret: Option<&SecretString>, + redirect_uri: &str, +) -> Result<(), OAuthError> +where + T: VAppPermission + PermissionStorage, +{ + let client = get_oauth_client(ctx, &client_id, redirect_uri).await?; + + // Verify that we received the expected grant type + if grant_type != "authorization_code" { + return Err(OAuthError { + error: OAuthErrorCode::UnsupportedGrantType, + error_description: None, + error_uri: None, + state: None, + }); + } + + tracing::debug!(grant_type, "Verified grant type"); + + // If we were provided a client secret, then it must be verified. If a client secret was not + // provided, then we can skip this step as long as the provider supports pkce_only + // authentication. + if let Some(client_secret) = client_secret { + let client_secret = RawKey::try_from(client_secret).map_err(|err| { + tracing::warn!(?err, "Failed to parse OAuth client secret"); + + OAuthError { + error: OAuthErrorCode::InvalidRequest, + error_description: Some("Malformed client secret".to_string()), + error_uri: None, + state: None, + } + })?; + + tracing::debug!("Constructed client secret"); + + if !client.is_secret_valid(&client_secret, ctx) { + Err(OAuthError { + error: OAuthErrorCode::InvalidClient, + error_description: Some("Invalid client secret".to_string()), + error_uri: None, + state: None, + }) + } else { + tracing::debug!("Verified client secret validity"); + + Ok(()) + } + } else if provider.authz_code_pkce_flow_info().is_some() { + Ok(()) + } else { + Err(OAuthError { + error: OAuthErrorCode::InvalidRequest, + error_description: Some("Client secret required".to_string()), + error_uri: None, + state: None, + }) + } +} + +fn verify_login_attempt( + attempt: &LoginAttempt, + provider: &str, + client_id: TypedUuid, + redirect_uri: &str, + pkce_verifier: &str, +) -> Result<(), OAuthError> { + if attempt.provider != provider { + Err(OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Provider mismatch".to_string()), + error_uri: None, + state: None, + }) + } else if attempt.client_id != client_id { + Err(OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Invalid client id".to_string()), + error_uri: None, + state: None, + }) + } else if attempt.redirect_uri != redirect_uri { + Err(OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Invalid redirect uri".to_string()), + error_uri: None, + state: None, + }) + } else if attempt.attempt_state != LoginAttemptState::RemoteAuthenticated { + Err(OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Grant is in an invalid state".to_string()), + error_uri: None, + state: None, + }) + } else if attempt.expires_at.map(|t| t <= Utc::now()).unwrap_or(true) { + Err(OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Grant has expired".to_string()), + error_uri: None, + state: None, + }) + } else { + match attempt.pkce_challenge.as_deref() { + Some(challenge) => { + let mut hasher = Sha256::new(); + hasher.update(pkce_verifier); + let hash = hasher.finalize(); + let computed_challenge = BASE64_URL_SAFE_NO_PAD.encode(hash); + + if challenge == computed_challenge { + Ok(()) + } else { + Err(OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Invalid pkce verifier".to_string()), + error_uri: None, + state: None, + }) + } + } + // PKCE is mandatory for all authorization code flows. A missing challenge + // means the login attempt was not properly initialized. + None => Err(OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Login attempt is missing a PKCE challenge".to_string()), + error_uri: None, + state: None, + }), + } + } +} + +/// Revoke an upstream IdP access token if the provider supports revocation. +/// Failures are logged but do not propagate — callers should not fail the +/// overall exchange just because revocation was unsuccessful. +async fn revoke_upstream_token(provider: &dyn OAuthProvider, token_secret: &str) { + let provider_info = match provider.authz_code_flow_info() { + Some(info) => info, + None => return, + }; + + if provider_info.remote.revocation_endpoint.is_some() { + let client = match provider.as_web_client() { + Ok(c) => c, + Err(err) => { + tracing::warn!( + ?err, + "Failed to build web client for upstream token revocation" + ); + return; + } + }; + let oauth_client: oauth2_reqwest::ReqwestClient = provider.client().clone().into(); + let access_token = oauth2::AccessToken::new(token_secret.to_string()); + match client.revoke_token(access_token.into()) { + Ok(req) => { + if let Err(err) = req.request_async(&oauth_client).await { + tracing::warn!(?err, "Failed to revoke upstream IdP access token"); + } + } + Err(err) => { + tracing::warn!( + ?err, + "Failed to build revocation request for upstream token" + ); + } + } + } else { + tracing::debug!("Provider does not support token revocation") + } +} + +#[instrument(skip(attempt))] +async fn fetch_user_info( + public_url: &str, + provider: &dyn OAuthProvider, + attempt: &LoginAttempt, +) -> Result<(UserInfo, Option), HttpError> { + // Exchange the stored authorization code with the remote provider for a remote access token + let client = provider.as_web_client().map_err(to_internal_error)?; + + let mut request = client.exchange_code(AuthorizationCode::new( + attempt + .provider_authz_code + .as_ref() + .ok_or_else(|| { + internal_error("Expected authorization code to exist due to attempt state") + })? + .to_string(), + )); + + if let Some(pkce_verifier) = &attempt.provider_pkce_verifier { + request = request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier.to_string())) + } + + if let Some(expires_in) = provider.expires_in() { + request = request.add_extra_param("expires_in", expires_in.to_string()); + } + + let oauth_client: oauth2_reqwest::ReqwestClient = provider.client().clone().into(); + let response = request + .request_async(&oauth_client) + .await + .map_err(to_internal_error)?; + + tracing::info!("Fetched access token from remote service"); + + // Use the retrieved access token to fetch the user information from the remote API + let info = provider + .get_user_info(response.access_token().secret()) + .await + .map_err(LoginError::UserInfo) + .tap_err(|err| tracing::error!(?err, "Failed to look up user information"))?; + + tracing::info!("Fetched user info from remote service"); + + // Return the upstream access token alongside the user info so the caller + // can decide whether to revoke it after the permission check. + let upstream_token = Some(response.access_token().secret().to_string()); + + Ok((info, upstream_token)) +} + +#[cfg(test)] +mod tests { + use std::{ + net::{Ipv4Addr, SocketAddrV4}, + ops::Add, + sync::{Arc, Mutex}, + }; + + use chrono::{TimeDelta, Utc}; + use dropshot::{HttpResponse, RequestInfo}; + use http::{ + HeaderValue, StatusCode, + header::{COOKIE, LOCATION, SET_COOKIE}, + }; + use http_body_util::Empty; + use mockall::predicate::eq; + use newtype_uuid::TypedUuid; + use oauth2::PkceCodeChallenge; + use secrecy::SecretString; + use uuid::Uuid; + use v_model::{ + AccessToken, ApiUser, ApiUserInfo, ApiUserProvider, LoginAttempt, NewApiUser, + NewApiUserProvider, OAuthClient, OAuthClientRedirectUri, OAuthClientSecret, + schema_ext::LoginAttemptState, + storage::{ + MockAccessGroupStore, MockAccessTokenStore, MockApiUserProviderStore, MockApiUserStore, + MockLoginAttemptStore, MockMapperStore, MockOAuthClientStore, + }, + }; + + use crate::{ + authn::key::RawKey, + context::{ + VContext, + test_mocks::{MockStorage, mock_context}, + }, + endpoints::login::{ + ExternalUserId, UserInfo, + oauth::{ + OAuthProviderName, + flow::code::{ + LOGIN_ATTEMPT_COOKIE, OAuthAuthzCodeReturnQuery, OAuthError, OAuthErrorCode, + authz_code_callback_op_inner, verify_csrf, verify_login_attempt, + }, + }, + }, + permissions::VPermission, + }; + + use super::{authorize_code_exchange, get_oauth_client, oauth_redirect_response}; + + /// A minimal no-op `OAuthProvider` for unit tests that need to pass a + /// provider reference to `complete_exchange` without performing any real + /// network I/O. `authz_code_flow_info` returns `None`, so + /// `revoke_upstream_token` will short-circuit immediately. + #[derive(Debug)] + struct NoOpOAuthProvider { + client: reqwest::Client, + } + + impl NoOpOAuthProvider { + fn new() -> Self { + Self { + client: reqwest::Client::new(), + } + } + } + + impl crate::endpoints::login::oauth::ExtractUserInfo for NoOpOAuthProvider { + fn extract_user_info( + &self, + _data: &[hyper::body::Bytes], + ) -> Result { + unimplemented!("not used in tests") + } + } + + impl crate::endpoints::login::oauth::OAuthProvider for NoOpOAuthProvider { + fn name(&self) -> OAuthProviderName { + OAuthProviderName::Google + } + fn initialize_headers(&self, _request: &mut reqwest::Request) {} + fn client(&self) -> &reqwest::Client { + &self.client + } + fn user_info_endpoints(&self) -> Vec<&str> { + vec![] + } + fn authz_code_flow_info( + &self, + ) -> Option<&crate::endpoints::login::oauth::OAuthProviderAuthorizationCodeInfo> { + None + } + fn authz_code_pkce_flow_info( + &self, + ) -> Option<&crate::endpoints::login::oauth::OAuthProviderAuthorizationCodePkceInfo> + { + None + } + fn device_code_flow_info( + &self, + ) -> Option<&crate::endpoints::login::oauth::OAuthProviderDeviceInfo> { + None + } + fn expires_in(&self) -> Option { + None + } + fn default_scopes(&self) -> &[String] { + &[] + } + fn supports_pkce(&self) -> bool { + false + } + } + + /// Create a mock `OAuthClientStore` that returns a client with the given + /// `client_id` and a single registered `redirect_uri`. This is needed by + /// any test that exercises `authz_code_callback_op_inner`, which re-validates + /// the redirect URI against the client before redirecting. + fn mock_oauth_client_store_for_callback( + client_id: TypedUuid, + redirect_uri: &str, + ) -> Arc { + let redirect_uri = redirect_uri.to_string(); + let mut store = MockOAuthClientStore::new(); + store + .expect_get() + .with(eq(client_id), eq(false)) + .returning(move |_, _| { + Ok(Some(OAuthClient { + id: client_id, + secrets: vec![], + redirect_uris: vec![OAuthClientRedirectUri { + id: TypedUuid::new_v4(), + oauth_client_id: client_id, + redirect_uri: redirect_uri.clone(), + created_at: Utc::now(), + deleted_at: None, + }], + created_at: Utc::now(), + deleted_at: None, + })) + }); + Arc::new(store) + } + + async fn mock_client() -> (VContext, OAuthClient, SecretString) { + let ctx = mock_context(Arc::new(MockStorage::new())).await; + let client_id = TypedUuid::new_v4(); + let key = RawKey::generate::<8>(&Uuid::new_v4()) + .sign(ctx.signer()) + .await + .unwrap(); + let secret_signature = key.signature().to_string(); + let client_secret = key.key(); + let redirect_uri = "https://example.com/callback"; + + ( + ctx, + OAuthClient { + id: client_id, + secrets: vec![OAuthClientSecret { + id: TypedUuid::new_v4(), + oauth_client_id: client_id, + secret_signature, + created_at: Utc::now(), + deleted_at: None, + }], + redirect_uris: vec![OAuthClientRedirectUri { + id: TypedUuid::new_v4(), + oauth_client_id: client_id, + redirect_uri: redirect_uri.to_string(), + created_at: Utc::now(), + deleted_at: None, + }], + created_at: Utc::now(), + deleted_at: None, + }, + client_secret, + ) + } + + #[tokio::test] + async fn test_oauth_client_lookup_checks_redirect_uri() { + let client_id = TypedUuid::new_v4(); + let client = OAuthClient { + id: client_id, + secrets: vec![], + redirect_uris: vec![OAuthClientRedirectUri { + id: TypedUuid::new_v4(), + oauth_client_id: client_id, + redirect_uri: "https://test.oxeng.dev/callback".to_string(), + created_at: Utc::now(), + deleted_at: None, + }], + created_at: Utc::now(), + deleted_at: None, + }; + + let mut client_store = MockOAuthClientStore::new(); + client_store + .expect_get() + .with(eq(client_id), eq(false)) + .returning(move |_, _| Ok(Some(client.clone()))); + + let mut storage = MockStorage::new(); + storage.oauth_client_store = Some(Arc::new(client_store)); + let ctx = mock_context(Arc::new(storage)).await; + + let failure = get_oauth_client(&ctx, &client_id, "https://not-test.oxeng.dev/callback") + .await + .unwrap_err(); + assert_eq!(OAuthErrorCode::InvalidRequest, failure.error); + assert_eq!( + Some("Invalid redirect uri".to_string()), + failure.error_description + ); + + let success = get_oauth_client(&ctx, &client_id, "https://test.oxeng.dev/callback").await; + assert_eq!(client_id, success.unwrap().id); + } + + #[tokio::test] + async fn test_remote_provider_redirect_url() { + let storage = MockStorage::new(); + let ctx = mock_context(Arc::new(storage)).await; + + let (challenge, _) = PkceCodeChallenge::new_random_sha256(); + let attempt = LoginAttempt { + id: TypedUuid::new_v4(), + attempt_state: LoginAttemptState::New, + client_id: TypedUuid::new_v4(), + redirect_uri: "https://test.oxeng.dev/callback".to_string(), + state: Some("ox_state".to_string()), + pkce_challenge: Some("ox_challenge".to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: None, + expires_at: None, + error: None, + provider: "google".to_string(), + provider_pkce_verifier: Some("v_verifier".to_string()), + provider_authz_code: None, + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + scope: String::new(), + }; + + let response = oauth_redirect_response( + ctx.public_url(), + &*ctx + .get_oauth_provider(&OAuthProviderName::Google) + .await + .unwrap(), + &attempt, + Some(challenge.clone()), + ) + .unwrap() + .to_result() + .unwrap(); + let headers = response.headers(); + + let expected_location = format!( + "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=google_web_client_id&state={}&code_challenge={}&code_challenge_method=S256&redirect_uri=https%3A%2F%2Ftest_public_url%2Flogin%2Foauth%2Fgoogle%2Fcode%2Fcallback&scope=openid+email+profile", + attempt.id, + challenge.as_str() + ); + + assert_eq!( + expected_location, + String::from_utf8(headers.get(LOCATION).unwrap().as_bytes().to_vec()).unwrap() + ); + assert_eq!( + format!( + "{}; HttpOnly; SameSite=Lax; Secure; Path=/login/oauth/; Max-Age=600", + attempt.id + ) + .as_str(), + String::from_utf8(headers.get(SET_COOKIE).unwrap().as_bytes().to_vec()) + .unwrap() + .split_once('=') + .unwrap() + .1 + ) + } + + #[tokio::test] + async fn test_csrf_check() { + let id = TypedUuid::new_v4(); + + let mut rq = hyper::Request::new(Empty::<()>::new()); + rq.headers_mut().insert( + COOKIE, + HeaderValue::from_str(&format!("{}={}", LOGIN_ATTEMPT_COOKIE, id)).unwrap(), + ); + let with_valid_cookie = RequestInfo::new( + &rq, + std::net::SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8888)), + ); + let query = OAuthAuthzCodeReturnQuery { + state: Some(id.to_string()), + code: None, + error: None, + }; + assert_eq!(id, verify_csrf(&with_valid_cookie, &query).unwrap()); + + let query = OAuthAuthzCodeReturnQuery { + state: None, + code: None, + error: None, + }; + assert_eq!( + StatusCode::UNAUTHORIZED, + verify_csrf(&with_valid_cookie, &query) + .unwrap_err() + .status_code + ); + + let mut rq = hyper::Request::new(Empty::<()>::new()); + rq.headers_mut().insert( + COOKIE, + HeaderValue::from_str(&format!("{}={}", LOGIN_ATTEMPT_COOKIE, Uuid::new_v4())).unwrap(), + ); + let with_invalid_cookie = RequestInfo::new( + &rq, + std::net::SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8888)), + ); + let query = OAuthAuthzCodeReturnQuery { + state: Some(id.to_string()), + code: None, + error: None, + }; + assert_eq!( + StatusCode::UNAUTHORIZED, + verify_csrf(&with_invalid_cookie, &query) + .unwrap_err() + .status_code + ); + + let rq = hyper::Request::new(Empty::<()>::new()); + let with_missing_cookie = RequestInfo::new( + &rq, + std::net::SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8888)), + ); + let query = OAuthAuthzCodeReturnQuery { + state: Some(id.to_string()), + code: None, + error: None, + }; + assert_eq!( + StatusCode::UNAUTHORIZED, + verify_csrf(&with_missing_cookie, &query) + .unwrap_err() + .status_code + ); + } + + #[tokio::test] + async fn test_callback_fails_when_not_in_new_state() { + let invalid_states = [ + LoginAttemptState::Complete, + LoginAttemptState::Failed, + LoginAttemptState::RemoteAuthenticated, + ]; + + for state in invalid_states { + let attempt_id = TypedUuid::new_v4(); + let attempt = LoginAttempt { + id: attempt_id, + attempt_state: state, + client_id: TypedUuid::new_v4(), + redirect_uri: "https://test.oxeng.dev/callback".to_string(), + state: Some("ox_state".to_string()), + pkce_challenge: Some("ox_challenge".to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: None, + expires_at: None, + error: None, + provider: "google".to_string(), + provider_pkce_verifier: Some("v_verifier".to_string()), + provider_authz_code: None, + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + scope: String::new(), + }; + + let mut storage = MockStorage::new(); + let mut attempt_store = MockLoginAttemptStore::new(); + attempt_store + .expect_get() + .with(eq(attempt.id)) + .returning(move |_| Ok(Some(attempt.clone()))); + storage.login_attempt_store = Some(Arc::new(attempt_store)); + + let ctx = mock_context(Arc::new(storage)).await; + let err = authz_code_callback_op_inner( + &ctx, + &attempt_id, + Some("remote-code".to_string()), + None, + ) + .await; + + assert_eq!(StatusCode::UNAUTHORIZED, err.unwrap_err().status_code); + } + } + + #[tokio::test] + async fn test_callback_fails_when_error_is_passed() { + let attempt_id = TypedUuid::new_v4(); + let client_id = TypedUuid::new_v4(); + let attempt = LoginAttempt { + id: attempt_id, + attempt_state: LoginAttemptState::New, + client_id, + redirect_uri: "https://test.oxeng.dev/callback".to_string(), + state: Some("ox_state".to_string()), + pkce_challenge: Some("ox_challenge".to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: None, + expires_at: None, + error: None, + provider: "google".to_string(), + provider_pkce_verifier: Some("v_verifier".to_string()), + provider_authz_code: None, + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + scope: String::new(), + }; + + let mut attempt_store = MockLoginAttemptStore::new(); + let original_attempt = attempt.clone(); + attempt_store + .expect_get() + .with(eq(attempt.id)) + .returning(move |_| Ok(Some(original_attempt.clone()))); + + attempt_store + .expect_update_if_state() + .withf(|attempt, expected| { + attempt.attempt_state == LoginAttemptState::Failed + && *expected == LoginAttemptState::New + }) + .returning(move |arg, _| { + let mut returned = attempt.clone(); + returned.attempt_state = arg.attempt_state; + returned.authz_code = arg.authz_code; + returned.error = arg.error; + Ok(returned) + }); + + let mut storage = MockStorage::new(); + storage.login_attempt_store = Some(Arc::new(attempt_store)); + storage.oauth_client_store = Some(mock_oauth_client_store_for_callback( + client_id, + "https://test.oxeng.dev/callback", + )); + let ctx = mock_context(Arc::new(storage)).await; + + let location = authz_code_callback_op_inner( + &ctx, + &attempt_id, + Some("remote-code".to_string()), + Some("not_access_denied".to_string()), + ) + .await + .unwrap(); + + assert_eq!( + format!("https://test.oxeng.dev/callback?state=ox_state&error=server_error",), + location + ); + } + + #[tokio::test] + async fn test_callback_forwards_access_denied() { + let attempt_id = TypedUuid::new_v4(); + let client_id = TypedUuid::new_v4(); + let attempt = LoginAttempt { + id: attempt_id, + attempt_state: LoginAttemptState::New, + client_id, + redirect_uri: "https://test.oxeng.dev/callback".to_string(), + state: Some("ox_state".to_string()), + pkce_challenge: Some("ox_challenge".to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: None, + expires_at: None, + error: None, + provider: "google".to_string(), + provider_pkce_verifier: Some("v_verifier".to_string()), + provider_authz_code: None, + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + scope: String::new(), + }; + + let mut attempt_store = MockLoginAttemptStore::new(); + let original_attempt = attempt.clone(); + attempt_store + .expect_get() + .with(eq(attempt.id)) + .returning(move |_| Ok(Some(original_attempt.clone()))); + + attempt_store + .expect_update_if_state() + .withf(|attempt, expected| { + attempt.attempt_state == LoginAttemptState::Failed + && *expected == LoginAttemptState::New + }) + .returning(move |arg, _| { + let mut returned = attempt.clone(); + returned.attempt_state = arg.attempt_state; + returned.authz_code = arg.authz_code; + returned.error = arg.error; + Ok(returned) + }); + + let mut storage = MockStorage::new(); + storage.login_attempt_store = Some(Arc::new(attempt_store)); + storage.oauth_client_store = Some(mock_oauth_client_store_for_callback( + client_id, + "https://test.oxeng.dev/callback", + )); + let ctx = mock_context(Arc::new(storage)).await; + + let location = authz_code_callback_op_inner( + &ctx, + &attempt_id, + Some("remote-code".to_string()), + Some("access_denied".to_string()), + ) + .await + .unwrap(); + + assert_eq!( + format!("https://test.oxeng.dev/callback?state=ox_state&error=access_denied",), + location + ); + } + + #[tokio::test] + async fn test_handles_callback_with_code() { + let attempt_id = TypedUuid::new_v4(); + let client_id = TypedUuid::new_v4(); + let attempt = LoginAttempt { + id: attempt_id, + attempt_state: LoginAttemptState::New, + client_id, + redirect_uri: "https://test.oxeng.dev/callback".to_string(), + state: Some("ox_state".to_string()), + pkce_challenge: Some("ox_challenge".to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: None, + expires_at: None, + error: None, + provider: "google".to_string(), + provider_pkce_verifier: Some("v_verifier".to_string()), + provider_authz_code: None, + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + scope: String::new(), + }; + + let mut attempt_store = MockLoginAttemptStore::new(); + let original_attempt = attempt.clone(); + attempt_store + .expect_get() + .with(eq(attempt.id)) + .returning(move |_| Ok(Some(original_attempt.clone()))); + + let extracted_code = Arc::new(Mutex::new(None)); + let extractor = extracted_code.clone(); + attempt_store + .expect_update_if_state() + .withf(|attempt, expected| { + attempt.attempt_state == LoginAttemptState::RemoteAuthenticated + && *expected == LoginAttemptState::New + }) + .returning(move |arg, _| { + let mut returned = attempt.clone(); + returned.attempt_state = arg.attempt_state; + returned.authz_code = arg.authz_code; + *extractor.lock().unwrap() = returned.authz_code.clone(); + Ok(returned) + }); + + let mut storage = MockStorage::new(); + storage.login_attempt_store = Some(Arc::new(attempt_store)); + storage.oauth_client_store = Some(mock_oauth_client_store_for_callback( + client_id, + "https://test.oxeng.dev/callback", + )); + let ctx = mock_context(Arc::new(storage)).await; + + let location = + authz_code_callback_op_inner(&ctx, &attempt_id, Some("remote-code".to_string()), None) + .await + .unwrap(); + + let lock = extracted_code.lock(); + assert_eq!( + format!( + "https://test.oxeng.dev/callback?state=ox_state&code={}", + lock.unwrap().as_ref().unwrap() + ), + location + ); + } + + #[tokio::test] + async fn test_exchange_checks_client_id_and_redirect() { + let (mut ctx, client, client_secret) = mock_client().await; + let client_id = client.id; + let redirect_uri = client.redirect_uris[0].redirect_uri.clone(); + let wrong_client_id = TypedUuid::new_v4(); + + let mut client_store = MockOAuthClientStore::new(); + client_store + .expect_get() + .with(eq(wrong_client_id), eq(false)) + .returning(move |_, _| Ok(None)); + client_store + .expect_get() + .with(eq(client_id), eq(false)) + .returning(move |_, _| Ok(Some(client.clone()))); + + let mut storage = MockStorage::new(); + storage.oauth_client_store = Some(Arc::new(client_store)); + + ctx.set_storage(Arc::new(storage)); + let provider = ctx + .get_oauth_provider(&OAuthProviderName::Google) + .await + .unwrap(); + + // 1. Verify exchange fails when passing an incorrect client id + assert_eq!( + Some("Unknown client id".to_string()), + authorize_code_exchange( + &ctx, + &*provider, + "authorization_code", + wrong_client_id, + Some(&client_secret), + &redirect_uri, + ) + .await + .unwrap_err() + .error_description + ); + + // 2. Verify exchange fails when passing an incorrect redirect uri + assert_eq!( + Some("Invalid redirect uri".to_string()), + authorize_code_exchange( + &ctx, + &*provider, + "authorization_code", + client_id, + Some(&client_secret), + "wrong-callback-destination", + ) + .await + .unwrap_err() + .error_description + ); + + // 3. Verify a successful exchange with a client secret + assert_eq!( + (), + authorize_code_exchange( + &ctx, + &*provider, + "authorization_code", + client_id, + Some(&client_secret), + &redirect_uri, + ) + .await + .unwrap() + ); + } + + #[tokio::test] + async fn test_exchange_requires_secret_except_for_pkce_only() { + let (mut ctx, client, _) = mock_client().await; + let client_id = client.id; + let redirect_uri = client.redirect_uris[0].redirect_uri.clone(); + + let mut client_store = MockOAuthClientStore::new(); + client_store + .expect_get() + .with(eq(client_id), eq(false)) + .returning(move |_, _| Ok(Some(client.clone()))); + + let mut storage = MockStorage::new(); + storage.oauth_client_store = Some(Arc::new(client_store)); + + ctx.set_storage(Arc::new(storage)); + + let provider = ctx + .get_oauth_provider(&OAuthProviderName::Google) + .await + .unwrap(); + let pkce_only_provider = ctx + .get_oauth_provider(&OAuthProviderName::Zendesk) + .await + .unwrap(); + + // 1. Verify exchange fails when not passing a client secret for a client that does not + // support pkce_only + assert_eq!( + Some("Client secret required".to_string()), + authorize_code_exchange( + &ctx, + &*provider, + "authorization_code", + client_id, + None, + &redirect_uri, + ) + .await + .unwrap_err() + .error_description + ); + + // 2. Verify exchange passes when omitting the client secret for a client that does + // support pkce_only + assert_eq!( + (), + authorize_code_exchange( + &ctx, + &*pkce_only_provider, + "authorization_code", + client_id, + None, + &redirect_uri, + ) + .await + .unwrap() + ); + } + + #[tokio::test] + async fn test_exchange_checks_grant_type() { + let (mut ctx, client, client_secret) = mock_client().await; + let client_id = client.id; + let redirect_uri = client.redirect_uris[0].redirect_uri.clone(); + + let mut client_store = MockOAuthClientStore::new(); + client_store + .expect_get() + .with(eq(client_id), eq(false)) + .returning(move |_, _| Ok(Some(client.clone()))); + + let mut storage = MockStorage::new(); + storage.oauth_client_store = Some(Arc::new(client_store)); + + ctx.set_storage(Arc::new(storage)); + let provider = ctx + .get_oauth_provider(&OAuthProviderName::Google) + .await + .unwrap(); + + assert_eq!( + OAuthErrorCode::UnsupportedGrantType, + authorize_code_exchange( + &ctx, + &*provider, + "not_authorization_code", + client_id, + Some(&client_secret), + &redirect_uri + ) + .await + .unwrap_err() + .error + ); + + assert_eq!( + (), + authorize_code_exchange( + &ctx, + &*provider, + "authorization_code", + client_id, + Some(&client_secret), + &redirect_uri + ) + .await + .unwrap() + ); + } + + #[tokio::test] + async fn test_exchange_checks_for_valid_secret() { + let (mut ctx, client, client_secret) = mock_client().await; + let client_id = client.id; + let redirect_uri = client.redirect_uris[0].redirect_uri.clone(); + + let mut client_store = MockOAuthClientStore::new(); + client_store + .expect_get() + .with(eq(client_id), eq(false)) + .returning(move |_, _| Ok(Some(client.clone()))); + + let mut storage = MockStorage::new(); + storage.oauth_client_store = Some(Arc::new(client_store)); + + ctx.set_storage(Arc::new(storage)); + let provider = ctx + .get_oauth_provider(&OAuthProviderName::Google) + .await + .unwrap(); + + let invalid_secret = RawKey::generate::<8>(&Uuid::new_v4()) + .sign(ctx.signer()) + .await + .unwrap() + .signature() + .to_string(); + + assert_eq!( + OAuthErrorCode::InvalidRequest, + authorize_code_exchange( + &ctx, + &*provider, + "authorization_code", + client_id, + Some(&"too-short".to_string().into()), + &redirect_uri + ) + .await + .unwrap_err() + .error + ); + + assert_eq!( + OAuthErrorCode::InvalidClient, + authorize_code_exchange( + &ctx, + &*provider, + "authorization_code", + client_id, + Some(&invalid_secret.into()), + &redirect_uri + ) + .await + .unwrap_err() + .error + ); + + assert_eq!( + (), + authorize_code_exchange( + &ctx, + &*provider, + "authorization_code", + client_id, + Some(&client_secret), + &redirect_uri + ) + .await + .unwrap() + ); + } + + #[tokio::test] + async fn test_login_attempt_verification() { + let (challenge, verifier) = PkceCodeChallenge::new_random_sha256(); + let attempt = LoginAttempt { + id: TypedUuid::new_v4(), + attempt_state: LoginAttemptState::RemoteAuthenticated, + client_id: TypedUuid::new_v4(), + redirect_uri: "https://test.oxeng.dev/callback".to_string(), + state: Some("ox_state".to_string()), + pkce_challenge: Some(challenge.as_str().to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: None, + expires_at: Some(Utc::now().add(TimeDelta::try_seconds(60).unwrap())), + error: None, + provider: "google".to_string(), + provider_pkce_verifier: Some("v_verifier".to_string()), + provider_authz_code: None, + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + scope: String::new(), + }; + + let bad_client_id = LoginAttempt { + client_id: TypedUuid::new_v4(), + ..attempt.clone() + }; + + assert_eq!( + OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Invalid client id".to_string()), + error_uri: None, + state: None, + }, + verify_login_attempt( + &bad_client_id, + &attempt.provider, + attempt.client_id, + &attempt.redirect_uri, + verifier.secret().as_str(), + ) + .unwrap_err() + ); + + let bad_redirect_uri = LoginAttempt { + redirect_uri: "https://bad.oxeng.dev/callback".to_string(), + ..attempt.clone() + }; + + assert_eq!( + OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Invalid redirect uri".to_string()), + error_uri: None, + state: None, + }, + verify_login_attempt( + &bad_redirect_uri, + &attempt.provider, + attempt.client_id, + &attempt.redirect_uri, + verifier.secret().as_str(), + ) + .unwrap_err() + ); + + let unconfirmed_state = LoginAttempt { + attempt_state: LoginAttemptState::New, + ..attempt.clone() + }; + + assert_eq!( + OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Grant is in an invalid state".to_string()), + error_uri: None, + state: None, + }, + verify_login_attempt( + &unconfirmed_state, + &attempt.provider, + attempt.client_id, + &attempt.redirect_uri, + verifier.secret().as_str(), + ) + .unwrap_err() + ); + + let already_used_state = LoginAttempt { + attempt_state: LoginAttemptState::Complete, + ..attempt.clone() + }; + + assert_eq!( + OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Grant is in an invalid state".to_string()), + error_uri: None, + state: None, + }, + verify_login_attempt( + &already_used_state, + &attempt.provider, + attempt.client_id, + &attempt.redirect_uri, + verifier.secret().as_str(), + ) + .unwrap_err() + ); + + let failed_state = LoginAttempt { + attempt_state: LoginAttemptState::Failed, + ..attempt.clone() + }; + + assert_eq!( + OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Grant is in an invalid state".to_string()), + error_uri: None, + state: None, + }, + verify_login_attempt( + &failed_state, + &attempt.provider, + attempt.client_id, + &attempt.redirect_uri, + verifier.secret().as_str(), + ) + .unwrap_err() + ); + + let expired = LoginAttempt { + expires_at: Some(Utc::now()), + ..attempt.clone() + }; + + assert_eq!( + OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Grant has expired".to_string()), + error_uri: None, + state: None, + }, + verify_login_attempt( + &expired, + &attempt.provider, + attempt.client_id, + &attempt.redirect_uri, + verifier.secret().as_str(), + ) + .unwrap_err() + ); + + // Verify that a login attempt with no stored PKCE challenge is rejected. + // PKCE is mandatory, so a missing challenge means the attempt is invalid. + let missing_challenge = LoginAttempt { + pkce_challenge: None, + pkce_challenge_method: None, + ..attempt.clone() + }; + + assert_eq!( + OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Login attempt is missing a PKCE challenge".to_string()), + error_uri: None, + state: None, + }, + verify_login_attempt( + &missing_challenge, + &attempt.provider, + attempt.client_id, + &attempt.redirect_uri, + verifier.secret().as_str(), + ) + .unwrap_err() + ); + + let invalid_pkce = LoginAttempt { + pkce_challenge: Some("no-the-correct-value".to_string()), + ..attempt.clone() + }; + + assert_eq!( + OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Invalid pkce verifier".to_string()), + error_uri: None, + state: None, + }, + verify_login_attempt( + &invalid_pkce, + &attempt.provider, + attempt.client_id, + &attempt.redirect_uri, + verifier.secret().as_str(), + ) + .unwrap_err() + ); + + assert_eq!( + (), + verify_login_attempt( + &attempt, + &attempt.provider, + attempt.client_id, + &attempt.redirect_uri, + verifier.secret().as_str(), + ) + .unwrap() + ); + } + + #[tokio::test] + async fn test_provider_mismatch_is_rejected() { + let (challenge, verifier) = PkceCodeChallenge::new_random_sha256(); + + // Login attempt was created via Google + let attempt = LoginAttempt { + id: TypedUuid::new_v4(), + attempt_state: LoginAttemptState::RemoteAuthenticated, + client_id: TypedUuid::new_v4(), + redirect_uri: "https://test.oxeng.dev/callback".to_string(), + state: Some("ox_state".to_string()), + pkce_challenge: Some(challenge.as_str().to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: None, + expires_at: Some(Utc::now().add(TimeDelta::try_seconds(60).unwrap())), + error: None, + provider: "google".to_string(), + provider_pkce_verifier: Some("v_verifier".to_string()), + provider_authz_code: None, + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + scope: String::new(), + }; + + // Exchanging against a different provider must fail + assert_eq!( + OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Provider mismatch".to_string()), + error_uri: None, + state: None, + }, + verify_login_attempt( + &attempt, + "github", + attempt.client_id, + &attempt.redirect_uri, + verifier.secret().as_str(), + ) + .unwrap_err() + ); + + // Exchanging against the correct provider must succeed + assert_eq!( + (), + verify_login_attempt( + &attempt, + "google", + attempt.client_id, + &attempt.redirect_uri, + verifier.secret().as_str(), + ) + .unwrap() + ); + } + + #[test] + fn test_login_attempt_cookie_has_path() { + let cookie = + super::build_login_attempt_cookie("test-attempt-id", "https://example.com", 600); + + assert_eq!(cookie.path(), Some(super::LOGIN_ATTEMPT_COOKIE_PATH)); + } + + #[test] + fn test_login_attempt_cookie_is_http_only() { + let cookie = + super::build_login_attempt_cookie("test-attempt-id", "https://example.com", 600); + + assert_eq!(cookie.http_only(), Some(true)); + } + + #[test] + fn test_login_attempt_cookie_is_same_site_lax() { + let cookie = + super::build_login_attempt_cookie("test-attempt-id", "https://example.com", 600); + + assert_eq!(cookie.same_site(), Some(cookie::SameSite::Lax)); + } + + #[test] + fn test_login_attempt_cookie_is_secure_for_https() { + let https_cookie = + super::build_login_attempt_cookie("test-attempt-id", "https://example.com", 600); + assert_eq!(https_cookie.secure(), Some(true)); + + let http_cookie = + super::build_login_attempt_cookie("test-attempt-id", "http://localhost", 600); + assert_eq!(http_cookie.secure(), Some(false)); + } + + #[test] + fn test_login_attempt_clear_cookie_has_same_path() { + // The clear cookie must use the same Path as the set cookie, + // otherwise browsers won't clear it. + let set_cookie = + super::build_login_attempt_cookie("test-attempt-id", "https://example.com", 600); + let clear_cookie = super::build_login_attempt_cookie("", "https://example.com", 0); + + assert_eq!(set_cookie.path(), clear_cookie.path()); + assert_eq!( + clear_cookie.max_age(), + Some(cookie::time::Duration::seconds(0)) + ); + } + + #[test] + fn test_valid_response_type_is_accepted() { + assert!(super::validate_response_type("code").is_ok()); + } + + #[test] + fn test_invalid_response_type_is_rejected() { + let err = super::validate_response_type("token").unwrap_err(); + assert_eq!(err.error, OAuthErrorCode::UnsupportedResponseType); + } + + #[test] + fn test_empty_response_type_is_rejected() { + assert!(super::validate_response_type("").is_err()); + } + + #[test] + fn test_response_type_rejects_similar_values() { + assert!(super::validate_response_type("Code").is_err()); + assert!(super::validate_response_type("CODE").is_err()); + assert!(super::validate_response_type("code ").is_err()); + assert!(super::validate_response_type("token").is_err()); + assert!(super::validate_response_type("code token").is_err()); + } + + /// Create a mock context and ApiUserInfo for `filter_idp_token` tests. + async fn mock_filter_idp_token_ctx( + user_permissions: Vec, + ) -> (VContext, ApiUserInfo) { + let mut access_group_store = MockAccessGroupStore::new(); + access_group_store + .expect_list() + .returning(|_, _| Ok(vec![])); + + let mut storage = MockStorage::new(); + storage.access_group_store = Some(Arc::new(access_group_store)); + + let ctx = mock_context(Arc::new(storage)).await; + let info = ApiUserInfo { + user: ApiUser { + id: TypedUuid::new_v4(), + permissions: user_permissions.into(), + groups: Default::default(), + created_at: Utc::now(), + updated_at: Utc::now(), + deleted_at: None, + }, + email: None, + providers: vec![], + }; + (ctx, info) + } + + #[tokio::test] + async fn test_filter_idp_token_returns_token_when_requested_and_permitted() { + let (ctx, info) = + mock_filter_idp_token_ctx(vec![VPermission::RetrieveRemoteAccessToken]).await; + let token = Some("idp-token-value".to_string()); + + let result = super::filter_idp_token(&ctx, token, true, &info).await; + assert_eq!(result, Some("idp-token-value".to_string())); + } + + #[tokio::test] + async fn test_filter_idp_token_returns_none_when_not_requested() { + let (ctx, info) = + mock_filter_idp_token_ctx(vec![VPermission::RetrieveRemoteAccessToken]).await; + let token = Some("idp-token-value".to_string()); + + // Even with the permission, if not requested the token is not returned + let result = super::filter_idp_token(&ctx, token, false, &info).await; + assert_eq!(result, None); + } + + #[tokio::test] + async fn test_filter_idp_token_returns_none_when_permission_missing() { + // User has some permissions but not RetrieveRemoteAccessToken + let (ctx, info) = mock_filter_idp_token_ctx(vec![VPermission::CreateApiUser]).await; + let token = Some("idp-token-value".to_string()); + + let result = super::filter_idp_token(&ctx, token, true, &info).await; + assert_eq!(result, None); + } + + #[tokio::test] + async fn test_filter_idp_token_returns_none_when_no_permissions() { + let (ctx, info) = mock_filter_idp_token_ctx(vec![]).await; + let token = Some("idp-token-value".to_string()); + + let result = super::filter_idp_token(&ctx, token, true, &info).await; + assert_eq!(result, None); + } + + #[tokio::test] + async fn test_filter_idp_token_returns_none_when_token_is_none() { + let (ctx, info) = + mock_filter_idp_token_ctx(vec![VPermission::RetrieveRemoteAccessToken]).await; + + // Token was None (e.g. revoked upstream) — should stay None regardless of permission + let result = super::filter_idp_token(&ctx, None, true, &info).await; + assert_eq!(result, None); + } + + /// Set up mock storage for `complete_exchange` tests. The registered user will + /// have the given `user_permissions`. + fn mock_exchange_storage(user_permissions: Vec) -> MockStorage { + // ApiUserProviderStore: list returns empty (new user), upsert returns a provider + let mut provider_store = MockApiUserProviderStore::new(); + provider_store + .expect_list() + .returning(move |_, _| Ok(vec![])); + provider_store + .expect_upsert() + .returning(move |p: NewApiUserProvider| { + Ok(ApiUserProvider { + id: p.id, + user_id: p.user_id, + provider: p.provider, + provider_id: p.provider_id, + emails: p.emails, + display_names: p.display_names, + created_at: Utc::now(), + updated_at: Utc::now(), + deleted_at: None, + }) + }); + + // ApiUserStore: upsert creates a user with the specified permissions + let mut user_store = MockApiUserStore::new(); + user_store + .expect_upsert() + .returning(move |u: NewApiUser| { + Ok(ApiUserInfo { + user: ApiUser { + id: u.id, + permissions: user_permissions.clone().into(), + groups: u.groups, + created_at: Utc::now(), + updated_at: Utc::now(), + deleted_at: None, + }, + email: None, + providers: vec![], + }) + }); + + // MapperStore: list returns empty (no mappers configured) + let mut mapper_store = MockMapperStore::new(); + mapper_store.expect_list().returning(|_, _| Ok(vec![])); + + // AccessTokenStore: upsert returns a token + let mut access_token_store = MockAccessTokenStore::new(); + access_token_store.expect_upsert().returning(|token| { + Ok(AccessToken { + id: token.id, + user_id: token.user_id, + revoked_at: None, + created_at: Utc::now(), + updated_at: Utc::now(), + }) + }); + + // AccessGroupStore: list returns empty (no groups configured) + let mut access_group_store = MockAccessGroupStore::new(); + access_group_store + .expect_list() + .returning(|_, _| Ok(vec![])); + + let mut storage = MockStorage::new(); + storage.api_user_provider_store = Some(Arc::new(provider_store)); + storage.api_user_store = Some(Arc::new(user_store)); + storage.mapper_store = Some(Arc::new(mapper_store)); + storage.access_token_store = Some(Arc::new(access_token_store)); + storage.access_group_store = Some(Arc::new(access_group_store)); + storage + } + + fn mock_user_info_with_idp_token() -> UserInfo { + UserInfo { + external_id: ExternalUserId::Google("test-google-id".to_string()), + verified_emails: vec!["user@example.com".to_string()], + display_name: Some("Test User".to_string()), + idp_token: Some("secret-upstream-token".to_string()), + } + } + + fn mock_completed_attempt() -> LoginAttempt { + LoginAttempt { + id: TypedUuid::new_v4(), + attempt_state: LoginAttemptState::Complete, + client_id: TypedUuid::new_v4(), + redirect_uri: "https://example.com/callback".to_string(), + state: Some("test-state".to_string()), + pkce_challenge: Some("test-challenge".to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: Some("test-code".to_string()), + expires_at: Some(Utc::now().add(TimeDelta::try_seconds(300).unwrap())), + error: None, + provider: "google".to_string(), + provider_pkce_verifier: None, + provider_authz_code: Some("remote-code".to_string()), + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + scope: "user:info:r".to_string(), + } + } + + #[tokio::test] + async fn test_exchange_returns_idp_token_when_requested_and_permitted() { + let storage = mock_exchange_storage(vec![ + VPermission::CreateAccessToken, + VPermission::RetrieveRemoteAccessToken, + ]); + let ctx = mock_context(Arc::new(storage)).await; + let attempt = mock_completed_attempt(); + let info = mock_user_info_with_idp_token(); + let provider = NoOpOAuthProvider::new(); + + let response = super::complete_exchange( + &ctx, + info, + &provider, + &attempt, + true, + Some("secret-upstream-token".to_string()), + ) + .await + .unwrap() + .0; + + assert_eq!( + response.idp_token, + Some("secret-upstream-token".to_string()), + "IdP token must be returned when requested and user has RetrieveRemoteAccessToken" + ); + } + + #[tokio::test] + async fn test_exchange_omits_idp_token_when_permission_missing() { + let storage = mock_exchange_storage(vec![ + VPermission::CreateAccessToken, + // Notably missing: VPermission::RetrieveRemoteAccessToken + ]); + let ctx = mock_context(Arc::new(storage)).await; + let attempt = mock_completed_attempt(); + let info = mock_user_info_with_idp_token(); + let provider = NoOpOAuthProvider::new(); + + let response = super::complete_exchange( + &ctx, + info, + &provider, + &attempt, + true, + Some("secret-upstream-token".to_string()), + ) + .await + .unwrap() + .0; + + assert_eq!( + response.idp_token, None, + "IdP token must NOT be returned when user lacks RetrieveRemoteAccessToken" + ); + } + + /// Verifies that the `state` parameter survives the authorization code flow + /// round trip without modification, as required by RFC 6749 §4.1.1. The + /// authorization server MUST return the exact `state` value that the client + /// originally provided. This test uses a state value containing characters + /// that require percent-encoding (`+`, `/`, spaces, `&`, `=`) to ensure + /// they are encoded exactly once in the final redirect URL and decoded back + /// to the original value by standard URL parsing. + #[tokio::test] + async fn test_state_roundtrip_preserves_special_characters() { + let attempt_id = TypedUuid::new_v4(); + let client_id = TypedUuid::new_v4(); + let original_state = "random+state/with spaces&special=chars"; + + // State is now stored as-is (no pre-encoding). callback_url() handles + // percent-encoding when building the redirect URL. + let attempt = LoginAttempt { + id: attempt_id, + attempt_state: LoginAttemptState::New, + client_id, + redirect_uri: "https://test.oxeng.dev/callback".to_string(), + state: Some(original_state.to_string()), + pkce_challenge: Some("ox_challenge".to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: None, + expires_at: None, + error: None, + provider: "google".to_string(), + provider_pkce_verifier: Some("v_verifier".to_string()), + provider_authz_code: None, + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + scope: String::new(), + }; + + let mut attempt_store = MockLoginAttemptStore::new(); + let original_attempt = attempt.clone(); + attempt_store + .expect_get() + .with(eq(attempt.id)) + .returning(move |_| Ok(Some(original_attempt.clone()))); + + attempt_store + .expect_update_if_state() + .withf(|attempt, expected| { + attempt.attempt_state == LoginAttemptState::RemoteAuthenticated + && *expected == LoginAttemptState::New + }) + .returning(move |arg, _| { + let mut returned = attempt.clone(); + returned.attempt_state = arg.attempt_state; + returned.authz_code = arg.authz_code; + Ok(returned) + }); + + let mut storage = MockStorage::new(); + storage.login_attempt_store = Some(Arc::new(attempt_store)); + storage.oauth_client_store = Some(mock_oauth_client_store_for_callback( + client_id, + "https://test.oxeng.dev/callback", + )); + let ctx = mock_context(Arc::new(storage)).await; + + let location = + authz_code_callback_op_inner(&ctx, &attempt_id, Some("remote-code".to_string()), None) + .await + .unwrap(); + + let url = url::Url::parse(&location).unwrap(); + let returned_state = url + .query_pairs() + .find(|(k, _)| k == "state") + .map(|(_, v)| v.into_owned()) + .expect("state parameter must be present in callback URL"); + + // RFC 6749 §4.1.1: the state value MUST be returned to the client + // unmodified. The client sent `original_state`, so it should get back + // exactly `original_state` after URL decoding. + assert_eq!( + original_state, returned_state, + "RFC 6749 §4.1.1 requires the state parameter to be returned unmodified. \ + The client sent {:?} but received {:?}.", + original_state, returned_state, + ); + } + + /// RFC 6749 §5.1 requires the token response to include a `scope` parameter + /// when the issued scope differs from what the client requested, and recommends + /// it in all cases. The token response should echo back the scope that was + /// granted so clients can verify what permissions they received. + #[tokio::test] + async fn test_exchange_response_includes_scope() { + let storage = mock_exchange_storage(vec![VPermission::CreateAccessToken]); + let ctx = mock_context(Arc::new(storage)).await; + let attempt = mock_completed_attempt(); // scope = "user:info:r" + let info = UserInfo { + external_id: ExternalUserId::Google("test-google-id".to_string()), + verified_emails: vec!["user@example.com".to_string()], + display_name: Some("Test User".to_string()), + idp_token: None, + }; + + let provider = NoOpOAuthProvider::new(); + + let response = super::complete_exchange(&ctx, info, &provider, &attempt, false, None) + .await + .unwrap() + .0; + + // Serialize the response to JSON and check for a "scope" field. + // Per RFC 6749 §5.1, the authorization server SHOULD include the scope + // in the token response, and MUST include it if it differs from what + // the client requested. + let json = serde_json::to_value(&response).unwrap(); + assert!( + json.get("scope").is_some(), + "Token response must include a 'scope' field per RFC 6749 §5.1. \ + The login attempt had scope {:?} but the response was: {}", + attempt.scope, + serde_json::to_string_pretty(&json).unwrap(), + ); + } + + #[tokio::test] + async fn test_exchange_omits_idp_token_when_not_requested() { + let storage = mock_exchange_storage(vec![ + VPermission::CreateAccessToken, + VPermission::RetrieveRemoteAccessToken, + ]); + let ctx = mock_context(Arc::new(storage)).await; + let attempt = mock_completed_attempt(); + let info = mock_user_info_with_idp_token(); + let provider = NoOpOAuthProvider::new(); + + let response = super::complete_exchange( + &ctx, + info, + &provider, + &attempt, + false, + Some("secret-upstream-token".to_string()), + ) + .await + .unwrap() + .0; + + assert_eq!( + response.idp_token, None, + "IdP token must NOT be returned when not requested, even with permission" + ); + } + + /// The OAuth callback (`authz_code_callback_op_inner`) redirects the user to + /// the `redirect_uri` stored in the login attempt without re-validating it + /// against the OAuth client's currently registered redirect URIs. This means + /// that if a redirect URI is removed from the client between the authorization + /// request and the callback, the redirect still proceeds to the now-deregistered + /// URI (a TOCTOU gap). The callback should re-validate the redirect URI before + /// using it. + #[tokio::test] + async fn test_callback_revalidates_redirect_uri() { + let client_id = TypedUuid::new_v4(); + // The login attempt was created with a redirect_uri that was valid at the + // time, but has since been removed from the client's allowed list. + let deregistered_uri = "https://formerly-valid.example.com/callback"; + + let attempt_id = TypedUuid::new_v4(); + let attempt = LoginAttempt { + id: attempt_id, + attempt_state: LoginAttemptState::New, + client_id, + redirect_uri: deregistered_uri.to_string(), + state: Some("test-state".to_string()), + pkce_challenge: Some("test-challenge".to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: None, + expires_at: None, + error: None, + provider: "google".to_string(), + provider_pkce_verifier: None, + provider_authz_code: None, + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + scope: "user:info:r".to_string(), + }; + + let mut attempt_store = MockLoginAttemptStore::new(); + let original_attempt = attempt.clone(); + attempt_store + .expect_get() + .with(eq(attempt_id)) + .returning(move |_| Ok(Some(original_attempt.clone()))); + + attempt_store + .expect_update_if_state() + .withf(|attempt, expected| { + attempt.attempt_state == LoginAttemptState::RemoteAuthenticated + && *expected == LoginAttemptState::New + }) + .returning(move |arg, _| { + let mut returned = attempt.clone(); + returned.attempt_state = arg.attempt_state; + returned.authz_code = arg.authz_code; + Ok(returned) + }); + + // Configure the OAuth client with NO registered redirect URIs, + // simulating that the URI was removed after the login attempt + // was created. + let mut client_store = MockOAuthClientStore::new(); + client_store + .expect_get() + .with(eq(client_id), eq(false)) + .returning(move |_, _| { + Ok(Some(OAuthClient { + id: client_id, + secrets: vec![], + redirect_uris: vec![], // No registered URIs + created_at: Utc::now(), + deleted_at: None, + })) + }); + + let mut storage = MockStorage::new(); + storage.login_attempt_store = Some(Arc::new(attempt_store)); + storage.oauth_client_store = Some(Arc::new(client_store)); + let ctx = mock_context(Arc::new(storage)).await; + + // The callback should reject the request because the redirect URI is no + // longer registered on the OAuth client. + let err = authz_code_callback_op_inner( + &ctx, + &attempt_id, + Some("remote-code".to_string()), + None, + ) + .await + .expect_err( + "Callback should fail when the redirect URI is no longer registered on the client", + ); + + assert_eq!( + err.status_code, + StatusCode::UNAUTHORIZED, + "Expected 401 when redirect URI is deregistered, got {}", + err.status_code, + ); + } + + /// The authorization code lookup should filter by provider so that a code + /// issued for one provider (e.g. Google) is not returned when exchanging + /// against a different provider (e.g. GitHub). This is a defense-in-depth + /// measure — codes should be scoped to their issuing provider at the query + /// level rather than relying solely on post-lookup validation. + #[tokio::test] + async fn test_code_lookup_filters_by_provider() { + // Create a login attempt that was authenticated via Google + let google_attempt = LoginAttempt { + id: TypedUuid::new_v4(), + attempt_state: LoginAttemptState::RemoteAuthenticated, + client_id: TypedUuid::new_v4(), + redirect_uri: "https://test.oxeng.dev/callback".to_string(), + state: Some("test-state".to_string()), + pkce_challenge: Some("test-challenge".to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: Some("authz-code-for-google".to_string()), + expires_at: None, + error: None, + provider: "google".to_string(), + provider_pkce_verifier: None, + provider_authz_code: Some("remote-code".to_string()), + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + scope: "user:info:r".to_string(), + }; + + // The mock store simulates a real database: it only returns the + // attempt when the filter's provider field matches. + let returned_attempt = google_attempt.clone(); + let mut attempt_store = MockLoginAttemptStore::new(); + attempt_store.expect_list().returning(move |filter, _| { + let dominated = &returned_attempt; + if let Some(providers) = &filter.provider { + if providers.iter().any(|p| p == &dominated.provider) { + Ok(vec![dominated.clone()]) + } else { + Ok(vec![]) + } + } else { + Ok(vec![dominated.clone()]) + } + }); + + let mut storage = MockStorage::new(); + storage.login_attempt_store = Some(Arc::new(attempt_store)); + let ctx = mock_context(Arc::new(storage)).await; + + // Looking up the code for the correct provider should succeed. + let google_result = ctx + .login + .get_login_attempt_for_code("authz-code-for-google", "google") + .await + .unwrap(); + assert!( + google_result.is_some(), + "Code lookup for the issuing provider must return the attempt" + ); + + // Looking up the same code but for a different provider should return + // None, because the provider filter now scopes the query. + let github_result = ctx + .login + .get_login_attempt_for_code("authz-code-for-google", "github") + .await + .unwrap(); + assert!( + github_result.is_none(), + "Code lookup must not return an attempt for a different provider. \ + Expected None, but got {:?}.", + github_result.as_ref().map(|a| &a.provider), + ); + } +} diff --git a/v-api/src/endpoints/login/oauth/flow/device_token.rs b/v-api/src/endpoints/login/oauth/flow/device_token.rs new file mode 100644 index 00000000..b2521770 --- /dev/null +++ b/v-api/src/endpoints/login/oauth/flow/device_token.rs @@ -0,0 +1,535 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use chrono::{DateTime, Utc}; +use dropshot::{Body, HttpError, HttpResponseOk, Method, Path, RequestContext, TypedBody}; +use http::{HeaderMap, HeaderValue, Response, StatusCode, header}; +use hyper::body::Bytes; +use oauth2::{EmptyExtraTokenFields, StandardTokenResponse, TokenResponse, basic::BasicTokenType}; +use schemars::JsonSchema; +use secrecy::ExposeSecret; +use serde::{Deserialize, Serialize}; +use std::fmt::Debug; +use tap::TapFallible; +use tracing::instrument; +use v_model::permissions::PermissionStorage; + +use super::super::OAuthProviderNameParam; +use crate::endpoints::login::UserInfoProvider; +use crate::{ + context::ApiContext, + endpoints::login::{LoginError, oauth::OAuthProviderDeviceInfo}, + error::ApiError, + permissions::VAppPermission, + response::internal_error, + util::response::bad_request, +}; + +#[instrument(skip(rqctx), err(Debug))] +pub async fn get_device_provider_op( + rqctx: &RequestContext>, + path: Path, +) -> Result, HttpError> +where + T: VAppPermission + PermissionStorage, +{ + let path = path.into_inner(); + + tracing::trace!("Getting OAuth data for {}", path.provider); + + let provider = rqctx + .v_ctx() + .get_oauth_provider(&path.provider) + .await + .map_err(ApiError::OAuth)?; + + Ok(HttpResponseOk( + provider + .device_code_flow_info() + .cloned() + .ok_or_else(|| bad_request("Provider does not support device clients"))?, + )) +} + +#[derive(Debug, Deserialize, Serialize, JsonSchema)] +pub struct AccessTokenExchangeRequest { + pub device_code: String, + pub grant_type: String, + pub expires_at: Option>, +} + +#[derive(Serialize)] +pub struct AccessTokenExchange { + provider: ProviderTokenExchange, + expires_at: Option>, +} + +#[derive(Serialize)] +pub struct ProviderTokenExchange { + client_id: String, + device_code: String, + grant_type: String, + client_secret: String, +} + +impl AccessTokenExchange { + pub fn new(req: AccessTokenExchangeRequest, provider: &OAuthProviderDeviceInfo) -> Self { + Self { + provider: ProviderTokenExchange { + client_id: provider.remote_client_id.clone(), + device_code: req.device_code, + grant_type: req.grant_type, + client_secret: provider.remote_client_secret.0.expose_secret().to_string(), + }, + expires_at: req.expires_at, + } + } +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct ProxyTokenResponse { + pub access_token: String, + pub token_type: String, + pub expires_in: Option, + pub refresh_token: Option, + pub scopes: Option>, +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct ProxyTokenError { + error: String, + error_description: Option, + error_uri: Option, +} + +// Complete a device exchange request against the specified provider. This effectively proxies the +// requests that would go to the provider, captures the returned access tokens, and registers a +// new internal user as needed. The user is then returned an token that is valid for interacting +// with the API +#[instrument(skip(rqctx, body), err(Debug))] +pub async fn exchange_device_token_op( + rqctx: &RequestContext>, + path: Path, + body: TypedBody, +) -> Result, HttpError> +where + T: VAppPermission + PermissionStorage, +{ + let ctx = rqctx.v_ctx(); + let path = path.into_inner(); + let provider = ctx + .get_oauth_provider(&path.provider) + .await + .map_err(ApiError::OAuth)?; + let device_info = provider.device_code_flow_info(); + + tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for token exchange"); + + if device_info.is_none() { + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .header(header::CONTENT_TYPE, "application/json") + .body( + serde_json::to_vec(&ProxyTokenError { + error: "unsupported_grant_type".to_string(), + error_description: Some(format!( + "{} does not support device code flow", + path.provider + )), + error_uri: None, + }) + .unwrap() + .into(), + )?); + } + + let device_info = device_info.unwrap(); + let exchange_request = body.into_inner(); + + // Validate grant_type per RFC 8628 §3.4 + if !validate_device_grant_type(&exchange_request.grant_type) { + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .header(header::CONTENT_TYPE, "application/json") + .body( + serde_json::to_vec(&ProxyTokenError { + error: "unsupported_grant_type".to_string(), + error_description: Some( + "grant_type must be urn:ietf:params:oauth:grant-type:device_code" + .to_string(), + ), + error_uri: None, + }) + .unwrap() + .into(), + )?); + } + + let exchange = AccessTokenExchange::new(exchange_request, device_info); + + let client = reqwest::Client::new(); + + let response = client + .request(Method::POST, &device_info.token_endpoint) + .header( + header::CONTENT_TYPE, + &device_info.token_endpoint_content_type, + ) + .header(header::ACCEPT, HeaderValue::from_static("application/json")) + .body( + // We know that this is safe to unwrap as we just deserialized it via the body Extractor + serde_urlencoded::to_string(&exchange.provider).unwrap(), + ) + .send() + .await + .tap_err(|err| tracing::error!(?err, "Token exchange request failed")) + .map_err(internal_error)?; + + // Take a part the response as we will need the individual parts later + let status = response.status(); + let headers = response.headers().clone(); + let bytes = response.bytes().await.map_err(internal_error)?; + + // We unfortunately can not trust our providers to follow specs and therefore need to do + // our own inspection of the response to determine what to do + if !status.is_success() { + // If the server returned a non-success status then we are going to trust the server and + // report their error back to the client + tracing::debug!(provider = ?path.provider, ?headers, ?status, "Received error response from OAuth provider"); + + Ok(proxy_upstream_response(bytes, headers, status)) + } else { + // The server gave us back a non-error response but it still may not be a success. + // GitHub for instance does not use a status code for indicating the success or failure + // of a call. So instead we try to deserialize the body into an access token, with the + // understanding that it may fail and we will need to try and treat the response as + // an error instead. + + let parsed: Result< + StandardTokenResponse, + serde_json::Error, + > = serde_json::from_slice(&bytes); + + match parsed { + Ok(parsed) => { + let info = provider + .get_user_info(parsed.access_token().secret()) + .await + .map_err(LoginError::UserInfo) + .tap_err(|err| tracing::error!(?err, "Failed to look up user information"))?; + + tracing::debug!("Verified and validated OAuth user"); + + let (api_user_info, api_user_provider) = ctx + .register_api_user(&ctx.builtin_registration_user(), info) + .await?; + + tracing::info!(api_user_id = ?api_user_info.user.id, api_user_provider_id = ?api_user_provider.id, "Retrieved api user to generate device token for"); + + let claims = + ctx.generate_claims(&api_user_info.user.id, &api_user_provider.id, None); + let token = ctx + .user + .register_access_token( + &ctx.builtin_registration_user(), + ctx.jwt_signer(), + &api_user_info.user.id, + &claims, + ) + .await?; + + tracing::info!(provider = ?path.provider, api_user_id = ?api_user_info.user.id, "Generated access token"); + + Ok(Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "application/json") + .body( + serde_json::to_string(&ProxyTokenResponse { + access_token: token.signed_token, + token_type: "Bearer".to_string(), + expires_in: Some(claims.exp - Utc::now().timestamp()), + refresh_token: None, + scopes: None, + }) + .unwrap() + .into(), + )?) + } + Err(_) => { + // Do not log the error here as we want to ensure we do not leak token information + tracing::debug!( + "Failed to parse a success response from the remote token endpoint" + ); + + Ok(handle_token_parse_failure( + &path.provider.to_string(), + bytes, + headers, + status, + )) + } + } + } +} + +/// Validate the grant_type for device code exchange per RFC 8628 §3.4. +fn validate_device_grant_type(grant_type: &str) -> bool { + grant_type == "urn:ietf:params:oauth:grant-type:device_code" +} + +/// Headers that are safe to forward from an upstream OAuth provider response. +/// Only `Content-Type` is needed so the client can parse the body. Polling backoff +/// is handled via the JSON body per RFC 8628 (`interval` field / `slow_down` error), +/// not via HTTP headers. +const FORWARDED_HEADERS: &[header::HeaderName] = &[header::CONTENT_TYPE]; + +/// Copy only allowlisted headers from an upstream response to avoid forwarding +/// dangerous headers such as `Set-Cookie`, `Location`, or CORS headers. +fn filter_upstream_headers(upstream: &HeaderMap) -> HeaderMap { + let mut filtered = HeaderMap::new(); + for name in FORWARDED_HEADERS { + if let Some(value) = upstream.get(name) { + filtered.insert(name.clone(), value.clone()); + } + } + filtered +} + +/// Build a response to the client by proxying an upstream provider's response. This is used +/// when the upstream provider returns a non-success status code. +fn proxy_upstream_response(bytes: Bytes, headers: HeaderMap, status: StatusCode) -> Response { + let mut client_response = Response::new(Body::from(bytes)); + *client_response.headers_mut() = filter_upstream_headers(&headers); + *client_response.status_mut() = status; + client_response +} + +/// Handle the case where the upstream provider returned a 200 status but the body could not be +/// parsed as a valid token response. We try to interpret the body as an error response and proxy +/// it back. If the body is not a recognizable error either, we return our own error. +fn handle_token_parse_failure( + provider_name: &str, + bytes: Bytes, + headers: HeaderMap, + status: StatusCode, +) -> Response { + // Try to deserialize the body as an error + let mut error_response = match serde_json::from_slice::(&bytes) { + Ok(error) => { + // We found an error in the message body. This is not ideal, but we at + // least can understand what the server was trying to tell us + tracing::debug!( + ?error, + provider_name, + "Parsed error response from OAuth provider" + ); + + let mut client_response = Response::new(Body::from(bytes)); + *client_response.headers_mut() = filter_upstream_headers(&headers); + *client_response.status_mut() = status; + + client_response + } + Err(_) => { + // We still do not know what the remote server is doing... and need to + // cancel the request ourselves + tracing::warn!("Remote OAuth provider returned a response that we do not understand"); + + Response::new( + serde_json::to_vec(&ProxyTokenError { + error: "access_denied".to_string(), + error_description: Some(format!( + "{} returned a malformed response", + provider_name + )), + error_uri: None, + }) + .unwrap() + .into(), + ) + } + }; + + *error_response.status_mut() = StatusCode::BAD_REQUEST; + error_response.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + + error_response +} + +#[cfg(test)] +mod tests { + use http::{ + HeaderMap, HeaderValue, StatusCode, + header::{self, HeaderName, SET_COOKIE}, + }; + use hyper::body::Bytes; + + use super::{handle_token_parse_failure, proxy_upstream_response, validate_device_grant_type}; + + #[test] + fn test_upstream_set_cookie_is_stripped_from_error_response() { + // A malicious or compromised upstream provider includes a Set-Cookie header + // that would set a cookie on our API's domain in the user's browser + let mut upstream_headers = HeaderMap::new(); + upstream_headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + upstream_headers.insert( + SET_COOKIE, + HeaderValue::from_static("session=malicious-value; Path=/; HttpOnly"), + ); + + let body = Bytes::from_static(b"{\"error\": \"authorization_pending\"}"); + let response = proxy_upstream_response(body, upstream_headers, StatusCode::FORBIDDEN); + + // The Set-Cookie header must NOT be forwarded to our client + assert!( + response.headers().get(SET_COOKIE).is_none(), + "Upstream Set-Cookie header must not be forwarded to the client" + ); + // But Content-Type should still be forwarded + assert!( + response.headers().get(header::CONTENT_TYPE).is_some(), + "Content-Type should be forwarded from upstream" + ); + } + + #[test] + fn test_upstream_cors_headers_are_stripped_from_error_response() { + // A malicious upstream provider injects permissive CORS headers that would + // weaken our API's cross-origin protections + let mut upstream_headers = HeaderMap::new(); + upstream_headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + upstream_headers.insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + HeaderValue::from_static("*"), + ); + upstream_headers.insert( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + + let body = Bytes::from_static(b"{\"error\": \"authorization_pending\"}"); + let response = proxy_upstream_response(body, upstream_headers, StatusCode::BAD_REQUEST); + + // CORS headers must NOT be forwarded from upstream + assert!( + response + .headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .is_none(), + "Upstream CORS header must not be forwarded to the client" + ); + assert!( + response + .headers() + .get(header::ACCESS_CONTROL_ALLOW_CREDENTIALS) + .is_none(), + "Upstream CORS credentials header must not be forwarded to the client" + ); + } + + #[test] + fn test_upstream_location_and_framing_headers_are_stripped_from_token_parse_failure() { + // When the upstream returns a 200 status but the body is an error (not a valid + // token), handle_token_parse_failure must not forward dangerous headers. + let mut upstream_headers = HeaderMap::new(); + upstream_headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + upstream_headers.insert( + header::LOCATION, + HeaderValue::from_static("https://evil.example.com/phishing"), + ); + upstream_headers.insert( + HeaderName::from_static("x-frame-options"), + HeaderValue::from_static("ALLOW-FROM https://evil.example.com"), + ); + + // Body that parses as a ProxyTokenError but NOT as a valid token + let body = Bytes::from_static( + b"{\"error\": \"slow_down\", \"error_description\": null, \"error_uri\": null}", + ); + let response = + handle_token_parse_failure("test-provider", body, upstream_headers, StatusCode::OK); + + // Dangerous headers must NOT be forwarded + assert!( + response.headers().get(header::LOCATION).is_none(), + "Upstream Location header must not be forwarded to the client" + ); + assert!( + response.headers().get("x-frame-options").is_none(), + "Upstream X-Frame-Options header must not be forwarded to the client" + ); + // But Content-Type should still be present (set by the function itself) + assert!( + response.headers().get(header::CONTENT_TYPE).is_some(), + "Content-Type should be present on the response" + ); + } + + #[test] + fn test_upstream_set_cookie_is_stripped_from_token_parse_failure() { + // Even when the upstream returns 200 and the body is a parseable error, + // a Set-Cookie header must not be forwarded to our client + let mut upstream_headers = HeaderMap::new(); + upstream_headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + upstream_headers.insert( + SET_COOKIE, + HeaderValue::from_static("tracking=evil-tracker; Domain=.our-api.com; Path=/"), + ); + + let body = Bytes::from_static( + b"{\"error\": \"authorization_pending\", \"error_description\": null, \"error_uri\": null}", + ); + let response = + handle_token_parse_failure("test-provider", body, upstream_headers, StatusCode::OK); + + // The Set-Cookie header must NOT be forwarded + assert!( + response.headers().get(SET_COOKIE).is_none(), + "Upstream Set-Cookie header must not be forwarded via token parse failure path" + ); + } + + #[test] + fn test_valid_device_grant_type_is_accepted() { + assert!(validate_device_grant_type( + "urn:ietf:params:oauth:grant-type:device_code" + )); + } + + #[test] + fn test_invalid_device_grant_type_is_rejected() { + assert!(!validate_device_grant_type("authorization_code")); + } + + #[test] + fn test_empty_device_grant_type_is_rejected() { + assert!(!validate_device_grant_type("")); + } + + #[test] + fn test_device_grant_type_rejects_similar_values() { + assert!(!validate_device_grant_type("device_code")); + assert!(!validate_device_grant_type( + "urn:ietf:params:oauth:grant-type:device_Code" + )); + assert!(!validate_device_grant_type( + "urn:ietf:params:oauth:grant-type:authorization_code" + )); + } +} diff --git a/v-api/src/endpoints/login/oauth/flow/mod.rs b/v-api/src/endpoints/login/oauth/flow/mod.rs new file mode 100644 index 00000000..305cd9ab --- /dev/null +++ b/v-api/src/endpoints/login/oauth/flow/mod.rs @@ -0,0 +1,6 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +pub mod code; +pub mod device_token; diff --git a/v-api/src/endpoints/login/oauth/google.rs b/v-api/src/endpoints/login/oauth/google.rs deleted file mode 100644 index a38024d7..00000000 --- a/v-api/src/endpoints/login/oauth/google.rs +++ /dev/null @@ -1,188 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -use hyper::body::Bytes; -use reqwest::Request; -use secrecy::SecretString; -use serde::Deserialize; -use std::fmt; - -use crate::endpoints::login::{ExternalUserId, UserInfo, UserInfoError}; - -use super::{ - ClientType, ExtractUserInfo, OAuthPrivateCredentials, OAuthProvider, OAuthProviderName, - OAuthPublicCredentials, -}; - -pub struct GoogleOAuthProvider { - device_public: OAuthPublicCredentials, - device_private: Option, - web_public: OAuthPublicCredentials, - web_private: Option, - additional_scopes: Vec, - client: reqwest::Client, -} - -impl fmt::Debug for GoogleOAuthProvider { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("GoogleOAuthProvider").finish() - } -} - -impl GoogleOAuthProvider { - pub fn new( - device_client_id: String, - device_client_secret: SecretString, - web_client_id: String, - web_client_secret: SecretString, - additional_scopes: Option>, - ) -> Self { - Self { - device_public: OAuthPublicCredentials { - client_id: device_client_id, - }, - device_private: Some(OAuthPrivateCredentials { - client_secret: device_client_secret, - }), - web_public: OAuthPublicCredentials { - client_id: web_client_id, - }, - web_private: Some(OAuthPrivateCredentials { - client_secret: web_client_secret, - }), - additional_scopes: additional_scopes.unwrap_or_default(), - client: reqwest::ClientBuilder::new() - .redirect(reqwest::redirect::Policy::none()) - .build() - .expect("Static client must build"), - } - } - - pub fn with_client(&mut self, client: reqwest::Client) -> &mut Self { - self.client = client; - self - } -} - -#[derive(Debug, Deserialize)] -struct GoogleUserInfo { - sub: String, - email: String, - email_verified: bool, -} - -#[derive(Debug, Deserialize)] -struct GoogleProfile { - #[serde(default)] - names: Vec, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -struct GoogleProfileName { - display_name: String, - metadata: GoogleProfileNameMeta, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -struct GoogleProfileNameMeta { - #[serde(default)] - primary: bool, -} - -impl ExtractUserInfo for GoogleOAuthProvider { - // There should always be as many entries in the data list as there are endpoints. This should - // be changed in the future to be a static check - fn extract_user_info(&self, data: &[Bytes]) -> Result { - let remote_info: GoogleUserInfo = serde_json::from_slice(&data[0])?; - let verified_emails = if remote_info.email_verified { - vec![remote_info.email] - } else { - vec![] - }; - - let profile_info: GoogleProfile = serde_json::from_slice(&data[1])?; - let display_name = profile_info - .names - .into_iter() - .filter_map(|name| name.metadata.primary.then_some(name.display_name)) - .nth(0); - - Ok(UserInfo { - external_id: ExternalUserId::Google(remote_info.sub), - verified_emails, - display_name, - }) - } -} - -impl OAuthProvider for GoogleOAuthProvider { - fn name(&self) -> OAuthProviderName { - OAuthProviderName::Google - } - - fn scopes(&self) -> Vec<&str> { - let mut default = vec!["openid", "email", "profile"]; - default.extend(self.additional_scopes.iter().map(|s| s.as_str())); - default - } - - fn initialize_headers(&self, _request: &mut Request) {} - - fn client(&self) -> &reqwest::Client { - &self.client - } - - fn client_id(&self, client_type: &ClientType) -> &str { - match client_type { - ClientType::Device => &self.device_public.client_id, - ClientType::Web => &self.web_public.client_id, - } - } - - fn client_secret(&self, client_type: &ClientType) -> Option<&SecretString> { - match client_type { - ClientType::Device => self - .device_private - .as_ref() - .map(|private| &private.client_secret), - ClientType::Web => self - .web_private - .as_ref() - .map(|private| &private.client_secret), - } - } - - fn user_info_endpoints(&self) -> Vec<&str> { - vec![ - "https://openidconnect.googleapis.com/v1/userinfo", - "https://people.googleapis.com/v1/people/me?personFields=names", - ] - } - - fn device_code_endpoint(&self) -> &str { - "https://oauth2.googleapis.com/device/code" - } - - fn auth_url_endpoint(&self) -> &str { - "https://accounts.google.com/o/oauth2/v2/auth" - } - - fn token_exchange_content_type(&self) -> &str { - "application/x-www-form-urlencoded" - } - - fn token_exchange_endpoint(&self) -> &str { - "https://oauth2.googleapis.com/token" - } - - fn token_revocation_endpoint(&self) -> Option<&str> { - Some("https://oauth2.googleapis.com/revoke") - } - - fn supports_pkce(&self) -> bool { - true - } -} diff --git a/v-api/src/endpoints/login/oauth/mod.rs b/v-api/src/endpoints/login/oauth/mod.rs index a5ddc119..a592373f 100644 --- a/v-api/src/endpoints/login/oauth/mod.rs +++ b/v-api/src/endpoints/login/oauth/mod.rs @@ -5,44 +5,48 @@ use async_trait::async_trait; use http::Method; use hyper::{body::Bytes, header::AUTHORIZATION, header::HeaderValue}; +use newtype_uuid::TypedUuid; use oauth2::{ AuthUrl, ClientId, ClientSecret, EndpointMaybeSet, EndpointNotSet, EndpointSet, RedirectUrl, RevocationUrl, TokenUrl, basic::BasicClient, url::ParseError, }; use reqwest::Request; use schemars::JsonSchema; -use secrecy::{ExposeSecret, SecretString}; +use secrecy::ExposeSecret; use serde::{Deserialize, Serialize}; use std::fmt::{Debug, Display}; use thiserror::Error; use tracing::instrument; -use v_model::OAuthClient; +use v_model::{OAuthClient, OAuthClientId}; -use crate::authn::{Verify, key::RawKey}; +use crate::{ + authn::{Verify, key::RawKey}, + secrets::OpenApiSecretString, +}; -use super::{UserInfo, UserInfoError, UserInfoProvider}; +use super::{UserInfo, UserInfoError, UserInfoProvider, is_redirect_uri_valid}; pub mod client; -pub mod code; -pub mod device_token; -pub mod github; -pub mod google; +pub mod flow; +pub mod remote; #[derive(Debug, Error)] pub enum OAuthProviderError { #[error("Unable to instantiate invalid provider")] FailToCreateInvalidProvider, + #[error("Missing redirect URI")] + MissingRedirectUri, + #[error("Failed to parse URL")] + UrlParseError(#[from] ParseError), + #[error("Provider does not support web clients")] + WebClientNotSupported, } #[derive(Debug)] pub enum ClientType { Device, Web, -} - -#[derive(Debug)] -pub struct WebClientConfig { - prefix: String, + WebPkce, } pub type WebClient = BasicClient< @@ -58,67 +62,47 @@ pub type WebClient = BasicClient< EndpointSet, >; -pub struct OAuthPublicCredentials { - client_id: String, -} - -pub struct OAuthPrivateCredentials { - client_secret: SecretString, -} - pub trait OAuthProvider: ExtractUserInfo + Debug + Send + Sync { fn name(&self) -> OAuthProviderName; - fn scopes(&self) -> Vec<&str>; fn initialize_headers(&self, request: &mut Request); fn client(&self) -> &reqwest::Client; - fn client_id(&self, client_type: &ClientType) -> &str; - fn client_secret(&self, client_type: &ClientType) -> Option<&SecretString>; - - // TODO: How can user info be change to something statically checked instead of a runtime check fn user_info_endpoints(&self) -> Vec<&str>; - fn device_code_endpoint(&self) -> &str; - fn auth_url_endpoint(&self) -> &str; - fn token_exchange_content_type(&self) -> &str; - fn token_exchange_endpoint(&self) -> &str; - fn token_revocation_endpoint(&self) -> Option<&str>; - fn supports_pkce(&self) -> bool; - fn provider_info(&self, public_url: &str, client_type: &ClientType) -> OAuthProviderInfo { - OAuthProviderInfo { - provider: self.name(), - client_id: self.client_id(client_type).to_string(), - auth_url_endpoint: self.auth_url_endpoint().to_string(), - device_code_endpoint: self.device_code_endpoint().to_string(), - token_endpoint: format!("{}/login/oauth/{}/device/exchange", public_url, self.name(),), - scopes: self - .scopes() - .into_iter() - .map(|s| s.to_string()) - .collect::>(), - } - } + fn authz_code_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodeInfo>; + fn authz_code_pkce_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodePkceInfo>; + fn device_code_flow_info(&self) -> Option<&OAuthProviderDeviceInfo>; - fn as_web_client(&self, config: &WebClientConfig) -> Result { - let mut client = - BasicClient::new(ClientId::new(self.client_id(&ClientType::Web).to_string())) - .set_auth_uri(AuthUrl::new(self.auth_url_endpoint().to_string())?) - .set_token_uri(TokenUrl::new(self.token_exchange_endpoint().to_string())?) - .set_revocation_url_option( - self.token_revocation_endpoint() - .map(|s| RevocationUrl::new(s.to_string())) - .transpose()?, - ) - .set_redirect_uri(RedirectUrl::new(format!( - "{}/login/oauth/{}/code/callback", - &config.prefix, - self.name() - ))?); - - if let Some(secret) = self.client_secret(&ClientType::Web) { - client = client.set_client_secret(ClientSecret::new(secret.expose_secret().to_string())) - } + fn expires_in(&self) -> Option; + fn default_scopes(&self) -> &[String]; + + /// Whether the remote OAuth provider supports PKCE (RFC 7636). Providers must + /// explicitly declare this. This controls whether v-api sends a PKCE challenge + /// to the remote provider during the authorization code exchange. Note: clients + /// calling v-api are always required to use PKCE regardless of this setting. + fn supports_pkce(&self) -> bool; - Ok(client) + fn as_web_client(&self) -> Result { + match self.authz_code_flow_info() { + Some(info) => { + let client = BasicClient::new(ClientId::new(info.remote.client_id.clone())) + .set_auth_uri(AuthUrl::new(info.remote.auth_url_endpoint.clone())?) + .set_token_uri(TokenUrl::new(info.remote.token_endpoint.clone())?) + .set_revocation_url_option( + info.remote + .revocation_endpoint + .as_ref() + .map(|url| RevocationUrl::new(url.to_string())) + .transpose()?, + ) + .set_redirect_uri(RedirectUrl::new(info.redirect_endpoint.to_string())?) + .set_client_secret(ClientSecret::new( + info.remote.client_secret.0.expose_secret().to_string(), + )); + + Ok(client) + } + None => Err(OAuthProviderError::WebClientNotSupported), + } } } @@ -155,33 +139,90 @@ where ); let response = self.client().execute(request).await?; - - tracing::trace!(status = ?response.status(), "Received response from OAuth provider"); + let status = response.status(); + + tracing::trace!(?status, "Received response from OAuth provider"); + + if !status.is_success() { + tracing::error!( + ?status, + endpoint, + "User info endpoint returned non-success status" + ); + return Err(UserInfoError::UnexpectedStatus { + endpoint: endpoint.to_string(), + status, + }); + } let bytes = response.bytes().await?; responses.push(bytes); } - self.extract_user_info(&responses) + let mut info = self.extract_user_info(&responses)?; + info.idp_token = Some(token.to_string()); + Ok(info) } } -#[derive(Debug, Deserialize, Serialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, JsonSchema)] pub struct OAuthProviderInfo { provider: OAuthProviderName, client_id: String, + code: Option, + pkce: Option, + device: Option, +} + +#[derive(Clone, Debug, Serialize, JsonSchema)] +pub struct OAuthProviderAuthorizationCodeInfo { + auth_url_endpoint: String, + redirect_endpoint: String, + token_endpoint_content_type: String, + token_endpoint: String, + remote: OAuthProviderAuthorizationCodeRemoteInfo, +} + +#[derive(Clone, Debug, Serialize, JsonSchema)] +pub struct OAuthProviderAuthorizationCodeRemoteInfo { + client_id: String, + #[schemars(skip)] + #[serde(skip)] + client_secret: OpenApiSecretString, auth_url_endpoint: String, + token_endpoint_content_type: String, + token_endpoint: String, + revocation_endpoint: Option, +} + +#[derive(Clone, Debug, Serialize, JsonSchema)] +pub struct OAuthProviderAuthorizationCodePkceInfo { + client_id: TypedUuid, + redirect_endpoint: String, + proxy_port: u16, + web: OAuthProviderAuthorizationCodeInfo, +} + +#[derive(Clone, Debug, Serialize, JsonSchema)] +pub struct OAuthProviderDeviceInfo { + client_id: TypedUuid, + remote_client_id: String, + #[schemars(skip)] + #[serde(skip)] + remote_client_secret: OpenApiSecretString, device_code_endpoint: String, + token_endpoint_content_type: String, token_endpoint: String, - scopes: Vec, + revocation_endpoint: Option, } -#[derive(Debug, Deserialize, PartialEq, Eq, Hash, Serialize, JsonSchema)] +#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Hash, Serialize, JsonSchema)] #[serde(rename_all = "kebab-case")] pub enum OAuthProviderName { #[serde(rename = "github")] GitHub, Google, + Zendesk, } impl Display for OAuthProviderName { @@ -189,6 +230,7 @@ impl Display for OAuthProviderName { match self { OAuthProviderName::GitHub => write!(f, "github"), OAuthProviderName::Google => write!(f, "google"), + OAuthProviderName::Zendesk => write!(f, "zendesk"), } } } @@ -224,8 +266,9 @@ impl CheckOAuthClient for OAuthClient { fn is_redirect_uri_valid(&self, redirect_uri: &str) -> bool { tracing::trace!(?redirect_uri, valid_uris = ?self.redirect_uris, "Checking redirect uri against list of valid uris"); - self.redirect_uris - .iter() - .any(|r| r.redirect_uri == redirect_uri) + is_redirect_uri_valid( + redirect_uri, + self.redirect_uris.iter().map(|r| r.redirect_uri.as_str()), + ) } } diff --git a/v-api/src/endpoints/login/oauth/github.rs b/v-api/src/endpoints/login/oauth/remote/github.rs similarity index 50% rename from v-api/src/endpoints/login/oauth/github.rs rename to v-api/src/endpoints/login/oauth/remote/github.rs index 2011162f..4f3405bd 100644 --- a/v-api/src/endpoints/login/oauth/github.rs +++ b/v-api/src/endpoints/login/oauth/remote/github.rs @@ -5,26 +5,27 @@ use http::{HeaderMap, HeaderValue, header::USER_AGENT}; use hyper::body::Bytes; use reqwest::Request; -use secrecy::SecretString; use serde::Deserialize; use std::fmt; -use crate::endpoints::login::{ExternalUserId, UserInfo, UserInfoError}; - -use super::{ - ClientType, ExtractUserInfo, OAuthPrivateCredentials, OAuthProvider, OAuthProviderName, - OAuthPublicCredentials, +use crate::{ + config::ResolvedOAuthConfig, + endpoints::login::{ + ExternalUserId, UserInfo, UserInfoError, + oauth::{ + OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, + OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo, + }, + }, }; +use super::super::{ExtractUserInfo, OAuthProvider, OAuthProviderName}; + pub struct GitHubOAuthProvider { - // public: GitHubPublicProvider, - // private: Option, - device_public: OAuthPublicCredentials, - device_private: Option, - web_public: OAuthPublicCredentials, - web_private: Option, - additional_scopes: Vec, + authz_code_flow_info: Option, + device_code_flow_info: Option, default_headers: HeaderMap, + default_scopes: Vec, client: reqwest::Client, } @@ -36,30 +37,45 @@ impl fmt::Debug for GitHubOAuthProvider { impl GitHubOAuthProvider { pub fn new( - device_client_id: String, - device_client_secret: SecretString, - web_client_id: String, - web_client_secret: SecretString, + config: ResolvedOAuthConfig, + public_url: String, additional_scopes: Option>, ) -> Self { let mut headers = HeaderMap::new(); headers.insert(USER_AGENT, HeaderValue::from_static("v-api")); - Self { - device_public: OAuthPublicCredentials { - client_id: device_client_id, - }, - device_private: Some(OAuthPrivateCredentials { - client_secret: device_client_secret, - }), - web_public: OAuthPublicCredentials { - client_id: web_client_id, + let mut default_scopes = vec!["user:email".to_string()]; + default_scopes.extend(additional_scopes.unwrap_or_default()); + + let authz_code_flow_info = config.web.map(|web| OAuthProviderAuthorizationCodeInfo { + auth_url_endpoint: format!("{}/login/oauth/github/code/authorize", public_url), + redirect_endpoint: format!("{}/login/oauth/github/code/callback", public_url), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: format!("{}/login/oauth/github/code/token", public_url), + remote: OAuthProviderAuthorizationCodeRemoteInfo { + client_id: web.remote_client_id, + client_secret: web.remote_client_secret.into(), + auth_url_endpoint: "https://github.com/login/oauth/authorize".to_string(), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: "https://github.com/login/oauth/access_token".to_string(), + revocation_endpoint: None, }, - web_private: Some(OAuthPrivateCredentials { - client_secret: web_client_secret, - }), - additional_scopes: additional_scopes.unwrap_or_default(), + }); + let device_code_flow_info = config.device.map(|device| OAuthProviderDeviceInfo { + client_id: device.client_id, + remote_client_id: device.remote_client_id, + remote_client_secret: device.remote_client_secret.into(), + device_code_endpoint: "https://github.com/login/device/code".to_string(), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: "https://github.com/login/oauth/access_token".to_string(), + revocation_endpoint: None, + }); + + Self { + authz_code_flow_info, + device_code_flow_info, default_headers: headers, + default_scopes, client: reqwest::ClientBuilder::new() .redirect(reqwest::redirect::Policy::none()) .build() @@ -104,6 +120,7 @@ impl ExtractUserInfo for GitHubOAuthProvider { external_id: ExternalUserId::GitHub(user.id.to_string()), verified_emails, display_name: Some(user.login), + idp_token: None, }) } } @@ -112,41 +129,12 @@ impl OAuthProvider for GitHubOAuthProvider { fn name(&self) -> OAuthProviderName { OAuthProviderName::GitHub } - - fn scopes(&self) -> Vec<&str> { - let mut default = vec!["user:email"]; - default.extend(self.additional_scopes.iter().map(|s| s.as_str())); - default - } - fn initialize_headers(&self, request: &mut Request) { *request.headers_mut() = self.default_headers.clone(); } - fn client(&self) -> &reqwest::Client { &self.client } - - fn client_id(&self, client_type: &ClientType) -> &str { - match client_type { - ClientType::Device => &self.device_public.client_id, - ClientType::Web => &self.web_public.client_id, - } - } - - fn client_secret(&self, client_type: &ClientType) -> Option<&SecretString> { - match client_type { - ClientType::Device => self - .device_private - .as_ref() - .map(|private| &private.client_secret), - ClientType::Web => self - .web_private - .as_ref() - .map(|private| &private.client_secret), - } - } - fn user_info_endpoints(&self) -> Vec<&str> { vec![ "https://api.github.com/user", @@ -154,27 +142,23 @@ impl OAuthProvider for GitHubOAuthProvider { ] } - fn device_code_endpoint(&self) -> &str { - "https://github.com/login/device/code" + fn expires_in(&self) -> Option { + None } - - fn auth_url_endpoint(&self) -> &str { - "https://github.com/login/oauth/authorize" + fn default_scopes(&self) -> &[String] { + &self.default_scopes } - - fn token_exchange_content_type(&self) -> &str { - "application/x-www-form-urlencoded" + fn supports_pkce(&self) -> bool { + true } - fn token_exchange_endpoint(&self) -> &str { - "https://github.com/login/oauth/access_token" + fn authz_code_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodeInfo> { + self.authz_code_flow_info.as_ref() } - - fn token_revocation_endpoint(&self) -> Option<&str> { + fn authz_code_pkce_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodePkceInfo> { None } - - fn supports_pkce(&self) -> bool { - true + fn device_code_flow_info(&self) -> Option<&OAuthProviderDeviceInfo> { + self.device_code_flow_info.as_ref() } } diff --git a/v-api/src/endpoints/login/oauth/remote/google.rs b/v-api/src/endpoints/login/oauth/remote/google.rs new file mode 100644 index 00000000..92bae32d --- /dev/null +++ b/v-api/src/endpoints/login/oauth/remote/google.rs @@ -0,0 +1,189 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use hyper::body::Bytes; +use reqwest::Request; +use serde::Deserialize; +use std::fmt; + +use crate::{ + config::ResolvedOAuthConfig, + endpoints::login::{ + ExternalUserId, UserInfo, UserInfoError, + oauth::{ + OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, + OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo, + }, + }, +}; + +use super::super::{ExtractUserInfo, OAuthProvider, OAuthProviderName}; + +pub struct GoogleOAuthProvider { + authz_code_flow_info: Option, + authz_code_pkce_flow_info: Option, + device_code_flow_info: Option, + default_scopes: Vec, + client: reqwest::Client, +} + +impl fmt::Debug for GoogleOAuthProvider { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GoogleOAuthProvider").finish() + } +} + +impl GoogleOAuthProvider { + pub fn new( + config: ResolvedOAuthConfig, + public_url: String, + additional_scopes: Option>, + ) -> Self { + let mut default_scopes = vec![ + "openid".to_string(), + "email".to_string(), + "profile".to_string(), + ]; + default_scopes.extend(additional_scopes.unwrap_or_default()); + + let authz_code_flow_info = config.web.map(|web| OAuthProviderAuthorizationCodeInfo { + auth_url_endpoint: format!("{}/login/oauth/google/code/authorize", public_url), + redirect_endpoint: format!("{}/login/oauth/google/code/callback", public_url), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: format!("{}/login/oauth/google/code/token", public_url), + remote: OAuthProviderAuthorizationCodeRemoteInfo { + client_id: web.remote_client_id, + client_secret: web.remote_client_secret.into(), + auth_url_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: "https://oauth2.googleapis.com/token".to_string(), + revocation_endpoint: Some("https://oauth2.googleapis.com/revoke".to_string()), + }, + }); + let authz_code_pkce_flow_info = config + .proxy_web + .and_then(|proxy| authz_code_flow_info.as_ref().map(|web| (web, proxy))) + .map(|(web, proxy)| OAuthProviderAuthorizationCodePkceInfo { + client_id: proxy.client_id, + redirect_endpoint: proxy.redirect_uri, + proxy_port: proxy.proxy_port, + web: web.clone(), + }); + let device_code_flow_info = config.device.map(|device| OAuthProviderDeviceInfo { + client_id: device.client_id, + remote_client_id: device.remote_client_id, + remote_client_secret: device.remote_client_secret.into(), + device_code_endpoint: "https://oauth2.googleapis.com/device/code".to_string(), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: "https://oauth2.googleapis.com/token".to_string(), + revocation_endpoint: Some("https://oauth2.googleapis.com/revoke".to_string()), + }); + + Self { + authz_code_flow_info, + authz_code_pkce_flow_info, + device_code_flow_info, + default_scopes, + client: reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Static client must build"), + } + } + + pub fn with_client(&mut self, client: reqwest::Client) -> &mut Self { + self.client = client; + self + } +} + +#[derive(Debug, Deserialize)] +struct GoogleUserInfo { + sub: String, + email: String, + email_verified: bool, +} + +#[derive(Debug, Deserialize)] +struct GoogleProfile { + #[serde(default)] + names: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct GoogleProfileName { + display_name: String, + metadata: GoogleProfileNameMeta, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct GoogleProfileNameMeta { + #[serde(default)] + primary: bool, +} + +impl ExtractUserInfo for GoogleOAuthProvider { + // There should always be as many entries in the data list as there are endpoints. This should + // be changed in the future to be a static check + fn extract_user_info(&self, data: &[Bytes]) -> Result { + let remote_info: GoogleUserInfo = serde_json::from_slice(&data[0])?; + let verified_emails = if remote_info.email_verified { + vec![remote_info.email] + } else { + vec![] + }; + + let profile_info: GoogleProfile = serde_json::from_slice(&data[1])?; + let display_name = profile_info + .names + .into_iter() + .filter_map(|name| name.metadata.primary.then_some(name.display_name)) + .nth(0); + + Ok(UserInfo { + external_id: ExternalUserId::Google(remote_info.sub), + verified_emails, + display_name, + idp_token: None, + }) + } +} + +impl OAuthProvider for GoogleOAuthProvider { + fn name(&self) -> OAuthProviderName { + OAuthProviderName::Google + } + fn initialize_headers(&self, _request: &mut Request) {} + fn client(&self) -> &reqwest::Client { + &self.client + } + fn user_info_endpoints(&self) -> Vec<&str> { + vec![ + "https://openidconnect.googleapis.com/v1/userinfo", + "https://people.googleapis.com/v1/people/me?personFields=names", + ] + } + + fn expires_in(&self) -> Option { + None + } + fn default_scopes(&self) -> &[String] { + &self.default_scopes + } + fn supports_pkce(&self) -> bool { + true + } + + fn authz_code_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodeInfo> { + self.authz_code_flow_info.as_ref() + } + fn authz_code_pkce_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodePkceInfo> { + self.authz_code_pkce_flow_info.as_ref() + } + fn device_code_flow_info(&self) -> Option<&OAuthProviderDeviceInfo> { + self.device_code_flow_info.as_ref() + } +} diff --git a/v-api/src/endpoints/login/oauth/remote/mod.rs b/v-api/src/endpoints/login/oauth/remote/mod.rs new file mode 100644 index 00000000..3a924871 --- /dev/null +++ b/v-api/src/endpoints/login/oauth/remote/mod.rs @@ -0,0 +1,7 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +pub mod github; +pub mod google; +pub mod zendesk; diff --git a/v-api/src/endpoints/login/oauth/remote/zendesk.rs b/v-api/src/endpoints/login/oauth/remote/zendesk.rs new file mode 100644 index 00000000..283645cd --- /dev/null +++ b/v-api/src/endpoints/login/oauth/remote/zendesk.rs @@ -0,0 +1,162 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use hyper::body::Bytes; +use reqwest::Request; +use serde::Deserialize; +use std::fmt; + +use crate::{ + config::ResolvedOAuthConfig, + endpoints::login::{ + ExternalUserId, UserInfo, UserInfoError, + oauth::{ + OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, + OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo, + }, + }, +}; + +use super::super::{ExtractUserInfo, OAuthProvider, OAuthProviderName}; + +pub struct ZendeskOAuthProvider { + authz_code_flow_info: Option, + authz_code_pkce_flow_info: Option, + user_info_endpoint: String, + default_scopes: Vec, + client: reqwest::Client, +} + +impl fmt::Debug for ZendeskOAuthProvider { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ZendeskOAuthProvider").finish() + } +} + +impl ZendeskOAuthProvider { + pub fn new( + config: ResolvedOAuthConfig, + public_url: String, + subdomain: String, + additional_scopes: Option>, + ) -> Self { + let base_url = format!("https://{}.zendesk.com", subdomain); + + let mut default_scopes = vec!["read".to_string(), "write".to_string()]; + default_scopes.extend(additional_scopes.unwrap_or_default()); + + let authz_code_flow_info = config.web.map(|web| OAuthProviderAuthorizationCodeInfo { + auth_url_endpoint: format!("{}/login/oauth/zendesk/code/authorize", public_url), + redirect_endpoint: format!("{}/login/oauth/zendesk/code/callback", public_url), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: format!("{}/login/oauth/zendesk/code/token", public_url), + remote: OAuthProviderAuthorizationCodeRemoteInfo { + client_id: web.remote_client_id, + client_secret: web.remote_client_secret.into(), + auth_url_endpoint: format!("{}/oauth/authorizations/new", base_url), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: format!("{}/oauth/tokens", base_url), + revocation_endpoint: None, + }, + }); + let authz_code_pkce_flow_info = config + .proxy_web + .and_then(|proxy| authz_code_flow_info.as_ref().map(|web| (web, proxy))) + .map(|(web, proxy)| OAuthProviderAuthorizationCodePkceInfo { + client_id: proxy.client_id, + redirect_endpoint: proxy.redirect_uri, + proxy_port: proxy.proxy_port, + web: web.clone(), + }); + + Self { + authz_code_flow_info, + authz_code_pkce_flow_info, + user_info_endpoint: format!("{}/api/v2/users/me.json", base_url), + default_scopes, + client: reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Static client must build"), + } + } + + pub fn with_client(&mut self, client: reqwest::Client) -> &mut Self { + self.client = client; + self + } +} + +#[derive(Debug, Deserialize)] +struct ZendeskUserResponse { + user: ZendeskUser, +} + +#[derive(Debug, Deserialize)] +struct ZendeskUser { + id: u64, + name: String, + email: String, + verified: bool, + suspended: bool, +} + +impl ExtractUserInfo for ZendeskOAuthProvider { + fn extract_user_info(&self, data: &[Bytes]) -> Result { + let response: ZendeskUserResponse = serde_json::from_slice(&data[0])?; + let user = response.user; + + if user.suspended { + return Err(UserInfoError::Locked); + } + + let verified_emails = if user.verified { + vec![user.email] + } else { + vec![] + }; + + Ok(UserInfo { + external_id: ExternalUserId::Zendesk(user.id.to_string()), + verified_emails, + display_name: Some(user.name), + idp_token: None, + }) + } +} + +impl OAuthProvider for ZendeskOAuthProvider { + fn name(&self) -> OAuthProviderName { + OAuthProviderName::Zendesk + } + fn initialize_headers(&self, _request: &mut Request) {} + fn client(&self) -> &reqwest::Client { + &self.client + } + fn user_info_endpoints(&self) -> Vec<&str> { + vec![&self.user_info_endpoint] + } + + fn expires_in(&self) -> Option { + // This is the maximum token duration that Zendesk supports. In the future we should make + // this configurable + Some(172800) + } + fn default_scopes(&self) -> &[String] { + &self.default_scopes + } + fn supports_pkce(&self) -> bool { + true + } + + fn authz_code_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodeInfo> { + self.authz_code_flow_info.as_ref() + } + fn authz_code_pkce_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodePkceInfo> { + self.authz_code_pkce_flow_info.as_ref() + } + fn device_code_flow_info(&self) -> Option<&OAuthProviderDeviceInfo> { + None + } +} diff --git a/v-api/src/secrets.rs b/v-api/src/secrets.rs index c5e8042e..8630e23d 100644 --- a/v-api/src/secrets.rs +++ b/v-api/src/secrets.rs @@ -10,7 +10,7 @@ use secrecy::{ExposeSecret, SecretString}; use serde::{Deserialize, Serialize, Serializer}; use std::borrow::Cow; -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct OpenApiSecretString(pub SecretString); impl From for OpenApiSecretString { diff --git a/v-cli-sdk/Cargo.toml b/v-cli-sdk/Cargo.toml new file mode 100644 index 00000000..4054def7 --- /dev/null +++ b/v-cli-sdk/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "v-cli-sdk" +version.workspace = true +edition.workspace = true +publish.workspace = true + +[dependencies] +anyhow = { workspace = true } +clap = { workspace = true } +http = { workspace = true } +http-body-util = { workspace = true } +hyper = { workspace = true, features = ["server", "http1"] } +hyper-util = { workspace = true, features = ["tokio"] } +oauth2 = { workspace = true } +oauth2-reqwest = { workspace = true } +owo-colors = { workspace = true } +progenitor-client = { workspace = true } +reqwest = { workspace = true } +schemars = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tabwriter = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt", "net", "sync"] } +uuid = { workspace = true } diff --git a/v-cli-sdk/src/cmd/auth/login.rs b/v-cli-sdk/src/cmd/auth/login.rs new file mode 100644 index 00000000..e641062f --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/login.rs @@ -0,0 +1,283 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::Result; +use clap::{Parser, Subcommand, ValueEnum}; +use oauth2::TokenResponse; +use std::{error::Error as StdError, fmt::Debug, future::Future, io::Write, pin::Pin, sync::Arc}; + +use crate::{ + VCliConfig, VCliContext, + cmd::auth::oauth::{self, CliOAuthAdapter, CliOAuthProviderInfo}, +}; + +pub trait CliAdapterToken { + fn access_token(&self) -> &str; + fn idp_token(&self) -> Option<&str>; +} + +pub trait CliConsumerLoginProvider: Into + Subcommand + Debug + Clone {} +impl CliConsumerLoginProvider for T where T: Into + Subcommand + Debug + Clone {} + +// Authenticates and generates an access token for interacting with the api +#[derive(Parser, Debug, Clone)] +#[clap(name = "login")] +pub struct Login +where + SupportedProviders: CliConsumerLoginProvider, +{ + #[command(subcommand)] + method: LoginMethod, + #[arg(short = 'm', default_value = "id")] + mode: AuthenticationMode, +} + +impl

Login

+where + P: CliConsumerLoginProvider, +{ + pub async fn run(&self, ctx: &mut T) -> Result<()> + where + T: VCliContext, + >::Error: StdError + Send + Sync + 'static, + { + let (access_token, idp_token) = self.method.run(ctx, self.mode).await?; + + ctx.config_mut().set_token(access_token); + ctx.config_mut().save()?; + + // If we are acquiring an IdP token, present it to the user. + if let Some(idp_token) = idp_token { + println!( + "\nYou can now additionally authenticate against the requested remote service API \ + with the following token." + ); + println!("IdP token: {}", idp_token); + println!(); + println!( + "Please note that this should be kept secure as calls made with this token are \ + made on behalf of your user acount" + ); + } + + Ok(()) + } +} + +#[derive(Subcommand, Debug, Clone)] +pub enum LoginMethod +where + SupportedProviders: Subcommand + Debug + Clone, +{ + #[command(name = "oauth")] + /// Login via OAuth + OAuth { + #[command(subcommand)] + provider: SupportedProviders, + /// Additionally retrieve a the underlying IdP token. This token is not stored. Remote mode + /// should be used when you need to authenticate to the underlying system frontend by the API + #[arg(long, default_value = "false")] + request_idp_token: bool, + }, + /// Login via Magic Link + #[command(name = "mlink")] + MagicLink { + /// Email recipient to login via + email: String, + /// Optional access scopes to apply to this session + scope: Option, + }, +} + +#[derive(Copy, Clone)] +pub enum LoginProvider { + Google, + GitHub, + Zendesk, +} + +#[derive(Copy, ValueEnum, Debug, Clone, PartialEq)] +pub enum AuthenticationMode { + /// Retrieve and store an identity token. Identity mode is the default and should be used to + /// when you do not require extended (multi-day) access + #[value(name = "id")] + Identity, + /// Retrieve and store an api token. Token mode should be used when you want to authenticate + /// a machine for continued access. This requires the permission to create api tokens + #[value(name = "token")] + Token, +} + +impl LoginMethod +where + SupportedProviders: CliConsumerLoginProvider, +{ + pub async fn run( + &self, + ctx: &T, + mode: AuthenticationMode, + ) -> Result<(String, Option)> + where + T: VCliContext, + >::Error: StdError + Send + Sync + 'static, + { + match self { + Self::OAuth { + provider, + request_idp_token, + } => { + let adapter = ctx.oauth_adapter(); + let provider = provider.clone().into(); + let provider = adapter.provider(provider).await?; + + // We now need to inspect the provider to determine the correct flow to use. If + // possible we use a limited input device flow, but not all providers support it. + // To handle those cases we need to use a proxy path that emulates an authorization + // code flow. + if provider.device_code_endpoint().is_some() { + if *request_idp_token { + anyhow::bail!( + "Remote token access is not supported via device authentication flow" + ); + } + Ok(( + self.run_oauth_device_provider(provider, mode, ctx.oauth_adapter()) + .await?, + None, + )) + } else if provider.supports_pkce_only() { + self.run_oauth_code_provider( + provider, + mode, + *request_idp_token, + ctx.oauth_adapter(), + ) + .await + } else { + anyhow::bail!("OAuth provider does not support any CLI authentication methods") + } + } + Self::MagicLink { email, scope } => Ok(( + self.run_magic_link(email, scope.as_deref(), ctx.mlink_adapter()) + .await?, + None, + )), + } + } + + async fn run_oauth_device_provider( + &self, + provider: V, + mode: AuthenticationMode, + adapter: T, + ) -> Result + where + T: CliOAuthAdapter, + V: CliOAuthProviderInfo, + { + let oauth_client = oauth::device::DeviceOAuth::new(provider)?; + let details = oauth_client.get_device_authorization().await?; + + println!( + "To complete login visit: {} and enter {}", + details.verification_uri().as_str(), + details.user_code().secret() + ); + + let token_response = oauth_client.login(&details).await; + + let identity_token = match token_response { + Ok(token) => Ok(token.access_token().to_owned()), + Err(err) => Err(anyhow::anyhow!("Authentication failed: {}", err)), + }?; + + match mode { + AuthenticationMode::Identity => Ok(identity_token.secret().to_string()), + AuthenticationMode::Token => { + let token = adapter + .get_long_lived_token(identity_token.secret()) + .await?; + Ok(token.access_token().to_string()) + } + } + } + + async fn run_oauth_code_provider( + &self, + provider: T, + mode: AuthenticationMode, + request_idp_token: bool, + adapter: V, + ) -> Result<(String, Option)> + where + T: CliOAuthProviderInfo, + V: CliOAuthAdapter + Send + Sync + 'static, + { + let oauth_client = oauth::code::CodeOAuth::new(provider)?; + let adapter = Arc::new(adapter); + + let identity_token = oauth_client + .login(Arc::clone(&adapter), request_idp_token) + .await?; + + let access_token = match mode { + AuthenticationMode::Identity => identity_token.access_token().to_string(), + AuthenticationMode::Token => { + let token = adapter + .get_long_lived_token(identity_token.access_token()) + .await?; + token.access_token().to_string() + } + }; + + let idp_token = if request_idp_token { + identity_token.idp_token().map(|s| s.to_string()) + } else { + None + }; + + Ok((access_token, idp_token)) + } + + async fn run_magic_link( + &self, + email: &str, + scope: Option<&str>, + adapter: T, + ) -> Result + where + T: CliMagicLinkAdapter, + { + let attempt = adapter.create_attempt(email, scope).await?; + + let mut auth_secret = String::new(); + print!("Enter the login token sent to the recipient: "); + std::io::stdout().flush()?; + std::io::stdin().read_line(&mut auth_secret)?; + + let token = adapter.exchange(attempt, email, &auth_secret).await?; + + Ok(token.access_token().to_string()) + } +} + +pub trait CliMagicLinkAdapter { + type Attempt; + type Token: CliAdapterToken; + type Error: StdError + Send + Sync + 'static; + + #[allow(clippy::type_complexity)] + fn create_attempt( + &self, + email: &str, + scope: Option<&str>, + ) -> Pin> + Send>>; + #[allow(clippy::type_complexity)] + fn exchange( + &self, + attempt: Self::Attempt, + email: &str, + token: &str, + ) -> Pin> + Send>>; +} diff --git a/v-cli-sdk/src/cmd/auth/mod.rs b/v-cli-sdk/src/cmd/auth/mod.rs new file mode 100644 index 00000000..d4b8bcaa --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/mod.rs @@ -0,0 +1,48 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::Result; +use clap::{Parser, Subcommand}; +use std::error::Error as StdError; + +use crate::{VCliContext, cmd::auth::login::CliConsumerLoginProvider}; + +pub mod login; +pub mod oauth; +pub mod proxy; + +// Authenticate against the Meetings API +#[derive(Parser, Debug)] +#[clap(name = "auth")] +pub struct Auth

+where + P: CliConsumerLoginProvider, +{ + #[command(subcommand)] + auth: AuthCommands

, +} + +#[derive(Subcommand, Debug, Clone)] +enum AuthCommands

+where + P: CliConsumerLoginProvider, +{ + /// Login via an authentication provider + Login(login::Login

), +} + +impl

Auth

+where + P: CliConsumerLoginProvider, +{ + pub async fn run(&self, ctx: &mut T) -> Result<()> + where + T: VCliContext, + >::Error: StdError + Send + Sync + 'static, + { + match &self.auth { + AuthCommands::Login(login) => login.run(ctx).await, + } + } +} diff --git a/v-cli-sdk/src/cmd/auth/oauth/code.rs b/v-cli-sdk/src/cmd/auth/oauth/code.rs new file mode 100644 index 00000000..92db7b45 --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/oauth/code.rs @@ -0,0 +1,236 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use std::{ + future::Future, + pin::Pin, + sync::{Arc, Mutex}, +}; + +use anyhow::Result; +use http::{Request, Response, StatusCode}; +use http_body_util::Full; +use hyper::body::{Bytes, Incoming}; + +use oauth2::{ + AuthType, AuthUrl, ClientId, CsrfToken, EndpointNotSet, EndpointSet, PkceCodeChallenge, + RedirectUrl, Scope, TokenUrl, basic::BasicClient, +}; +use tokio::sync::oneshot; +use uuid::Uuid; + +use crate::cmd::auth::{ + login::LoginProvider, + oauth::{CliOAuthAdapter, CliOAuthProviderInfo}, + proxy::run_proxy_server, +}; + +type CodeClient = BasicClient< + // HasAuthUrl + EndpointSet, + // HasDeviceAuthUrl + EndpointNotSet, + // HasIntrospectionUrl + EndpointNotSet, + // HasRevocationUrl + EndpointNotSet, + // HasTokenUrl + EndpointSet, +>; + +pub struct CodeOAuth { + provider: LoginProvider, + client: CodeClient, + client_id: Uuid, + redirect_uri: String, + scopes: Vec, + port: u16, +} + +impl CodeOAuth { + pub fn new(provider: T) -> Result + where + T: CliOAuthProviderInfo, + { + let client = BasicClient::new(ClientId::new(provider.client_id().to_string())) + .set_auth_uri(AuthUrl::new( + provider + .auth_url_endpoint() + .ok_or_else(|| { + anyhow::anyhow!("OAuth code flow provider must define an authorization url") + })? + .to_string(), + )?) + .set_auth_type(AuthType::RequestBody) + .set_token_uri(TokenUrl::new(provider.token_endpoint().to_string())?) + .set_redirect_uri(RedirectUrl::new( + provider + .redirect_endpoint() + .ok_or_else(|| { + anyhow::anyhow!("OAuth code flow provider must define a redirect url") + })? + .to_string(), + )?); + + Ok(Self { + provider: provider.provider(), + client, + client_id: provider.client_id(), + redirect_uri: provider.redirect_endpoint().unwrap_or_default().to_string(), + scopes: provider.scopes().iter().map(|s| s.to_string()).collect(), + port: provider.public_pkce_port().ok_or_else(|| { + anyhow::anyhow!("OAuth code flow provider must define a public proxy port") + })?, + }) + } + + /// Build the authorization URL that the user should visit in a browser. + /// Returns the full URL and the CSRF state token used for verification. + pub fn authorize_url( + &self, + pkce_challenge: PkceCodeChallenge, + ) -> (oauth2::url::Url, CsrfToken) { + let mut req = self + .client + .authorize_url(CsrfToken::new_random) + .set_pkce_challenge(pkce_challenge); + + for scope in &self.scopes { + req = req.add_scope(Scope::new(scope.to_string())); + } + + req.url() + } + + /// Run the full authorization code login flow: + /// + /// 1. Generate the authorization URL and print it for the user. + /// 2. Spin up a local HTTP proxy server to capture the IdP redirect. + /// 3. Forward the redirect request to the API server via the adapter. + /// 4. Extract the token from the server's response. + /// 5. Return a success page to the browser and shut down the proxy. + pub async fn login(&self, adapter: Arc, request_idp_token: bool) -> Result + where + T: CliOAuthAdapter + Send + Sync + 'static, + { + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + let (auth_url, _csrf_state) = self.authorize_url(pkce_challenge); + + println!( + "Open the following URL in your browser to authenticate:\n\n {}\n", + auth_url + ); + + // Channel to receive the token extracted from the server response. + let (token_tx, token_rx) = oneshot::channel::>(); + #[allow(clippy::type_complexity)] + let token_tx: Arc>>>> = + Arc::new(Mutex::new(Some(token_tx))); + + // Channel to shut down the proxy server once we have the token. + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + + let port = self.port; + + // Spawn the local proxy server in a background task. + tokio::spawn({ + let callback_token_tx = Arc::clone(&token_tx); + let error_token_tx = Arc::clone(&token_tx); + let client_id = self.client_id; + let redirect_uri = self.redirect_uri.clone(); + let provider = self.provider; + + async move { + let callback: crate::cmd::auth::proxy::Callback = Arc::new(Mutex::new(Some( + Box::new(move |request: Request| { + let adapter = Arc::clone(&adapter); + let token_tx = Arc::clone(&callback_token_tx); + + Box::pin(async move { + let code = request + .uri() + .query() + .and_then(|q: &str| { + q.split('&') + .filter_map(|pair: &str| pair.split_once('=')) + .find(|(key, _): &(&str, &str)| *key == "code") + .map(|(_, value): (&str, &str)| value.to_string()) + }) + .ok_or_else(|| { + anyhow::anyhow!( + "Missing 'code' query parameter in callback request" + ) + })?; + + // Forward the redirect request to the API server. + let token = adapter + .exchange_authorization_code(super::AuthorizationCodeExchange { + provider, + client_id, + redirect_uri: redirect_uri.clone(), + grant_type: "authorization_code".to_string(), + code, + pkce_verifier, + request_idp_token, + }) + .await + .map_err(|e| anyhow::anyhow!(e))?; + + // Send the token back to the main task. + if let Ok(mut guard) = token_tx.lock() + && let Some(tx) = guard.take() + { + let _ = tx.send(Ok(token)); + } + + // Return a friendly page to the browser so the user + // knows they can close the tab. + Ok(Response::builder() + .status(StatusCode::OK) + .header("content-type", "text/html; charset=utf-8") + .body(Full::new(Bytes::from(concat!( + "", + "", + "

Authentication successful. This window should close automatically.

", + "" + ))))?) + }) + as Pin< + Box< + dyn Future>>> + + Send, + >, + > + }), + ))); + + if let Err(e) = run_proxy_server(port, callback, shutdown_rx).await { + eprintln!("Proxy server error: {e}"); + + // If the proxy died before we got a token, unblock the + // receiver so the caller is not stuck forever. + if let Ok(mut guard) = error_token_tx.lock() + && let Some(tx) = guard.take() + { + let _ = tx.send(Err(anyhow::anyhow!( + "Proxy server exited unexpectedly: {e}" + ))); + } + } + } + }); + + // Wait for the proxy callback to extract the token. + let token = token_rx.await.map_err(|_| { + anyhow::anyhow!( + "Authentication callback was never received — proxy server may have exited early" + ) + })??; + + // Tell the proxy server to stop. + let _ = shutdown_tx.send(()); + + Ok(token) + } +} diff --git a/v-cli-sdk/src/cmd/auth/oauth/device.rs b/v-cli-sdk/src/cmd/auth/oauth/device.rs new file mode 100644 index 00000000..b1956dbb --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/oauth/device.rs @@ -0,0 +1,96 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::Result; +use oauth2::{ + AuthType, AuthUrl, ClientId, DeviceAuthorizationUrl, EmptyExtraTokenFields, EndpointNotSet, + EndpointSet, Scope, StandardDeviceAuthorizationResponse, StandardTokenResponse, TokenUrl, + basic::{BasicClient, BasicTokenType}, +}; + +use crate::cmd::auth::oauth::CliOAuthProviderInfo; + +type DeviceClient = BasicClient< + // HasAuthUrl + EndpointSet, + // HasDeviceAuthUrl + EndpointSet, + // HasIntrospectionUrl + EndpointNotSet, + // HasRevocationUrl + EndpointNotSet, + // HasTokenUrl + EndpointSet, +>; + +pub struct DeviceOAuth { + client: DeviceClient, + http: oauth2_reqwest::ReqwestClient, + scopes: Vec, +} + +impl DeviceOAuth { + pub fn new(provider: T) -> Result + where + T: CliOAuthProviderInfo, + { + if let Some(device_endpoint) = provider.device_code_endpoint() { + let device_auth_url = DeviceAuthorizationUrl::new(device_endpoint.to_string())?; + + let client = BasicClient::new(ClientId::new(provider.client_id().to_string())) + .set_auth_uri(AuthUrl::new( + provider + .device_code_endpoint() + .ok_or_else(|| { + anyhow::anyhow!( + "OAuth device flow provider must define an device code url" + ) + })? + .to_string(), + )?) + .set_auth_type(AuthType::RequestBody) + .set_token_uri(TokenUrl::new(provider.token_endpoint().to_string())?) + .set_device_authorization_url(device_auth_url); + + Ok(Self { + client, + http: oauth2_reqwest::ReqwestClient::from( + reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap(), + ), + scopes: provider.scopes().iter().map(|s| s.to_string()).collect(), + }) + } else { + anyhow::bail!("Device authorization is not supported by this provider") + } + } + + pub async fn login( + &self, + details: &StandardDeviceAuthorizationResponse, + ) -> Result> { + let token = self + .client + .exchange_device_access_token(details) + .set_max_backoff_interval(details.interval()) + .request_async(&self.http, tokio::time::sleep, Some(details.expires_in())) + .await; + + Ok(token?) + } + + pub async fn get_device_authorization(&self) -> Result { + let mut req = self.client.exchange_device_code(); + + for scope in &self.scopes { + req = req.add_scope(Scope::new(scope.to_string())); + } + + let res = req.request_async(&self.http).await; + + Ok(res?) + } +} diff --git a/v-cli-sdk/src/cmd/auth/oauth/mod.rs b/v-cli-sdk/src/cmd/auth/oauth/mod.rs new file mode 100644 index 00000000..4b0a0169 --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/oauth/mod.rs @@ -0,0 +1,59 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::Result; +use oauth2::PkceCodeVerifier; +use std::{error::Error as StdError, future::Future, pin::Pin}; +use uuid::Uuid; + +pub mod code; +pub mod device; + +use crate::cmd::auth::login::{CliAdapterToken, LoginProvider}; + +/// Parameters for exchanging an authorization code for an access token. +pub struct AuthorizationCodeExchange { + pub provider: super::login::LoginProvider, + pub client_id: Uuid, + pub redirect_uri: String, + pub grant_type: String, + pub code: String, + pub pkce_verifier: PkceCodeVerifier, + pub request_idp_token: bool, +} + +pub trait CliOAuthAdapter { + type ShortToken: CliAdapterToken + Send + 'static; + type LongToken: CliAdapterToken + Send + 'static; + type Error: StdError + Send + Sync + 'static; + + #[allow(clippy::type_complexity)] + fn provider( + &self, + provider: super::login::LoginProvider, + ) -> Pin> + Send>>; + #[allow(clippy::type_complexity)] + fn exchange_authorization_code( + &self, + exchange: AuthorizationCodeExchange, + ) -> Pin> + Send>>; + #[allow(clippy::type_complexity)] + fn get_long_lived_token( + &self, + access_token: &str, + ) -> Pin> + Send>>; +} + +pub trait CliOAuthProviderInfo { + fn provider(&self) -> LoginProvider; + fn client_id(&self) -> Uuid; + fn remote_client_id(&self) -> &str; + fn public_pkce_port(&self) -> Option; + fn supports_pkce_only(&self) -> bool; + fn device_code_endpoint(&self) -> Option<&str>; + fn auth_url_endpoint(&self) -> Option<&str>; + fn token_endpoint(&self) -> &str; + fn redirect_endpoint(&self) -> Option<&str>; + fn scopes(&self) -> &[String]; +} diff --git a/v-cli-sdk/src/cmd/auth/proxy.rs b/v-cli-sdk/src/cmd/auth/proxy.rs new file mode 100644 index 00000000..527584f7 --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/proxy.rs @@ -0,0 +1,157 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use std::future::Future; +use std::net::SocketAddr; +use std::sync::{Arc, Mutex}; + +use http_body_util::Full; +use hyper::body::{Bytes, Incoming}; +use hyper::service::service_fn; +use hyper::{Request, Response, StatusCode}; +use hyper_util::rt::TokioIo; +use tokio::net::TcpListener; +use tokio::sync::oneshot; + +/// A callback function that receives an incoming HTTP request and returns a response. +/// Wrapped in `Arc>>` so it can be called at most once. +pub type CallbackFn = Box< + dyn FnOnce( + Request, + ) -> std::pin::Pin< + Box>>> + Send>, + > + Send, +>; + +/// A shareable, single-use callback. The first request to arrive `.take()`s +/// the inner function; any subsequent request receives an error response. +pub type Callback = Arc>>; + +/// Start a minimal HTTP server on the given port that forwards every incoming +/// request to `callback` and returns whatever response the callback produces. +/// +/// The server will run until a message is sent on the `shutdown` channel, at +/// which point it will stop accepting new connections and return. +pub async fn run_proxy_server( + port: u16, + callback: Callback, + shutdown: oneshot::Receiver<()>, +) -> anyhow::Result<()> { + let addr = SocketAddr::from(([127, 0, 0, 1], port)); + let listener = TcpListener::bind(addr).await?; + serve_loop(listener, callback, shutdown).await +} + +/// Core accept-loop shared by [`run_proxy_server`] and tests. +/// +/// Accepts connections on `listener`, forwarding each request to `callback`. +/// Stops when `shutdown` fires. +async fn serve_loop( + listener: TcpListener, + callback: Callback, + shutdown: oneshot::Receiver<()>, +) -> anyhow::Result<()> { + tokio::pin!(shutdown); + + loop { + tokio::select! { + _ = &mut shutdown => { + break; + } + accepted = listener.accept() => { + let (stream, _remote_addr) = accepted?; + let io = TokioIo::new(stream); + let cb = Arc::clone(&callback); + + tokio::task::spawn(async move { + let service = service_fn(move |req: Request| { + let cb = Arc::clone(&cb); + async move { + let handler = cb + .lock() + .expect("callback mutex poisoned") + .take(); + + match handler { + Some(f) => f(req).await, + None => Ok(Response::builder() + .status(StatusCode::GONE) + .body(Full::new(Bytes::from( + "Callback has already been invoked", + ))) + .expect("building static response cannot fail")), + } + } + }); + + if let Err(err) = + hyper::server::conn::http1::Builder::new() + .serve_connection(io, service) + .await + { + eprintln!("Error serving connection: {err}"); + } + }); + } + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use hyper::StatusCode; + + #[tokio::test] + async fn test_proxy_server_responds() { + let callback: Callback = Arc::new(Mutex::new(Some(Box::new(|_req| { + Box::pin(async { + Ok(Response::builder() + .status(StatusCode::OK) + .body(Full::new(Bytes::from("hello from callback"))) + .unwrap()) + }) + })))); + + let (tx, rx) = oneshot::channel::<()>(); + + // Use port 0 to let the OS pick an available port. + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn({ + let callback = Arc::clone(&callback); + async move { + serve_loop(listener, callback, rx).await.unwrap(); + } + }); + + // Send a request to the server. + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://{}", local_addr)) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + assert_eq!(resp.text().await.unwrap(), "hello from callback"); + + // A second request should get a 410 GONE. + let resp2 = client + .get(format!("http://{}", local_addr)) + .send() + .await + .unwrap(); + + assert_eq!(resp2.status(), 410); + + // Shut down the server. + tx.send(()).unwrap(); + server_handle.await.unwrap(); + } +} diff --git a/v-cli-sdk/src/cmd/config/mod.rs b/v-cli-sdk/src/cmd/config/mod.rs new file mode 100644 index 00000000..1ba94478 --- /dev/null +++ b/v-cli-sdk/src/cmd/config/mod.rs @@ -0,0 +1,143 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::Result; +use clap::{Parser, Subcommand}; + +use crate::{FormatStyle, VCliContext}; + +pub trait VCliConfig { + fn host(&self) -> Option<&str>; + fn set_host(&mut self, host: String); + fn token(&self) -> Option<&str>; + fn set_token(&mut self, token: String); + fn default_format(&self) -> FormatStyle; + fn set_default_format(&mut self, format: FormatStyle); + fn mlink_redirect(&self) -> Option<&str>; + fn set_mlink_redirect(&mut self, redirect: String); + fn mlink_secret(&self) -> Option<&str>; + fn set_mlink_secret(&mut self, secret: String); + fn save(&self) -> Result<(), std::io::Error>; +} + +#[derive(Debug, Parser)] +#[clap(name = "config")] +pub struct ConfigCmd { + #[clap(subcommand)] + setting: SettingCmd, +} + +#[derive(Debug, Subcommand)] +pub enum SettingCmd { + /// Gets a setting + #[clap(subcommand, name = "get")] + Get(GetCmd), + /// Sets a setting + #[clap(subcommand, name = "set")] + Set(SetCmd), +} + +#[derive(Debug, Subcommand)] +pub enum GetCmd { + /// Get the default formatter to use when printing results + #[clap(name = "format")] + Format, + /// Get the configured API host in use + #[clap(name = "host")] + Host, + /// Get the configured access token + #[clap(name = "token")] + Token, + /// Get the configured magic redirect uri + #[clap(name = "mlink-redirect")] + MagicLinkRedirectUri, + /// Get the configured magic link secret + #[clap(name = "mlink-secret")] + MagicLinkSecret, +} + +#[derive(Debug, Subcommand)] +pub enum SetCmd { + /// Set the default formatter to use when printing results + #[clap(name = "format")] + Format { format: FormatStyle }, + /// Set the configured API host to use + #[clap(name = "host")] + Host { host: String }, + /// Set the configured magic redirect uri + #[clap(name = "mlink-redirect")] + MagicLinkRedirectUri { redirect: String }, + /// Set the configured magic link secret + #[clap(name = "mlink-secret")] + MagicLinkSecret { secret: String }, +} + +impl ConfigCmd { + pub async fn run(&self, ctx: &mut T) -> Result<()> + where + T: VCliContext, + { + match &self.setting { + SettingCmd::Get(get) => get.run(ctx.config()).await?, + SettingCmd::Set(set) => set.run(ctx.config_mut()).await?, + } + + Ok(()) + } +} + +impl GetCmd { + pub async fn run(&self, config: &T) -> Result<()> + where + T: VCliConfig, + { + match &self { + GetCmd::Format => { + println!("{}", config.default_format()); + } + GetCmd::Host => { + println!("{}", config.host().unwrap_or("None")); + } + GetCmd::Token => { + println!("{}", config.token().unwrap_or("None")); + } + GetCmd::MagicLinkRedirectUri => { + println!("{}", config.mlink_redirect().unwrap_or("None")); + } + GetCmd::MagicLinkSecret => { + println!("{}", config.mlink_secret().unwrap_or("None")); + } + } + + Ok(()) + } +} + +impl SetCmd { + pub async fn run(&self, config: &mut T) -> Result<()> + where + T: VCliConfig, + { + match &self { + SetCmd::Format { format } => { + config.set_default_format(*format); + config.save()?; + } + SetCmd::Host { host } => { + config.set_host(host.to_string()); + config.save()?; + } + SetCmd::MagicLinkRedirectUri { redirect } => { + config.set_mlink_redirect(redirect.to_string()); + config.save()?; + } + SetCmd::MagicLinkSecret { secret } => { + config.set_mlink_secret(secret.to_string()); + config.save()?; + } + } + + Ok(()) + } +} diff --git a/v-cli-sdk/src/cmd/mod.rs b/v-cli-sdk/src/cmd/mod.rs new file mode 100644 index 00000000..df67c4a3 --- /dev/null +++ b/v-cli-sdk/src/cmd/mod.rs @@ -0,0 +1,6 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +pub mod auth; +pub mod config; diff --git a/v-cli-sdk/src/err.rs b/v-cli-sdk/src/err.rs new file mode 100644 index 00000000..0d54b489 --- /dev/null +++ b/v-cli-sdk/src/err.rs @@ -0,0 +1,71 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::{Error, anyhow}; +use progenitor_client::Error as ProgenitorClientError; + +use crate::{VApiErrorMessage, VCliContext, VerbosityLevel}; + +pub fn format_api_err(ctx: &T, client_err: ProgenitorClientError) -> Error +where + T: VCliContext, + E: VApiErrorMessage, +{ + let mut err = anyhow!("API Request failed"); + + match client_err { + ProgenitorClientError::CommunicationError(inner) => { + if ctx.verbosity() >= VerbosityLevel::All { + err = err.context("Communication Error").context(inner); + } + } + ProgenitorClientError::ErrorResponse(response) => { + if ctx.verbosity() >= VerbosityLevel::All { + err = err.context(format!("Status: {}", response.status())); + err = err.context(format!("Headers {:?}", response.headers())); + } + + let response_message = response.into_inner(); + + if ctx.verbosity() >= VerbosityLevel::All { + err = err.context(format!( + "Request {}", + response_message.request_id().unwrap_or("") + )); + } + + err = err.context(format!( + "Code: {}", + response_message.error_code().unwrap_or("") + )); + err = err.context(response_message.message().unwrap_or("").to_string()); + } + ProgenitorClientError::InvalidRequest(message) => { + err = err.context("Invalid request").context(message); + } + ProgenitorClientError::InvalidResponsePayload(_, inner) => { + err = err.context("Invalid response").context(inner); + } + ProgenitorClientError::UnexpectedResponse(response) => { + err = err + .context("Unexpected response") + .context(format!("Status: {}", response.status())); + + if ctx.verbosity() >= VerbosityLevel::All { + err = err.context(format!("Headers {:?}", response.headers())); + } + } + ProgenitorClientError::ResponseBodyError(inner) => { + err = err.context("Invalid response").context(inner); + } + ProgenitorClientError::InvalidUpgrade(inner) => { + err = err.context("Invalid upgrade").context(inner) + } + ProgenitorClientError::Custom(inner) => { + err = err.context("Inner progenitor error").context(inner) + } + } + + err +} diff --git a/v-cli-sdk/src/lib.rs b/v-cli-sdk/src/lib.rs new file mode 100644 index 00000000..bd398174 --- /dev/null +++ b/v-cli-sdk/src/lib.rs @@ -0,0 +1,71 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use clap::ValueEnum; +use serde::{Deserialize, Serialize}; +use std::fmt::Display; + +pub mod cmd; +pub mod err; +pub mod printer; + +use crate::cmd::auth::{login::CliMagicLinkAdapter, oauth::CliOAuthAdapter}; +pub use cmd::config::VCliConfig; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum VerbosityLevel { + None, + All, +} + +#[derive( + Copy, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Clone, Serialize, Deserialize, Default, +)] +pub enum FormatStyle { + #[default] + #[value(name = "json")] + Json, + #[value(name = "tab")] + Tab, +} + +impl Display for FormatStyle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Json => write!(f, "json"), + Self::Tab => write!(f, "tab"), + } + } +} + +pub trait VCliContext { + type ShortToken; + type LongToken; + type Error; + + fn config(&self) -> &impl VCliConfig; + fn config_mut(&mut self) -> &mut impl VCliConfig; + fn client(&self) -> Option; + fn printer(&self) -> Option<&P>; + fn verbosity(&self) -> VerbosityLevel; + + fn oauth_adapter( + &self, + ) -> impl CliOAuthAdapter< + ShortToken = Self::ShortToken, + LongToken = Self::LongToken, + Error = Self::Error, + > + Send + + Sync + + 'static; + fn mlink_adapter( + &self, + ) -> impl CliMagicLinkAdapter + Send + Sync + 'static; +} + +pub trait VApiErrorMessage { + fn message(&self) -> Option<&str>; + fn error_code(&self) -> Option<&str>; + fn request_id(&self) -> Option<&str>; +} diff --git a/v-cli-sdk/src/printer/mod.rs b/v-cli-sdk/src/printer/mod.rs new file mode 100644 index 00000000..3fc117e1 --- /dev/null +++ b/v-cli-sdk/src/printer/mod.rs @@ -0,0 +1,227 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use owo_colors::{OwoColorize, Style}; +use serde::Serialize; +use std::io::Write; +use tabwriter::TabWriter; + +#[derive(Debug, Clone)] +pub enum Printer { + Json, + Tab, +} + +pub trait CliOutput { + fn output_error(&self, value: &progenitor_client::Error) + where + T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug; +} + +impl Printer { + /// Print any serializable response object in the configured format. + /// + /// - `Json` mode emits compact, single-line JSON. + /// - `Tab` mode serializes to a `serde_json::Value` and pretty-prints it + /// with tab-aligned key/value pairs. + pub fn print_response(&self, value: &T) + where + T: Serialize, + { + let json_value = serde_json::to_value(value) + .unwrap_or_else(|e| serde_json::Value::String(format!("", e))); + + match self { + Printer::Json => { + println!("{}", serde_json::to_string(&json_value).unwrap_or_default()); + } + Printer::Tab => { + let styles = TabStyles::default(); + let mut tw = TabWriter::new(vec![]).ansi(true); + pretty_print_value(&mut tw, &json_value, 0, &styles); + tw.flush().unwrap(); + let output = String::from_utf8(tw.into_inner().unwrap()).unwrap(); + print!("{}", output); + } + } + } + + /// Print an error from a progenitor client response. + /// + /// A 401 Unauthorized is treated specially: instead of dumping the raw + /// server error we print a short, actionable message telling the user to + /// authenticate first. + pub fn print_error_response(&self, value: &progenitor_client::Error) + where + T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug, + { + // Check for 401 Unauthorized up-front, regardless of output format. + if let Some(status) = value.status() + && status == reqwest::StatusCode::UNAUTHORIZED + { + eprintln!("Authentication required. Please run `auth login` first."); + return; + } + + match self { + Printer::Json => { + // For JSON mode, try to extract a serializable body from the + // error and fall back to the Debug representation. + let msg = match value { + progenitor_client::Error::ErrorResponse(rv) => { + serde_json::to_string(rv.as_ref()).ok() + } + _ => None, + }; + eprintln!("{}", msg.unwrap_or_else(|| format!("{:?}", value))); + } + Printer::Tab => { + eprintln!("{}", value); + } + } + } +} + +impl CliOutput for Printer { + fn output_error(&self, value: &progenitor_client::Error) + where + T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug, + { + self.print_error_response(value); + } +} + +// --------------------------------------------------------------------------- +// Tab-indented pretty-printer for serde_json::Value +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone)] +struct TabStyles { + label: Style, + value: Style, + null: Style, +} + +impl Default for TabStyles { + fn default() -> Self { + TabStyles { + label: Style::new().bold(), + value: Style::new(), + null: Style::new().dimmed(), + } + } +} + +fn indent(tw: &mut TabWriter>, depth: usize) { + for _ in 0..depth { + let _ = write!(tw, "\t"); + } +} + +fn pretty_print_value( + tw: &mut TabWriter>, + value: &serde_json::Value, + depth: usize, + styles: &TabStyles, +) { + match value { + serde_json::Value::Object(map) => { + for (key, val) in map { + pretty_print_field(tw, key, val, depth, styles); + } + } + serde_json::Value::Array(arr) => { + for (i, val) in arr.iter().enumerate() { + indent(tw, depth); + let _ = writeln!(tw, "{}", format!("[{}]", i).style(styles.label),); + pretty_print_value(tw, val, depth + 1, styles); + } + } + _ => { + indent(tw, depth); + let _ = writeln!(tw, "{}", format_scalar(value, styles)); + } + } +} + +fn pretty_print_field( + tw: &mut TabWriter>, + key: &str, + value: &serde_json::Value, + depth: usize, + styles: &TabStyles, +) { + match value { + serde_json::Value::Object(_) => { + indent(tw, depth); + let _ = writeln!(tw, "{}:", key.style(styles.label)); + pretty_print_value(tw, value, depth + 1, styles); + } + serde_json::Value::Array(arr) if arr.is_empty() => { + indent(tw, depth); + let _ = writeln!( + tw, + "{}:\t{}", + key.style(styles.label), + "[]".style(styles.null), + ); + } + serde_json::Value::Array(arr) if arr.iter().all(is_scalar) => { + // Print simple arrays inline, one value per line with the key on + // the first line only (mimics the existing TabDisplay list style). + for (i, val) in arr.iter().enumerate() { + indent(tw, depth); + if i == 0 { + let _ = writeln!( + tw, + "{}:\t{}", + key.style(styles.label), + format_scalar(val, styles), + ); + } else { + let _ = writeln!(tw, "\t{}", format_scalar(val, styles)); + } + } + } + serde_json::Value::Array(arr) => { + indent(tw, depth); + let _ = writeln!(tw, "{}:", key.style(styles.label)); + for (i, val) in arr.iter().enumerate() { + indent(tw, depth + 1); + let _ = writeln!(tw, "{}", format!("[{}]", i).style(styles.label),); + pretty_print_value(tw, val, depth + 2, styles); + } + } + _ => { + indent(tw, depth); + let _ = writeln!( + tw, + "{}:\t{}", + key.style(styles.label), + format_scalar(value, styles), + ); + } + } +} + +fn is_scalar(value: &serde_json::Value) -> bool { + matches!( + value, + serde_json::Value::Null + | serde_json::Value::Bool(_) + | serde_json::Value::Number(_) + | serde_json::Value::String(_) + ) +} + +fn format_scalar(value: &serde_json::Value, styles: &TabStyles) -> String { + match value { + serde_json::Value::Null => format!("{}", "null".style(styles.null)), + serde_json::Value::Bool(b) => format!("{}", b.style(styles.value)), + serde_json::Value::Number(n) => format!("{}", n.style(styles.value)), + serde_json::Value::String(s) => format!("{}", s.style(styles.value)), + // Fallback for non-scalars that end up here + other => format!("{}", other.style(styles.value)), + } +} diff --git a/v-model/Cargo.toml b/v-model/Cargo.toml index b4f31561..ab5d5289 100644 --- a/v-model/Cargo.toml +++ b/v-model/Cargo.toml @@ -25,6 +25,7 @@ serde_json = { workspace = true } steno = { workspace = true, optional = true } thiserror = { workspace = true } tracing = { workspace = true } +url = { workspace = true } uuid = { workspace = true, features = ["v4", "serde"] } [dev-dependencies] diff --git a/v-model/src/lib.rs b/v-model/src/lib.rs index b02c4efa..ac7b2f7e 100644 --- a/v-model/src/lib.rs +++ b/v-model/src/lib.rs @@ -15,11 +15,9 @@ use schema_ext::MagicLinkAttemptState; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::{ - collections::{BTreeMap, BTreeSet}, - fmt::Display, -}; +use std::{collections::BTreeSet, fmt::Display}; use thiserror::Error; +use url::Url; pub mod db; pub mod permissions; @@ -282,26 +280,22 @@ pub struct LoginAttempt { } impl LoginAttempt { - pub fn callback_url(&self) -> String { - let mut params = BTreeMap::new(); - - if let Some(state) = &self.state { - params.insert("state", state); - } - - if let Some(error) = &self.error { - params.insert("error", error); - } else if let Some(authz_code) = &self.authz_code { - params.insert("code", authz_code); + pub fn callback_url(&self) -> Result { + let mut url = Url::parse(&self.redirect_uri)?; + + { + let mut pairs = url.query_pairs_mut(); + if let Some(state) = &self.state { + pairs.append_pair("state", state); + } + if let Some(error) = &self.error { + pairs.append_pair("error", error); + } else if let Some(authz_code) = &self.authz_code { + pairs.append_pair("code", authz_code); + } } - let query_string = params - .into_iter() - .map(|(k, v)| format!("{}={}", k, v)) - .collect::>() - .join("&"); - - [self.redirect_uri.as_str(), query_string.as_str()].join("?") + Ok(url.to_string()) } } @@ -324,6 +318,13 @@ impl NewLoginAttempt { redirect_uri: String, scope: String, ) -> Result { + // Validate that the redirect URI is a well-formed URL. This ensures + // callback_url() can always parse it later. + Url::parse(&redirect_uri).map_err(|err| InvalidValueError { + field: "redirect_uri".to_string(), + error: format!("Invalid URL: {}", err), + })?; + Ok(Self { id: TypedUuid::new_v4(), attempt_state: LoginAttemptState::New, diff --git a/v-model/src/schema_ext.rs b/v-model/src/schema_ext.rs index 8ac7126d..769280c3 100644 --- a/v-model/src/schema_ext.rs +++ b/v-model/src/schema_ext.rs @@ -51,7 +51,16 @@ macro_rules! sql_conversion { } #[derive( - Debug, PartialEq, Clone, FromSqlRow, AsExpression, Serialize, Deserialize, JsonSchema, Default, + Copy, + Debug, + PartialEq, + Clone, + FromSqlRow, + AsExpression, + Serialize, + Deserialize, + JsonSchema, + Default, )] #[diesel(sql_type = AttemptState)] #[serde(rename_all = "lowercase")] diff --git a/v-model/src/storage/mod.rs b/v-model/src/storage/mod.rs index 8102f3be..ae7b5d83 100644 --- a/v-model/src/storage/mod.rs +++ b/v-model/src/storage/mod.rs @@ -257,6 +257,7 @@ pub struct LoginAttemptFilter { pub client_id: Option>>, pub attempt_state: Option>, pub authz_code: Option>, + pub provider: Option>, } #[cfg_attr(feature = "mock", automock)] @@ -270,6 +271,15 @@ pub trait LoginAttemptStore { pagination: &ListPagination, ) -> Result, StoreError>; async fn upsert(&self, attempt: NewLoginAttempt) -> Result; + /// Atomically update a login attempt, but only if it is currently in the `expected_state`. + /// The `attempt` must have the desired target state and any other field updates already set. + /// Returns `StoreError::InvariantFailed` if the attempt is not in the expected state, + /// which prevents TOCTOU races on state transitions. + async fn update_if_state( + &self, + attempt: NewLoginAttempt, + expected_state: LoginAttemptState, + ) -> Result; } #[derive(Debug, Default)] diff --git a/v-model/src/storage/postgres.rs b/v-model/src/storage/postgres.rs index d645bc55..10224d0e 100644 --- a/v-model/src/storage/postgres.rs +++ b/v-model/src/storage/postgres.rs @@ -39,7 +39,7 @@ use crate::{ magic_link_client_redirect_uri, magic_link_client_secret, mapper, oauth_client, oauth_client_redirect_uri, oauth_client_secret, }, - schema_ext::MagicLinkAttemptState, + schema_ext::{LoginAttemptState, MagicLinkAttemptState}, storage::{LinkRequestFilter, LinkRequestStore, StoreError}, }; @@ -705,6 +705,7 @@ impl LoginAttemptStore for PostgresStore { client_id, attempt_state, authz_code, + provider, } = filter; if let Some(id) = id { @@ -731,6 +732,10 @@ impl LoginAttemptStore for PostgresStore { query = query.filter(login_attempt::authz_code.eq_any(authz_code)); } + if let Some(provider) = provider { + query = query.filter(login_attempt::provider.eq_any(provider)); + } + let results = query .offset(pagination.offset) .limit(pagination.limit) @@ -775,6 +780,38 @@ impl LoginAttemptStore for PostgresStore { Ok(attempt_m.into()) } + + async fn update_if_state( + &self, + attempt: NewLoginAttempt, + expected_state: LoginAttemptState, + ) -> Result { + let conn = self.pool.get().await?; + let result: Option = update( + login_attempt::dsl::login_attempt + .filter(login_attempt::id.eq(attempt.id.into_untyped_uuid())) + .filter(login_attempt::attempt_state.eq(expected_state)), + ) + .set(( + login_attempt::attempt_state.eq(attempt.attempt_state), + login_attempt::authz_code.eq(attempt.authz_code), + login_attempt::expires_at.eq(attempt.expires_at), + login_attempt::error.eq(attempt.error), + login_attempt::provider_authz_code.eq(attempt.provider_authz_code), + login_attempt::provider_error.eq(attempt.provider_error), + )) + .get_result_async::(&*conn) + .await + .optional()?; + + match result { + Some(attempt) => Ok(LoginAttempt::from(attempt)), + None => Err(StoreError::InvariantFailed(format!( + "Login attempt {} is not in expected state for transition to {}", + attempt.id, attempt.attempt_state, + ))), + } + } } #[async_trait]