From a71b6024a3d6e49565a522e1351d5a45d6dec9dd Mon Sep 17 00:00:00 2001 From: augustuswm Date: Tue, 28 Apr 2026 15:20:38 -0500 Subject: [PATCH 01/51] Add Zendesk as an OAuth backend --- v-api/src/config.rs | 1 + v-api/src/endpoints/login/mod.rs | 10 + .../src/endpoints/login/oauth/device_token.rs | 18 ++ v-api/src/endpoints/login/oauth/github.rs | 4 +- v-api/src/endpoints/login/oauth/google.rs | 4 +- v-api/src/endpoints/login/oauth/mod.rs | 9 +- v-api/src/endpoints/login/oauth/zendesk.rs | 173 ++++++++++++++++++ 7 files changed, 212 insertions(+), 7 deletions(-) create mode 100644 v-api/src/endpoints/login/oauth/zendesk.rs diff --git a/v-api/src/config.rs b/v-api/src/config.rs index 22789099..917063ce 100644 --- a/v-api/src/config.rs +++ b/v-api/src/config.rs @@ -151,6 +151,7 @@ pub struct SendGridConfig { pub struct OAuthProviders { pub github: Option, pub google: Option, + pub zendesk: Option, } #[derive(Debug, Deserialize)] diff --git a/v-api/src/endpoints/login/mod.rs b/v-api/src/endpoints/login/mod.rs index f9c471e4..d6c30b30 100644 --- a/v-api/src/endpoints/login/mod.rs +++ b/v-api/src/endpoints/login/mod.rs @@ -60,6 +60,7 @@ impl From for HttpError { pub enum ExternalUserId { GitHub(String), Google(String), + Zendesk(String), #[cfg(feature = "local-dev")] Local(String), MagicLink(String), @@ -70,6 +71,7 @@ impl ExternalUserId { match self { Self::GitHub(id) => id, Self::Google(id) => id, + Self::Zendesk(id) => id, #[cfg(feature = "local-dev")] Self::Local(id) => id, Self::MagicLink(id) => id, @@ -80,6 +82,7 @@ impl ExternalUserId { match self { Self::GitHub(_) => "github", Self::Google(_) => "google", + Self::Zendesk(_) => "zendesk", #[cfg(feature = "local-dev")] Self::Local(_) => "local", Self::MagicLink(_) => "magic-link", @@ -103,6 +106,7 @@ impl Serialize for ExternalUserId { match self { ExternalUserId::GitHub(id) => serializer.serialize_str(&format!("github-{}", id)), ExternalUserId::Google(id) => serializer.serialize_str(&format!("google-{}", id)), + ExternalUserId::Zendesk(id) => serializer.serialize_str(&format!("zendesk-{}", id)), #[cfg(feature = "local-dev")] ExternalUserId::Local(id) => serializer.serialize_str(&format!("local-{}", id)), ExternalUserId::MagicLink(id) => { @@ -142,6 +146,12 @@ impl<'de> Deserialize<'de> for ExternalUserId { } else { Err(de::Error::custom(ExternalUserIdDeserializeError::Empty)) } + } else if let Some(("", id)) = value.split_once("zendesk-") { + if !id.is_empty() { + Ok(ExternalUserId::Zendesk(id.to_string())) + } else { + Err(de::Error::custom(ExternalUserIdDeserializeError::Empty)) + } } else if let Some(("", id)) = value.split_once("local-") { #[cfg(feature = "local-dev")] { diff --git a/v-api/src/endpoints/login/oauth/device_token.rs b/v-api/src/endpoints/login/oauth/device_token.rs index 0c9d5c11..5fb2eb04 100644 --- a/v-api/src/endpoints/login/oauth/device_token.rs +++ b/v-api/src/endpoints/login/oauth/device_token.rs @@ -124,6 +124,24 @@ where tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for token exchange"); + if provider.device_code_endpoint().is_none() { + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .header(header::CONTENT_TYPE, "application/json") + .body( + serde_json::to_vec(&ProxyTokenError { + error: "unsupported_grant_type".to_string(), + error_description: Some(format!( + "{} does not support device code flow", + path.provider + )), + error_uri: None, + }) + .unwrap() + .into(), + )?); + } + let exchange_request = body.into_inner(); if let Some(exchange) = AccessTokenExchange::new(exchange_request, &*provider) { diff --git a/v-api/src/endpoints/login/oauth/github.rs b/v-api/src/endpoints/login/oauth/github.rs index bce13561..614f7d36 100644 --- a/v-api/src/endpoints/login/oauth/github.rs +++ b/v-api/src/endpoints/login/oauth/github.rs @@ -154,8 +154,8 @@ impl OAuthProvider for GitHubOAuthProvider { ] } - fn device_code_endpoint(&self) -> &str { - "https://github.com/login/device/code" + fn device_code_endpoint(&self) -> Option<&str> { + Some("https://github.com/login/device/code") } fn auth_url_endpoint(&self) -> &str { diff --git a/v-api/src/endpoints/login/oauth/google.rs b/v-api/src/endpoints/login/oauth/google.rs index a38024d7..aacb4837 100644 --- a/v-api/src/endpoints/login/oauth/google.rs +++ b/v-api/src/endpoints/login/oauth/google.rs @@ -162,8 +162,8 @@ impl OAuthProvider for GoogleOAuthProvider { ] } - fn device_code_endpoint(&self) -> &str { - "https://oauth2.googleapis.com/device/code" + fn device_code_endpoint(&self) -> Option<&str> { + Some("https://oauth2.googleapis.com/device/code") } fn auth_url_endpoint(&self) -> &str { diff --git a/v-api/src/endpoints/login/oauth/mod.rs b/v-api/src/endpoints/login/oauth/mod.rs index f89d822d..08d2edde 100644 --- a/v-api/src/endpoints/login/oauth/mod.rs +++ b/v-api/src/endpoints/login/oauth/mod.rs @@ -27,6 +27,7 @@ pub mod code; pub mod device_token; pub mod github; pub mod google; +pub mod zendesk; #[derive(Debug, Error)] pub enum OAuthProviderError { @@ -76,7 +77,7 @@ pub trait OAuthProvider: ExtractUserInfo + Debug + Send + Sync { // TODO: How can user info be change to something statically checked instead of a runtime check fn user_info_endpoints(&self) -> Vec<&str>; - fn device_code_endpoint(&self) -> &str; + fn device_code_endpoint(&self) -> Option<&str>; fn auth_url_endpoint(&self) -> &str; fn token_exchange_content_type(&self) -> &str; fn token_exchange_endpoint(&self) -> &str; @@ -88,7 +89,7 @@ pub trait OAuthProvider: ExtractUserInfo + Debug + Send + Sync { provider: self.name(), client_id: self.client_id(client_type).to_string(), auth_url_endpoint: self.auth_url_endpoint().to_string(), - device_code_endpoint: self.device_code_endpoint().to_string(), + device_code_endpoint: self.device_code_endpoint().map(|s| s.to_string()), token_endpoint: format!("{}/login/oauth/{}/device/exchange", public_url, self.name(),), scopes: self .scopes() @@ -171,7 +172,7 @@ pub struct OAuthProviderInfo { provider: OAuthProviderName, client_id: String, auth_url_endpoint: String, - device_code_endpoint: String, + device_code_endpoint: Option, token_endpoint: String, scopes: Vec, } @@ -182,6 +183,7 @@ pub enum OAuthProviderName { #[serde(rename = "github")] GitHub, Google, + Zendesk, } impl Display for OAuthProviderName { @@ -189,6 +191,7 @@ impl Display for OAuthProviderName { match self { OAuthProviderName::GitHub => write!(f, "github"), OAuthProviderName::Google => write!(f, "google"), + OAuthProviderName::Zendesk => write!(f, "zendesk"), } } } diff --git a/v-api/src/endpoints/login/oauth/zendesk.rs b/v-api/src/endpoints/login/oauth/zendesk.rs new file mode 100644 index 00000000..6b3c56e1 --- /dev/null +++ b/v-api/src/endpoints/login/oauth/zendesk.rs @@ -0,0 +1,173 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use hyper::body::Bytes; +use reqwest::Request; +use secrecy::SecretString; +use serde::Deserialize; +use std::fmt; + +use crate::endpoints::login::{ExternalUserId, UserInfo, UserInfoError}; + +use super::{ + ClientType, ExtractUserInfo, OAuthPrivateCredentials, OAuthProvider, OAuthProviderName, + OAuthPublicCredentials, +}; + +pub struct ZendeskOAuthProvider { + device_public: OAuthPublicCredentials, + device_private: Option, + web_public: OAuthPublicCredentials, + web_private: Option, + additional_scopes: Vec, + client: reqwest::Client, + user_info_endpoint: String, + auth_url_endpoint: String, + token_exchange_endpoint: String, +} + +impl fmt::Debug for ZendeskOAuthProvider { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ZendeskOAuthProvider").finish() + } +} + +impl ZendeskOAuthProvider { + pub fn new( + subdomain: String, + device_client_id: String, + device_client_secret: SecretString, + web_client_id: String, + web_client_secret: SecretString, + additional_scopes: Option>, + ) -> Self { + let base_url = format!("https://{}.zendesk.com", subdomain); + + Self { + device_public: OAuthPublicCredentials { + client_id: device_client_id, + }, + device_private: Some(OAuthPrivateCredentials { + client_secret: device_client_secret, + }), + web_public: OAuthPublicCredentials { + client_id: web_client_id, + }, + web_private: Some(OAuthPrivateCredentials { + client_secret: web_client_secret, + }), + additional_scopes: additional_scopes.unwrap_or_default(), + client: reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Static client must build"), + user_info_endpoint: format!("{}/api/v2/users/me.json", base_url), + auth_url_endpoint: format!("{}/oauth/authorizations/new", base_url), + token_exchange_endpoint: format!("{}/oauth/tokens", base_url), + } + } + + pub fn with_client(&mut self, client: reqwest::Client) -> &mut Self { + self.client = client; + self + } +} + +#[derive(Debug, Deserialize)] +struct ZendeskUserResponse { + user: ZendeskUser, +} + +#[derive(Debug, Deserialize)] +struct ZendeskUser { + id: u64, + name: String, + email: String, + verified: bool, +} + +impl ExtractUserInfo for ZendeskOAuthProvider { + fn extract_user_info(&self, data: &[Bytes]) -> Result { + let response: ZendeskUserResponse = serde_json::from_slice(&data[0])?; + let user = response.user; + + let verified_emails = if user.verified { + vec![user.email] + } else { + vec![] + }; + + Ok(UserInfo { + external_id: ExternalUserId::Zendesk(user.id.to_string()), + verified_emails, + display_name: Some(user.name), + }) + } +} + +impl OAuthProvider for ZendeskOAuthProvider { + fn name(&self) -> OAuthProviderName { + OAuthProviderName::Zendesk + } + + fn scopes(&self) -> Vec<&str> { + let mut default = vec!["users:read"]; + default.extend(self.additional_scopes.iter().map(|s| s.as_str())); + default + } + + fn initialize_headers(&self, _request: &mut Request) {} + + fn client(&self) -> &reqwest::Client { + &self.client + } + + fn client_id(&self, client_type: &ClientType) -> &str { + match client_type { + ClientType::Device => &self.device_public.client_id, + ClientType::Web => &self.web_public.client_id, + } + } + + fn client_secret(&self, client_type: &ClientType) -> Option<&SecretString> { + match client_type { + ClientType::Device => self + .device_private + .as_ref() + .map(|private| &private.client_secret), + ClientType::Web => self + .web_private + .as_ref() + .map(|private| &private.client_secret), + } + } + + fn user_info_endpoints(&self) -> Vec<&str> { + vec![&self.user_info_endpoint] + } + + fn device_code_endpoint(&self) -> Option<&str> { + None + } + + fn auth_url_endpoint(&self) -> &str { + &self.auth_url_endpoint + } + + fn token_exchange_content_type(&self) -> &str { + "application/x-www-form-urlencoded" + } + + fn token_exchange_endpoint(&self) -> &str { + &self.token_exchange_endpoint + } + + fn token_revocation_endpoint(&self) -> Option<&str> { + None + } + + fn supports_pkce(&self) -> bool { + false + } +} From c1988ab054028ee8cca0a3197588b1da5f0f08a8 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 29 Apr 2026 10:10:10 -0500 Subject: [PATCH 02/51] Add experimental crate for cli helpers --- Cargo.lock | 35 ++++++- Cargo.toml | 2 + v-cli-sdk/Cargo.toml | 15 +++ v-cli-sdk/src/cmd/auth/login.rs | 162 ++++++++++++++++++++++++++++++++ v-cli-sdk/src/cmd/auth/mod.rs | 34 +++++++ v-cli-sdk/src/cmd/auth/oauth.rs | 91 ++++++++++++++++++ v-cli-sdk/src/cmd/config/mod.rs | 134 ++++++++++++++++++++++++++ v-cli-sdk/src/cmd/mod.rs | 2 + v-cli-sdk/src/err.rs | 64 +++++++++++++ v-cli-sdk/src/lib.rs | 52 ++++++++++ 10 files changed, 589 insertions(+), 2 deletions(-) create mode 100644 v-cli-sdk/Cargo.toml create mode 100644 v-cli-sdk/src/cmd/auth/login.rs create mode 100644 v-cli-sdk/src/cmd/auth/mod.rs create mode 100644 v-cli-sdk/src/cmd/auth/oauth.rs create mode 100644 v-cli-sdk/src/cmd/config/mod.rs create mode 100644 v-cli-sdk/src/cmd/mod.rs create mode 100644 v-cli-sdk/src/err.rs create mode 100644 v-cli-sdk/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index c64419dd..ee5d1518 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2272,6 +2272,21 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "progenitor-client" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e8a874cf25a33cac7a01b9c1de87bcfbc8aea93f3156d09dcc3bee516a78926" +dependencies = [ + "bytes", + "futures-core", + "percent-encoding", + "reqwest 0.13.2", + "serde", + "serde_json", + "serde_urlencoded", +] + [[package]] name = "quinn" version = "0.11.9" @@ -2558,6 +2573,7 @@ dependencies = [ "rustls-platform-verifier", "serde", "serde_json", + "serde_urlencoded", "sync_wrapper", "tokio", "tokio-rustls 0.26.4", @@ -2732,7 +2748,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -3858,6 +3874,21 @@ dependencies = [ "v-model", ] +[[package]] +name = "v-cli-sdk" +version = "0.2.0" +dependencies = [ + "anyhow", + "chrono", + "clap", + "oauth2", + "oauth2-reqwest", + "progenitor-client", + "reqwest 0.13.2", + "serde", + "tokio", +] + [[package]] name = "v-model" version = "0.2.0" @@ -4118,7 +4149,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 7d58a2ae..da03d8db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "v-api-installer", "v-api-param", "v-api-permission-derive", + "v-cli-sdk", "v-model", "xtask" ] @@ -37,6 +38,7 @@ oauth2-reqwest = "0.1.0-alpha.3" partial-struct = { git = "https://github.com/oxidecomputer/partial-struct" } percent-encoding = "2.3.2" proc-macro2 = "1" +progenitor-client = "0.14.0" quote = "1" rand = "0.10.1" rand_core = "0.10.1" diff --git a/v-cli-sdk/Cargo.toml b/v-cli-sdk/Cargo.toml new file mode 100644 index 00000000..52f46795 --- /dev/null +++ b/v-cli-sdk/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "v-cli-sdk" +version = "0.2.0" +edition = "2021" + +[dependencies] +anyhow = { workspace = true } +chrono = { workspace = true } +clap = { workspace = true } +oauth2 = { workspace = true } +oauth2-reqwest = { workspace = true } +progenitor-client = { workspace = true } +reqwest = { workspace = true } +serde = { workspace = true } +tokio = { workspace = true } diff --git a/v-cli-sdk/src/cmd/auth/login.rs b/v-cli-sdk/src/cmd/auth/login.rs new file mode 100644 index 00000000..089e2808 --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/login.rs @@ -0,0 +1,162 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::Result; +use clap::{Parser, Subcommand, ValueEnum}; +use oauth2::TokenResponse; +use std::{error::Error as StdError, future::Future, io::Write, pin::Pin}; + +use crate::{CliContext, cmd::{auth::oauth::{self, CliOAuthProviderInfo}, config::CliConfig}}; + +// Authenticates and generates an access token for interacting with the api +#[derive(Parser, Debug, Clone)] +#[clap(name = "login")] +pub struct Login { + #[command(subcommand)] + method: LoginMethod, + #[arg(short = 'm', default_value = "id")] + mode: AuthenticationMode, +} + +impl Login { + pub async fn run(&self, ctx: &mut T) -> Result<()> + where + T: CliContext, + { + let access_token = self.method.run(ctx, &self.mode).await?; + + ctx.config_mut().set_token(access_token); + ctx.config_mut().save()?; + + Ok(()) + } +} + +#[derive(Subcommand, Debug, Clone)] +pub enum LoginMethod { + #[command(name = "oauth")] + /// Login via OAuth + OAuth { + #[command(subcommand)] + provider: LoginProvider, + }, + /// Login via Magic Link + #[command(name = "mlink")] + MagicLink { + /// Email recipient to login via + email: String, + /// Optional access scopes to apply to this session + scope: Option, + }, +} + +#[derive(Subcommand, Debug, Clone)] +pub enum LoginProvider { + #[command(name = "google")] + /// Login via Google + Google, +} + +#[derive(ValueEnum, Debug, Clone, PartialEq)] +pub enum AuthenticationMode { + /// Retrieve and store an identity token. Identity mode is the default and should be used to + /// when you do not require extended (multi-day) access + #[value(name = "id")] + Identity, + /// Retrieve and store an api token. Token mode should be used when you want to authenticate + /// a machine for continued access. This requires the permission to create api tokens + #[value(name = "token")] + Token, +} + +impl LoginMethod { + pub async fn run(&self, ctx: &T, mode: &AuthenticationMode) -> Result + where + T: CliContext, + { + match self { + Self::OAuth { provider } => { + self.run_oauth_provider(provider, mode, ctx.oauth_adapter()) + .await + } + Self::MagicLink { email, scope } => { + self.run_magic_link(email, scope.as_deref(), ctx.mlink_adapter()) + .await + } + } + } + + async fn run_oauth_provider( + &self, + provider: &LoginProvider, + mode: &AuthenticationMode, + adapter: T + ) -> Result where T: CliOAuthAdapter { + let provider = adapter.provider(provider).await?; + let oauth_client = oauth::DeviceOAuth::new(provider)?; + let details = oauth_client.get_device_authorization().await?; + + println!( + "To complete login visit: {} and enter {}", + details.verification_uri().as_str(), + details.user_code().secret() + ); + + let token_response = oauth_client.login(&details).await; + + let identity_token = match token_response { + Ok(token) => Ok(token.access_token().to_owned()), + Err(err) => Err(anyhow::anyhow!("Authentication failed: {}", err)), + }?; + + if mode == &AuthenticationMode::Token { + let token = adapter.get_long_lived_token(identity_token.secret()).await?; + Ok(token.access_token().to_string()) + } else { + Ok(identity_token.secret().to_string()) + } + } + + async fn run_magic_link( + &self, + email: &str, + scope: Option<&str>, + adapter: T, + ) -> Result + where + T: CliMagicLinkAdapter, + { + let attempt = adapter.create_attempt(email, scope).await?; + + let mut auth_secret = String::new(); + print!("Enter the login token sent to the recipient: "); + std::io::stdout().flush()?; + std::io::stdin().read_line(&mut auth_secret)?; + + let token = adapter.exchange(attempt, email, &auth_secret).await?; + + Ok(token.access_token().to_string()) + } +} + +pub trait CliOAuthAdapter { + type Token: CliAdapterToken; + type Error: StdError + Send + Sync + 'static; + + fn provider(&self, provider: &LoginProvider) -> Pin> + Send>>; + fn get_long_lived_token(&self, access_token: &str) -> Pin> + Send>>; +} + +pub trait CliMagicLinkAdapter { + type Attempt; + type Token: CliAdapterToken; + type Error: StdError + Send + Sync + 'static; + + fn create_attempt(&self, email: &str, scope: Option<&str>) -> Pin> + Send>>; + fn exchange(&self, attempt: Self::Attempt, email: &str, token: &str) -> Pin> + Send>>; +} + +pub trait CliAdapterToken { + fn access_token(&self) -> &str; +} diff --git a/v-cli-sdk/src/cmd/auth/mod.rs b/v-cli-sdk/src/cmd/auth/mod.rs new file mode 100644 index 00000000..e50ce8f6 --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/mod.rs @@ -0,0 +1,34 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::Result; +use clap::{Parser, Subcommand}; + +use crate::CliContext; + +// mod link; +pub mod login; +pub mod oauth; + +// Authenticate against the Meetings API +#[derive(Parser, Debug)] +#[clap(name = "auth")] +pub struct Auth { + #[command(subcommand)] + auth: AuthCommands, +} + +#[derive(Subcommand, Debug, Clone)] +enum AuthCommands { + /// Login via an authentication provider + Login(login::Login), +} + +impl Auth { + pub async fn run(&self, ctx: &mut T) -> Result<()> where T: CliContext { + match &self.auth { + AuthCommands::Login(login) => login.run(ctx).await, + } + } +} diff --git a/v-cli-sdk/src/cmd/auth/oauth.rs b/v-cli-sdk/src/cmd/auth/oauth.rs new file mode 100644 index 00000000..36301a6f --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/oauth.rs @@ -0,0 +1,91 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::Result; +use oauth2::basic::{BasicClient, BasicTokenType}; +use oauth2::{ + AuthType, AuthUrl, ClientId, DeviceAuthorizationUrl, EmptyExtraTokenFields, EndpointNotSet, + EndpointSet, Scope, StandardTokenResponse, TokenUrl, +}; +use oauth2::StandardDeviceAuthorizationResponse; + +type DeviceClient = BasicClient< + // HasAuthUrl + EndpointSet, + // HasDeviceAuthUrl + EndpointSet, + // HasIntrospectionUrl + EndpointNotSet, + // HasRevocationUrl + EndpointNotSet, + // HasTokenUrl + EndpointSet, +>; + +pub struct DeviceOAuth { + client: DeviceClient, + http: oauth2_reqwest::ReqwestClient, + scopes: Vec, +} + +impl DeviceOAuth { + pub fn new(provider: T) -> Result where T: CliOAuthProviderInfo { + if let Some(device_endpoint) = provider.device_code_endpoint() { + let device_auth_url = DeviceAuthorizationUrl::new(device_endpoint.to_string())?; + + let client = BasicClient::new(ClientId::new(provider.client_id().to_string())) + .set_auth_uri(AuthUrl::new(provider.auth_url_endpoint().to_string())?) + .set_auth_type(AuthType::RequestBody) + .set_token_uri(TokenUrl::new(provider.token_endpoint().to_string())?) + .set_device_authorization_url(device_auth_url); + + Ok(Self { + client, + http: oauth2_reqwest::ReqwestClient::from( + reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap(), + ), + scopes: provider.scopes().into_iter().map(|s| s.to_string()).collect(), + }) + } else { + anyhow::bail!("Device authorization is not supported by this provider") + } + } + + pub async fn login( + &self, + details: &StandardDeviceAuthorizationResponse, + ) -> Result> { + let token = self + .client + .exchange_device_access_token(details) + .set_max_backoff_interval(details.interval()) + .request_async(&self.http, tokio::time::sleep, Some(details.expires_in())) + .await; + + Ok(token?) + } + + pub async fn get_device_authorization(&self) -> Result { + let mut req = self.client.exchange_device_code(); + + for scope in &self.scopes { + req = req.add_scope(Scope::new(scope.to_string())); + } + + let res = req.request_async(&self.http).await; + + Ok(res?) + } +} + +pub trait CliOAuthProviderInfo { + fn device_code_endpoint(&self) -> Option<&str>; + fn auth_url_endpoint(&self) -> &str; + fn token_endpoint(&self) -> &str; + fn client_id(&self) -> &str; + fn scopes(&self) -> &[String]; +} diff --git a/v-cli-sdk/src/cmd/config/mod.rs b/v-cli-sdk/src/cmd/config/mod.rs new file mode 100644 index 00000000..d643965b --- /dev/null +++ b/v-cli-sdk/src/cmd/config/mod.rs @@ -0,0 +1,134 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::Result; +use clap::{Parser, Subcommand}; + +use crate::{CliContext, FormatStyle}; + +pub trait CliConfig { + fn host(&self) -> Option<&str>; + fn set_host(&mut self, host: String); + fn token(&self) -> Option<&str>; + fn set_token(&mut self, token: String); + fn default_format(&self) -> Option<&FormatStyle>; + fn set_default_format(&mut self, format: FormatStyle); + fn mlink_redirect(&self) -> Option<&str>; + fn set_mlink_redirect(&mut self, redirect: String); + fn mlink_secret(&self) -> Option<&str>; + fn set_mlink_secret(&mut self, secret: String); + fn save(&self) -> Result<(), std::io::Error>; +} + +#[derive(Debug, Parser)] +#[clap(name = "config")] +pub struct ConfigCmd { + #[clap(subcommand)] + setting: SettingCmd, +} + +#[derive(Debug, Subcommand)] +pub enum SettingCmd { + /// Gets a setting + #[clap(subcommand, name = "get")] + Get(GetCmd), + /// Sets a setting + #[clap(subcommand, name = "set")] + Set(SetCmd), +} + +#[derive(Debug, Subcommand)] +pub enum GetCmd { + /// Get the default formatter to use when printing results + #[clap(name = "format")] + Format, + /// Get the configured API host in use + #[clap(name = "host")] + Host, + /// Get the configured access token + #[clap(name = "token")] + Token, + /// Get the configured magic redirect uri + #[clap(name = "mlink-redirect")] + MagicLinkRedirectUri, + /// Get the configured magic link secret + #[clap(name = "mlink-secret")] + MagicLinkSecret, +} + +#[derive(Debug, Subcommand)] +pub enum SetCmd { + /// Set the default formatter to use when printing results + #[clap(name = "format")] + Format { format: FormatStyle }, + /// Set the configured API host to use + #[clap(name = "host")] + Host { host: String }, + /// Set the configured magic redirect uri + #[clap(name = "mlink-redirect")] + MagicLinkRedirectUri { redirect: String }, + /// Set the configured magic link secret + #[clap(name = "mlink-secret")] + MagicLinkSecret { secret: String }, +} + +impl ConfigCmd { + pub async fn run(&self, ctx: &mut T) -> Result<()> where T: CliContext { + match &self.setting { + SettingCmd::Get(get) => get.run(ctx.config()).await?, + SettingCmd::Set(set) => set.run(ctx.config_mut()).await?, + } + + Ok(()) + } +} + +impl GetCmd { + pub async fn run(&self, config: &T) -> Result<()> where T: CliConfig { + match &self { + GetCmd::Format => { + println!("{}", config.default_format().map(|f| *f).unwrap_or(FormatStyle::Json)); + } + GetCmd::Host => { + println!("{}", config.host().unwrap_or("None")); + } + GetCmd::Token => { + println!("{}", config.token().unwrap_or("None")); + } + GetCmd::MagicLinkRedirectUri => { + println!("{}", config.mlink_redirect().unwrap_or("None")); + } + GetCmd::MagicLinkSecret => { + println!("{}", config.mlink_secret().unwrap_or("None")); + } + } + + Ok(()) + } +} + +impl SetCmd { + pub async fn run(&self, config: &mut T) -> Result<()> where T: CliConfig { + match &self { + SetCmd::Format { format } => { + config.set_default_format(format.clone()); + config.save()?; + } + SetCmd::Host { host } => { + config.set_host(host.to_string()); + config.save()?; + } + SetCmd::MagicLinkRedirectUri { redirect } => { + config.set_mlink_redirect(redirect.to_string()); + config.save()?; + } + SetCmd::MagicLinkSecret { secret } => { + config.set_mlink_secret(secret.to_string()); + config.save()?; + } + } + + Ok(()) + } +} diff --git a/v-cli-sdk/src/cmd/mod.rs b/v-cli-sdk/src/cmd/mod.rs new file mode 100644 index 00000000..c73c20da --- /dev/null +++ b/v-cli-sdk/src/cmd/mod.rs @@ -0,0 +1,2 @@ +pub mod auth; +pub mod config; diff --git a/v-cli-sdk/src/err.rs b/v-cli-sdk/src/err.rs new file mode 100644 index 00000000..2c7670e4 --- /dev/null +++ b/v-cli-sdk/src/err.rs @@ -0,0 +1,64 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::{Error, anyhow}; +use progenitor_client::Error as ProgenitorClientError; + +use crate::{ApiErrorMessage, CliContext, VerbosityLevel}; + +pub fn format_api_err(ctx: &T, client_err: ProgenitorClientError) -> Error where T: CliContext, E: ApiErrorMessage { + let mut err = anyhow!("API Request failed"); + + match client_err { + ProgenitorClientError::CommunicationError(inner) => { + if ctx.verbosity() >= VerbosityLevel::All { + err = err.context("Communication Error").context(inner); + } + } + ProgenitorClientError::ErrorResponse(response) => { + if ctx.verbosity() >= VerbosityLevel::All { + err = err.context(format!("Status: {}", response.status())); + err = err.context(format!("Headers {:?}", response.headers())); + } + + let response_message = response.into_inner(); + + if ctx.verbosity() >= VerbosityLevel::All { + err = err.context(format!("Request {}", response_message.request_id().as_deref().unwrap_or(""))); + } + + err = err.context(format!( + "Code: {}", + response_message.error_code().as_deref().unwrap_or("") + )); + err = err.context(response_message.message().as_deref().unwrap_or("").to_string()); + } + ProgenitorClientError::InvalidRequest(message) => { + err = err.context("Invalid request").context(message); + } + ProgenitorClientError::InvalidResponsePayload(_, inner) => { + err = err.context("Invalid response").context(inner); + } + ProgenitorClientError::UnexpectedResponse(response) => { + err = err + .context("Unexpected response") + .context(format!("Status: {}", response.status())); + + if ctx.verbosity() >= VerbosityLevel::All { + err = err.context(format!("Headers {:?}", response.headers())); + } + } + ProgenitorClientError::ResponseBodyError(inner) => { + err = err.context("Invalid response").context(inner); + } + ProgenitorClientError::InvalidUpgrade(inner) => { + err = err.context("Invalid upgrade").context(inner) + } + ProgenitorClientError::Custom(inner) => { + err = err.context("Inner progenitor error").context(inner) + } + } + + err +} diff --git a/v-cli-sdk/src/lib.rs b/v-cli-sdk/src/lib.rs new file mode 100644 index 00000000..b5375292 --- /dev/null +++ b/v-cli-sdk/src/lib.rs @@ -0,0 +1,52 @@ +use clap::ValueEnum; +use serde::{Deserialize, Serialize}; +use std::fmt::Display; + +use crate::cmd::{auth::login::{CliMagicLinkAdapter, CliOAuthAdapter}, config::CliConfig}; + +pub mod cmd; +pub mod err; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum VerbosityLevel { + None, + All, +} + +#[derive(Copy, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Clone, Serialize, Deserialize)] +pub enum FormatStyle { + #[value(name = "json")] + Json, + #[value(name = "tab")] + Tab, +} + +impl Display for FormatStyle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Json => write!(f, "json"), + Self::Tab => write!(f, "tab"), + } + } +} + +pub trait CliContext { + type Attempt; + type Token; + type Error; + + fn config(&self) -> &impl CliConfig; + fn config_mut(&mut self) -> &mut impl CliConfig; + fn client(&self) -> Option<&C>; + fn printer(&self) -> Option<&P>; + fn verbosity(&self) -> VerbosityLevel; + + fn oauth_adapter(&self) -> impl CliOAuthAdapter; + fn mlink_adapter(&self) -> impl CliMagicLinkAdapter; +} + +pub trait ApiErrorMessage { + fn message(&self) -> Option<&str>; + fn error_code(&self) -> Option<&str>; + fn request_id(&self) -> Option<&str>; +} From 7f99622336103a990f8e88c1cc848fcf566e2463 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 29 Apr 2026 15:37:31 -0500 Subject: [PATCH 03/51] OAuth code proxy work --- Cargo.lock | 4 + Cargo.toml | 1 + v-api/src/context/mod.rs | 3 +- v-api/src/endpoints/login/mod.rs | 2 + v-api/src/endpoints/login/oauth/code.rs | 19 +- .../src/endpoints/login/oauth/device_token.rs | 5 +- v-api/src/endpoints/login/oauth/github.rs | 17 ++ v-api/src/endpoints/login/oauth/google.rs | 17 ++ v-api/src/endpoints/login/oauth/mod.rs | 35 ++- v-api/src/endpoints/login/oauth/zendesk.rs | 29 +++ v-cli-sdk/Cargo.toml | 6 +- v-cli-sdk/src/cmd/auth/login.rs | 144 ++++++++--- v-cli-sdk/src/cmd/auth/mod.rs | 30 ++- v-cli-sdk/src/cmd/auth/oauth.rs | 241 +++++++++++++++++- v-cli-sdk/src/cmd/auth/proxy.rs | 129 ++++++++++ v-cli-sdk/src/cmd/config/mod.rs | 23 +- v-cli-sdk/src/err.rs | 21 +- v-cli-sdk/src/lib.rs | 13 +- 18 files changed, 636 insertions(+), 103 deletions(-) create mode 100644 v-cli-sdk/src/cmd/auth/proxy.rs diff --git a/Cargo.lock b/Cargo.lock index ee5d1518..5a3f6cf2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3881,6 +3881,10 @@ dependencies = [ "anyhow", "chrono", "clap", + "http", + "http-body-util", + "hyper", + "hyper-util", "oauth2", "oauth2-reqwest", "progenitor-client", diff --git a/Cargo.toml b/Cargo.toml index da03d8db..e2353bad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ hex = "0.4.3" http = "1" http-body-util = "0.1.3" hyper = "1.9.0" +hyper-util = "0.1" jsonwebtoken = { version = "10.2", features = ["rust_crypto"] } mockall = "0.14.0" newtype-uuid = { version = "1.3.2", features = ["schemars08", "serde", "v4"] } diff --git a/v-api/src/context/mod.rs b/v-api/src/context/mod.rs index b73abdb3..f0dbd40c 100644 --- a/v-api/src/context/mod.rs +++ b/v-api/src/context/mod.rs @@ -1259,7 +1259,7 @@ pub(crate) mod test_mocks { pub async fn mock_context(storage: Arc) -> VContext { let MockKey { signer, verifier } = mock_key("test"); let mut ctx = VContextBuilder::::new() - .with_public_url("".to_string()) + .with_public_url("https://test_public_url".to_string()) .with_storage(storage) .with_jwt_expiration(JwtConfig::default().default_expiration) .with_keys(vec![signer, verifier]) @@ -1278,6 +1278,7 @@ pub(crate) mod test_mocks { OAuthProviderName::Google, Box::new(move || { Box::new(GoogleOAuthProvider::new( + "https://test_public_url".to_string(), "google_device_client_id".to_string(), "google_device_client_secret".to_string().into(), "google_web_client_id".to_string(), diff --git a/v-api/src/endpoints/login/mod.rs b/v-api/src/endpoints/login/mod.rs index d6c30b30..05a7207e 100644 --- a/v-api/src/endpoints/login/mod.rs +++ b/v-api/src/endpoints/login/mod.rs @@ -201,6 +201,8 @@ pub enum UserInfoError { Deserialize(#[from] serde_json::Error), #[error("Failed to create user info request {0}")] Http(#[from] http::Error), + #[error("User account is locked")] + Locked, #[error("User information is missing")] MissingUserInfoData(String), } diff --git a/v-api/src/endpoints/login/oauth/code.rs b/v-api/src/endpoints/login/oauth/code.rs index 39d3a4a8..969547f9 100644 --- a/v-api/src/endpoints/login/oauth/code.rs +++ b/v-api/src/endpoints/login/oauth/code.rs @@ -30,7 +30,7 @@ use v_model::{ LoginAttempt, LoginAttemptId, NewLoginAttempt, OAuthClient, OAuthClientId, }; -use super::{OAuthProvider, OAuthProviderNameParam, UserInfoProvider, WebClientConfig}; +use super::{OAuthProvider, OAuthProviderNameParam, UserInfoProvider}; use crate::{ authn::key::RawKey, context::{ApiContext, VContext}, @@ -239,11 +239,7 @@ fn oauth_redirect_response( // TODO: This behavior should be changed so that clients are precomputed. We do not need to be // constructing a new client on every request. That said, we need to ensure the client does not // maintain state between requests - let client = provider - .as_web_client(&WebClientConfig { - prefix: public_url.to_string(), - }) - .map_err(to_internal_error)?; + let client = provider.as_web_client().map_err(to_internal_error)?; // Create an attempt cookie header for storing the login attempt. This also acts as our csrf // check @@ -715,11 +711,7 @@ async fn fetch_user_info( attempt: &LoginAttempt, ) -> Result { // Exchange the stored authorization code with the remote provider for a remote access token - let client = provider - .as_web_client(&WebClientConfig { - prefix: public_url.to_string(), - }) - .map_err(to_internal_error)?; + let client = provider.as_web_client().map_err(to_internal_error)?; let mut request = client.exchange_code(AuthorizationCode::new( attempt @@ -889,8 +881,7 @@ mod tests { #[tokio::test] async fn test_remote_provider_redirect_url() { let storage = MockStorage::new(); - let mut ctx = mock_context(Arc::new(storage)).await; - ctx.with_public_url("https://api.oxeng.dev"); + let ctx = mock_context(Arc::new(storage)).await; let (challenge, _) = PkceCodeChallenge::new_random_sha256(); let attempt = LoginAttempt { @@ -927,7 +918,7 @@ mod tests { .unwrap(); let headers = response.headers(); - let expected_location = format!("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=google_web_client_id&state={}&code_challenge={}&code_challenge_method=S256&redirect_uri=https%3A%2F%2Fapi.oxeng.dev%2Flogin%2Foauth%2Fgoogle%2Fcode%2Fcallback&scope=openid+email+profile", attempt.id, challenge.as_str()); + let expected_location = format!("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=google_web_client_id&state={}&code_challenge={}&code_challenge_method=S256&redirect_uri=https%3A%2F%2Ftest_public_url%2Flogin%2Foauth%2Fgoogle%2Fcode%2Fcallback&scope=openid+email+profile", attempt.id, challenge.as_str()); assert_eq!( expected_location, diff --git a/v-api/src/endpoints/login/oauth/device_token.rs b/v-api/src/endpoints/login/oauth/device_token.rs index 5fb2eb04..22488a61 100644 --- a/v-api/src/endpoints/login/oauth/device_token.rs +++ b/v-api/src/endpoints/login/oauth/device_token.rs @@ -40,10 +40,7 @@ where .await .map_err(ApiError::OAuth)?; - Ok(HttpResponseOk(provider.provider_info( - rqctx.v_ctx().public_url(), - &ClientType::Device, - ))) + Ok(HttpResponseOk(provider.provider_info(&ClientType::Device))) } #[derive(Debug, Deserialize, JsonSchema, Serialize)] diff --git a/v-api/src/endpoints/login/oauth/github.rs b/v-api/src/endpoints/login/oauth/github.rs index 614f7d36..3f3483fe 100644 --- a/v-api/src/endpoints/login/oauth/github.rs +++ b/v-api/src/endpoints/login/oauth/github.rs @@ -26,6 +26,9 @@ pub struct GitHubOAuthProvider { additional_scopes: Vec, default_headers: HeaderMap, client: reqwest::Client, + token_endpoint: Option, + redirect_endpoint: Option, + redirect_proxy_endpoint: Option, } impl fmt::Debug for GitHubOAuthProvider { @@ -36,6 +39,7 @@ impl fmt::Debug for GitHubOAuthProvider { impl GitHubOAuthProvider { pub fn new( + public_url: String, device_client_id: String, device_client_secret: SecretString, web_client_id: String, @@ -64,6 +68,9 @@ impl GitHubOAuthProvider { .redirect(reqwest::redirect::Policy::none()) .build() .expect("Static client must build"), + token_endpoint: Some(format!("{}/login/oauth/github/device/exchange", public_url)), + redirect_endpoint: Some(format!("{}/login/oauth/github/code/callback", public_url,)), + redirect_proxy_endpoint: None, } } @@ -177,4 +184,14 @@ impl OAuthProvider for GitHubOAuthProvider { fn supports_pkce(&self) -> bool { true } + + fn token_endpoint(&self) -> Option<&str> { + self.token_endpoint.as_deref() + } + fn redirect_endpoint(&self) -> Option<&str> { + self.redirect_endpoint.as_deref() + } + fn redirect_proxy_endpoint(&self) -> Option<&str> { + self.redirect_proxy_endpoint.as_deref() + } } diff --git a/v-api/src/endpoints/login/oauth/google.rs b/v-api/src/endpoints/login/oauth/google.rs index aacb4837..e477ae44 100644 --- a/v-api/src/endpoints/login/oauth/google.rs +++ b/v-api/src/endpoints/login/oauth/google.rs @@ -22,6 +22,9 @@ pub struct GoogleOAuthProvider { web_private: Option, additional_scopes: Vec, client: reqwest::Client, + token_endpoint: Option, + redirect_endpoint: Option, + redirect_proxy_endpoint: Option, } impl fmt::Debug for GoogleOAuthProvider { @@ -32,6 +35,7 @@ impl fmt::Debug for GoogleOAuthProvider { impl GoogleOAuthProvider { pub fn new( + public_url: String, device_client_id: String, device_client_secret: SecretString, web_client_id: String, @@ -56,6 +60,9 @@ impl GoogleOAuthProvider { .redirect(reqwest::redirect::Policy::none()) .build() .expect("Static client must build"), + token_endpoint: Some(format!("{}/login/oauth/google/device/exchange", public_url)), + redirect_endpoint: Some(format!("{}/login/oauth/google/code/callback", public_url,)), + redirect_proxy_endpoint: None, } } @@ -185,4 +192,14 @@ impl OAuthProvider for GoogleOAuthProvider { fn supports_pkce(&self) -> bool { true } + + fn token_endpoint(&self) -> Option<&str> { + self.token_endpoint.as_deref() + } + fn redirect_endpoint(&self) -> Option<&str> { + self.redirect_endpoint.as_deref() + } + fn redirect_proxy_endpoint(&self) -> Option<&str> { + self.redirect_proxy_endpoint.as_deref() + } } diff --git a/v-api/src/endpoints/login/oauth/mod.rs b/v-api/src/endpoints/login/oauth/mod.rs index 08d2edde..b986c5c4 100644 --- a/v-api/src/endpoints/login/oauth/mod.rs +++ b/v-api/src/endpoints/login/oauth/mod.rs @@ -33,6 +33,10 @@ pub mod zendesk; pub enum OAuthProviderError { #[error("Unable to instantiate invalid provider")] FailToCreateInvalidProvider, + #[error("Missing redirect URI")] + MissingRedirectUri, + #[error("Failed to parse URL")] + UrlParseError(#[from] ParseError), } #[derive(Debug)] @@ -41,11 +45,6 @@ pub enum ClientType { Web, } -#[derive(Debug)] -pub struct WebClientConfig { - prefix: String, -} - pub type WebClient = BasicClient< // HasAuthUrl EndpointSet, @@ -84,13 +83,19 @@ pub trait OAuthProvider: ExtractUserInfo + Debug + Send + Sync { fn token_revocation_endpoint(&self) -> Option<&str>; fn supports_pkce(&self) -> bool; - fn provider_info(&self, public_url: &str, client_type: &ClientType) -> OAuthProviderInfo { + fn token_endpoint(&self) -> Option<&str>; + fn redirect_endpoint(&self) -> Option<&str>; + fn redirect_proxy_endpoint(&self) -> Option<&str>; + + fn provider_info(&self, client_type: &ClientType) -> OAuthProviderInfo { OAuthProviderInfo { provider: self.name(), client_id: self.client_id(client_type).to_string(), auth_url_endpoint: self.auth_url_endpoint().to_string(), device_code_endpoint: self.device_code_endpoint().map(|s| s.to_string()), - token_endpoint: format!("{}/login/oauth/{}/device/exchange", public_url, self.name(),), + token_endpoint: self.token_endpoint().map(|s| s.to_string()), + redirect_endpoint: self.redirect_endpoint().map(|s| s.to_string()), + redirect_proxy_endpoint: self.redirect_proxy_endpoint().map(|s| s.to_string()), scopes: self .scopes() .into_iter() @@ -99,7 +104,7 @@ pub trait OAuthProvider: ExtractUserInfo + Debug + Send + Sync { } } - fn as_web_client(&self, config: &WebClientConfig) -> Result { + fn as_web_client(&self) -> Result { let mut client = BasicClient::new(ClientId::new(self.client_id(&ClientType::Web).to_string())) .set_auth_uri(AuthUrl::new(self.auth_url_endpoint().to_string())?) @@ -109,11 +114,11 @@ pub trait OAuthProvider: ExtractUserInfo + Debug + Send + Sync { .map(|s| RevocationUrl::new(s.to_string())) .transpose()?, ) - .set_redirect_uri(RedirectUrl::new(format!( - "{}/login/oauth/{}/code/callback", - &config.prefix, - self.name() - ))?); + .set_redirect_uri(RedirectUrl::new( + self.redirect_endpoint() + .ok_or(OAuthProviderError::MissingRedirectUri)? + .to_string(), + )?); if let Some(secret) = self.client_secret(&ClientType::Web) { client = client.set_client_secret(ClientSecret::new(secret.expose_secret().to_string())) @@ -173,7 +178,9 @@ pub struct OAuthProviderInfo { client_id: String, auth_url_endpoint: String, device_code_endpoint: Option, - token_endpoint: String, + token_endpoint: Option, + redirect_endpoint: Option, + redirect_proxy_endpoint: Option, scopes: Vec, } diff --git a/v-api/src/endpoints/login/oauth/zendesk.rs b/v-api/src/endpoints/login/oauth/zendesk.rs index 6b3c56e1..b2421ce2 100644 --- a/v-api/src/endpoints/login/oauth/zendesk.rs +++ b/v-api/src/endpoints/login/oauth/zendesk.rs @@ -25,6 +25,9 @@ pub struct ZendeskOAuthProvider { user_info_endpoint: String, auth_url_endpoint: String, token_exchange_endpoint: String, + token_endpoint: Option, + redirect_endpoint: Option, + redirect_proxy_endpoint: Option, } impl fmt::Debug for ZendeskOAuthProvider { @@ -35,12 +38,14 @@ impl fmt::Debug for ZendeskOAuthProvider { impl ZendeskOAuthProvider { pub fn new( + public_url: String, subdomain: String, device_client_id: String, device_client_secret: SecretString, web_client_id: String, web_client_secret: SecretString, additional_scopes: Option>, + redirect_proxy_port: u16, ) -> Self { let base_url = format!("https://{}.zendesk.com", subdomain); @@ -65,6 +70,15 @@ impl ZendeskOAuthProvider { user_info_endpoint: format!("{}/api/v2/users/me.json", base_url), auth_url_endpoint: format!("{}/oauth/authorizations/new", base_url), token_exchange_endpoint: format!("{}/oauth/tokens", base_url), + token_endpoint: Some(format!( + "{}/login/oauth/zendesk/device/exchange", + public_url + )), + redirect_endpoint: Some(format!("{}/login/oauth/zendesk/code/callback", public_url,)), + redirect_proxy_endpoint: Some(format!( + "http://localhost:{}/login/oauth/zendesk/code/callback", + redirect_proxy_port + )), } } @@ -85,6 +99,7 @@ struct ZendeskUser { name: String, email: String, verified: bool, + suspended: bool, } impl ExtractUserInfo for ZendeskOAuthProvider { @@ -92,6 +107,10 @@ impl ExtractUserInfo for ZendeskOAuthProvider { let response: ZendeskUserResponse = serde_json::from_slice(&data[0])?; let user = response.user; + if user.suspended { + return Err(UserInfoError::Locked); + } + let verified_emails = if user.verified { vec![user.email] } else { @@ -170,4 +189,14 @@ impl OAuthProvider for ZendeskOAuthProvider { fn supports_pkce(&self) -> bool { false } + + fn token_endpoint(&self) -> Option<&str> { + self.token_endpoint.as_deref() + } + fn redirect_endpoint(&self) -> Option<&str> { + self.redirect_endpoint.as_deref() + } + fn redirect_proxy_endpoint(&self) -> Option<&str> { + self.redirect_proxy_endpoint.as_deref() + } } diff --git a/v-cli-sdk/Cargo.toml b/v-cli-sdk/Cargo.toml index 52f46795..c3c630ec 100644 --- a/v-cli-sdk/Cargo.toml +++ b/v-cli-sdk/Cargo.toml @@ -7,9 +7,13 @@ edition = "2021" anyhow = { workspace = true } chrono = { workspace = true } clap = { workspace = true } +http = { workspace = true } +http-body-util = { workspace = true } +hyper = { workspace = true, features = ["server", "http1"] } +hyper-util = { workspace = true, features = ["tokio"] } oauth2 = { workspace = true } oauth2-reqwest = { workspace = true } progenitor-client = { workspace = true } reqwest = { workspace = true } serde = { workspace = true } -tokio = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt", "net", "sync"] } diff --git a/v-cli-sdk/src/cmd/auth/login.rs b/v-cli-sdk/src/cmd/auth/login.rs index 089e2808..6d09a937 100644 --- a/v-cli-sdk/src/cmd/auth/login.rs +++ b/v-cli-sdk/src/cmd/auth/login.rs @@ -5,24 +5,44 @@ use anyhow::Result; use clap::{Parser, Subcommand, ValueEnum}; use oauth2::TokenResponse; -use std::{error::Error as StdError, future::Future, io::Write, pin::Pin}; +use std::{error::Error as StdError, fmt::Debug, future::Future, io::Write, pin::Pin, sync::Arc}; -use crate::{CliContext, cmd::{auth::oauth::{self, CliOAuthProviderInfo}, config::CliConfig}}; +use crate::{ + cmd::{ + auth::oauth::{self, CliOAuthAdapter, CliOAuthProviderInfo}, + config::CliConfig, + }, + CliContext, +}; + +pub trait CliAdapterToken { + fn access_token(&self) -> &str; +} + +pub trait CliConsumerLoginProvider: Into + Subcommand + Debug + Clone {} +impl CliConsumerLoginProvider for T where T: Into + Subcommand + Debug + Clone {} // Authenticates and generates an access token for interacting with the api #[derive(Parser, Debug, Clone)] #[clap(name = "login")] -pub struct Login { +pub struct Login

+where + P: CliConsumerLoginProvider, +{ #[command(subcommand)] - method: LoginMethod, + method: LoginMethod

, #[arg(short = 'm', default_value = "id")] mode: AuthenticationMode, } -impl Login { - pub async fn run(&self, ctx: &mut T) -> Result<()> +impl

Login

+where + P: CliConsumerLoginProvider, +{ + pub async fn run(&self, ctx: &mut T) -> Result<()> where - T: CliContext, + T: CliContext, + >::Error: StdError + Send + Sync + 'static, { let access_token = self.method.run(ctx, &self.mode).await?; @@ -34,12 +54,15 @@ impl Login { } #[derive(Subcommand, Debug, Clone)] -pub enum LoginMethod { +pub enum LoginMethod

+where + P: Subcommand + Debug + Clone, +{ #[command(name = "oauth")] /// Login via OAuth OAuth { #[command(subcommand)] - provider: LoginProvider, + provider: P, }, /// Login via Magic Link #[command(name = "mlink")] @@ -51,11 +74,10 @@ pub enum LoginMethod { }, } -#[derive(Subcommand, Debug, Clone)] pub enum LoginProvider { - #[command(name = "google")] - /// Login via Google Google, + GitHub, + Zendesk, } #[derive(ValueEnum, Debug, Clone, PartialEq)] @@ -68,17 +90,40 @@ pub enum AuthenticationMode { /// a machine for continued access. This requires the permission to create api tokens #[value(name = "token")] Token, + /// Retrieve and store a remote token. Remote mode should be used when you want to authenticate + /// and retrieve a token for use against the underlying authentication provider + #[value(name = "remote")] + Remote, } -impl LoginMethod { - pub async fn run(&self, ctx: &T, mode: &AuthenticationMode) -> Result +impl

LoginMethod

+where + P: CliConsumerLoginProvider, +{ + pub async fn run(&self, ctx: &T, mode: &AuthenticationMode) -> Result where - T: CliContext, + T: CliContext, + >::Error: StdError + Send + Sync + 'static, { match self { Self::OAuth { provider } => { - self.run_oauth_provider(provider, mode, ctx.oauth_adapter()) - .await + let adapter = ctx.oauth_adapter(); + let provider = provider.clone().into(); + let provider = adapter.provider(&provider).await?; + + // We now need to inspect the provider to determine the correct flow to use. If + // possible we use a limited input device flow, but not all providers support it. + // To handle those cases we need to use a proxy path that emulates an authorization + // code flow. + if provider.device_code_endpoint().is_some() { + self.run_oauth_device_provider(provider, mode, ctx.oauth_adapter()) + .await + } else if provider.code_redirect_proxy_endpoint().is_some() { + self.run_oauth_code_provider(provider, mode, ctx.oauth_adapter()) + .await + } else { + anyhow::bail!("OAuth provider does not support any CLI authentication methods") + } } Self::MagicLink { email, scope } => { self.run_magic_link(email, scope.as_deref(), ctx.mlink_adapter()) @@ -87,13 +132,16 @@ impl LoginMethod { } } - async fn run_oauth_provider( + async fn run_oauth_device_provider( &self, - provider: &LoginProvider, + provider: V, mode: &AuthenticationMode, - adapter: T - ) -> Result where T: CliOAuthAdapter { - let provider = adapter.provider(provider).await?; + adapter: T, + ) -> Result + where + T: CliOAuthAdapter, + V: CliOAuthProviderInfo, + { let oauth_client = oauth::DeviceOAuth::new(provider)?; let details = oauth_client.get_device_authorization().await?; @@ -111,13 +159,38 @@ impl LoginMethod { }?; if mode == &AuthenticationMode::Token { - let token = adapter.get_long_lived_token(identity_token.secret()).await?; + let token = adapter + .get_long_lived_token(identity_token.secret()) + .await?; Ok(token.access_token().to_string()) } else { Ok(identity_token.secret().to_string()) } } + async fn run_oauth_code_provider( + &self, + provider: V, + mode: &AuthenticationMode, + adapter: T, + ) -> Result + where + T: CliOAuthAdapter + Send + Sync + 'static, + V: CliOAuthProviderInfo, + { + let oauth_client = oauth::CodeOAuth::new(provider)?; + let adapter = Arc::new(adapter); + + let identity_token = oauth_client.login(Arc::clone(&adapter)).await?; + + if mode == &AuthenticationMode::Token { + let token = adapter.get_long_lived_token(&identity_token).await?; + Ok(token.access_token().to_string()) + } else { + Ok(identity_token) + } + } + async fn run_magic_link( &self, email: &str, @@ -140,23 +213,20 @@ impl LoginMethod { } } -pub trait CliOAuthAdapter { - type Token: CliAdapterToken; - type Error: StdError + Send + Sync + 'static; - - fn provider(&self, provider: &LoginProvider) -> Pin> + Send>>; - fn get_long_lived_token(&self, access_token: &str) -> Pin> + Send>>; -} - pub trait CliMagicLinkAdapter { type Attempt; type Token: CliAdapterToken; type Error: StdError + Send + Sync + 'static; - fn create_attempt(&self, email: &str, scope: Option<&str>) -> Pin> + Send>>; - fn exchange(&self, attempt: Self::Attempt, email: &str, token: &str) -> Pin> + Send>>; -} - -pub trait CliAdapterToken { - fn access_token(&self) -> &str; + fn create_attempt( + &self, + email: &str, + scope: Option<&str>, + ) -> Pin> + Send>>; + fn exchange( + &self, + attempt: Self::Attempt, + email: &str, + token: &str, + ) -> Pin> + Send>>; } diff --git a/v-cli-sdk/src/cmd/auth/mod.rs b/v-cli-sdk/src/cmd/auth/mod.rs index e50ce8f6..5de8042e 100644 --- a/v-cli-sdk/src/cmd/auth/mod.rs +++ b/v-cli-sdk/src/cmd/auth/mod.rs @@ -4,29 +4,43 @@ use anyhow::Result; use clap::{Parser, Subcommand}; +use std::error::Error as StdError; -use crate::CliContext; +use crate::{cmd::auth::login::CliConsumerLoginProvider, CliContext}; -// mod link; pub mod login; pub mod oauth; +pub mod proxy; // Authenticate against the Meetings API #[derive(Parser, Debug)] #[clap(name = "auth")] -pub struct Auth { +pub struct Auth

+where + P: CliConsumerLoginProvider, +{ #[command(subcommand)] - auth: AuthCommands, + auth: AuthCommands

, } #[derive(Subcommand, Debug, Clone)] -enum AuthCommands { +enum AuthCommands

+where + P: CliConsumerLoginProvider, +{ /// Login via an authentication provider - Login(login::Login), + Login(login::Login

), } -impl Auth { - pub async fn run(&self, ctx: &mut T) -> Result<()> where T: CliContext { +impl

Auth

+where + P: CliConsumerLoginProvider, +{ + pub async fn run(&self, ctx: &mut T) -> Result<()> + where + T: CliContext, + >::Error: StdError + Send + Sync + 'static, + { match &self.auth { AuthCommands::Login(login) => login.run(ctx).await, } diff --git a/v-cli-sdk/src/cmd/auth/oauth.rs b/v-cli-sdk/src/cmd/auth/oauth.rs index 36301a6f..8f5f6751 100644 --- a/v-cli-sdk/src/cmd/auth/oauth.rs +++ b/v-cli-sdk/src/cmd/auth/oauth.rs @@ -2,13 +2,233 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at https://mozilla.org/MPL/2.0/. +use std::error::Error as StdError; +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; + use anyhow::Result; +use http::{Request, Response, StatusCode}; +use http_body_util::{BodyExt, Full}; +use hyper::body::{Bytes, Incoming}; use oauth2::basic::{BasicClient, BasicTokenType}; +use oauth2::StandardDeviceAuthorizationResponse; use oauth2::{ - AuthType, AuthUrl, ClientId, DeviceAuthorizationUrl, EmptyExtraTokenFields, EndpointNotSet, - EndpointSet, Scope, StandardTokenResponse, TokenUrl, + AuthType, AuthUrl, ClientId, CsrfToken, DeviceAuthorizationUrl, EmptyExtraTokenFields, + EndpointNotSet, EndpointSet, RedirectUrl, Scope, StandardTokenResponse, TokenUrl, }; -use oauth2::StandardDeviceAuthorizationResponse; +use reqwest::Url; +use tokio::sync::oneshot; + +use crate::cmd::auth::login::CliAdapterToken; + +use super::proxy::run_proxy_server; + +pub trait CliOAuthAdapter { + type Token: CliAdapterToken; + type Error: StdError + Send + Sync + 'static; + + fn provider( + &self, + provider: &super::login::LoginProvider, + ) -> Pin> + Send>>; + fn exchange_authorization_code( + &self, + request: Request, + ) -> Pin>, Self::Error>> + Send>>; + fn get_long_lived_token( + &self, + access_token: &str, + ) -> Pin> + Send>>; +} + +pub trait CliOAuthProviderInfo { + fn device_code_endpoint(&self) -> Option<&str>; + fn code_redirect_proxy_endpoint(&self) -> Option<&str>; + fn auth_url_endpoint(&self) -> &str; + fn token_endpoint(&self) -> &str; + fn client_id(&self) -> &str; + fn scopes(&self) -> &[String]; +} + +type CodeClient = BasicClient< + // HasAuthUrl + EndpointSet, + // HasDeviceAuthUrl + EndpointNotSet, + // HasIntrospectionUrl + EndpointNotSet, + // HasRevocationUrl + EndpointNotSet, + // HasTokenUrl + EndpointSet, +>; + +pub struct CodeOAuth { + client: CodeClient, + scopes: Vec, + port: u16, +} + +impl CodeOAuth { + pub fn new(provider: T) -> Result + where + T: CliOAuthProviderInfo, + { + let redirect_url = provider + .code_redirect_proxy_endpoint() + .ok_or_else(|| anyhow::anyhow!("Provider does not support code redirect proxy flow"))?; + + let parsed_url = Url::parse(redirect_url)?; + + let port = parsed_url.port().ok_or_else(|| { + anyhow::anyhow!("Provider proxy url does not have a defined port to listen on") + })?; + + if parsed_url.scheme() != "http" { + anyhow::bail!("Provider proxy url scheme must be http"); + } + + if parsed_url + .host_str() + .map(|h| h != "localhost" && h != "127.0.0.1") + .unwrap_or(true) + { + anyhow::bail!("Provider proxy url host must be localhost"); + } + + let client = BasicClient::new(ClientId::new(provider.client_id().to_string())) + .set_auth_uri(AuthUrl::new(provider.auth_url_endpoint().to_string())?) + .set_auth_type(AuthType::RequestBody) + .set_token_uri(TokenUrl::new(provider.token_endpoint().to_string())?) + .set_redirect_uri(RedirectUrl::new(redirect_url.to_string())?); + + Ok(Self { + client, + scopes: provider.scopes().iter().map(|s| s.to_string()).collect(), + port, + }) + } + + /// Build the authorization URL that the user should visit in a browser. + /// Returns the full URL and the CSRF state token used for verification. + pub fn authorize_url(&self) -> (oauth2::url::Url, CsrfToken) { + let mut req = self.client.authorize_url(CsrfToken::new_random); + + for scope in &self.scopes { + req = req.add_scope(Scope::new(scope.to_string())); + } + + req.url() + } + + /// Run the full authorization code login flow: + /// + /// 1. Generate the authorization URL and print it for the user. + /// 2. Spin up a local HTTP proxy server to capture the IdP redirect. + /// 3. Forward the redirect request to the API server via the adapter. + /// 4. Extract the token from the server's response. + /// 5. Return a success page to the browser and shut down the proxy. + pub async fn login(&self, adapter: Arc) -> Result + where + T: CliOAuthAdapter + Send + Sync + 'static, + { + let (auth_url, _csrf_state) = self.authorize_url(); + + println!( + "Open the following URL in your browser to authenticate:\n\n {}\n", + auth_url + ); + + // Channel to receive the token extracted from the server response. + let (token_tx, token_rx) = oneshot::channel::>(); + let token_tx: Arc>>>> = + Arc::new(Mutex::new(Some(token_tx))); + + // Channel to shut down the proxy server once we have the token. + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + + let port = self.port; + + // Spawn the local proxy server in a background task. + tokio::spawn({ + let callback_token_tx = Arc::clone(&token_tx); + let error_token_tx = Arc::clone(&token_tx); + + async move { + let callback = Arc::new(move |request: Request| { + let adapter = Arc::clone(&adapter); + let token_tx = Arc::clone(&callback_token_tx); + + Box::pin(async move { + // Forward the redirect request to the API server. + let response = adapter + .exchange_authorization_code(request) + .await + .map_err(|e| anyhow::anyhow!(e))?; + + // The server responds with the access token in the body. + let (_parts, body) = response.into_parts(); + let body_bytes = body + .collect() + .await + .expect("Full collection cannot fail") + .to_bytes(); + let token = String::from_utf8(body_bytes.to_vec())?; + + // Send the token back to the main task. + if let Ok(mut guard) = token_tx.lock() { + if let Some(tx) = guard.take() { + let _ = tx.send(Ok(token)); + } + } + + // Return a friendly page to the browser so the user + // knows they can close the tab. + Ok(Response::builder() + .status(StatusCode::OK) + .header("content-type", "text/html; charset=utf-8") + .body(Full::new(Bytes::from(concat!( + "", + "

Authentication successful!

", + "

You can close this tab and return to the CLI.

", + "" + ))))?) + }) + as Pin< + Box>>> + Send>, + > + }); + + if let Err(e) = run_proxy_server(port, callback, shutdown_rx).await { + eprintln!("Proxy server error: {e}"); + + // If the proxy died before we got a token, unblock the + // receiver so the caller isn't stuck forever. + if let Ok(mut guard) = error_token_tx.lock() { + if let Some(tx) = guard.take() { + let _ = tx.send(Err(anyhow::anyhow!( + "Proxy server exited unexpectedly: {e}" + ))); + } + } + } + } + }); + + // Wait for the proxy callback to extract the token. + let token = token_rx.await.map_err(|_| { + anyhow::anyhow!( + "Authentication callback was never received — proxy server may have exited early" + ) + })??; + + // Tell the proxy server to stop. + let _ = shutdown_tx.send(()); + + Ok(token) + } +} type DeviceClient = BasicClient< // HasAuthUrl @@ -30,7 +250,10 @@ pub struct DeviceOAuth { } impl DeviceOAuth { - pub fn new(provider: T) -> Result where T: CliOAuthProviderInfo { + pub fn new(provider: T) -> Result + where + T: CliOAuthProviderInfo, + { if let Some(device_endpoint) = provider.device_code_endpoint() { let device_auth_url = DeviceAuthorizationUrl::new(device_endpoint.to_string())?; @@ -48,7 +271,7 @@ impl DeviceOAuth { .build() .unwrap(), ), - scopes: provider.scopes().into_iter().map(|s| s.to_string()).collect(), + scopes: provider.scopes().iter().map(|s| s.to_string()).collect(), }) } else { anyhow::bail!("Device authorization is not supported by this provider") @@ -81,11 +304,3 @@ impl DeviceOAuth { Ok(res?) } } - -pub trait CliOAuthProviderInfo { - fn device_code_endpoint(&self) -> Option<&str>; - fn auth_url_endpoint(&self) -> &str; - fn token_endpoint(&self) -> &str; - fn client_id(&self) -> &str; - fn scopes(&self) -> &[String]; -} diff --git a/v-cli-sdk/src/cmd/auth/proxy.rs b/v-cli-sdk/src/cmd/auth/proxy.rs new file mode 100644 index 00000000..47049421 --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/proxy.rs @@ -0,0 +1,129 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use std::future::Future; +use std::net::SocketAddr; +use std::sync::Arc; + +use http_body_util::Full; +use hyper::body::{Bytes, Incoming}; +use hyper::service::service_fn; +use hyper::{Request, Response}; +use hyper_util::rt::TokioIo; +use tokio::net::TcpListener; +use tokio::sync::oneshot; + +/// A callback function that receives an incoming HTTP request and returns a response. +pub type Callback = Arc< + dyn Fn( + Request, + ) + -> std::pin::Pin>>> + Send>> + + Send + + Sync, +>; + +/// Start a minimal HTTP server on the given port that forwards every incoming +/// request to `callback` and returns whatever response the callback produces. +/// +/// The server will run until a message is sent on the `shutdown` channel, at +/// which point it will stop accepting new connections and return. +pub async fn run_proxy_server( + port: u16, + callback: Callback, + shutdown: oneshot::Receiver<()>, +) -> anyhow::Result<()> { + let addr = SocketAddr::from(([127, 0, 0, 1], port)); + let listener = TcpListener::bind(addr).await?; + serve_loop(listener, callback, shutdown).await +} + +/// Core accept-loop shared by [`run_proxy_server`] and tests. +/// +/// Accepts connections on `listener`, forwarding each request to `callback`. +/// Stops when `shutdown` fires. +async fn serve_loop( + listener: TcpListener, + callback: Callback, + shutdown: oneshot::Receiver<()>, +) -> anyhow::Result<()> { + tokio::pin!(shutdown); + + loop { + tokio::select! { + _ = &mut shutdown => { + break; + } + accepted = listener.accept() => { + let (stream, _remote_addr) = accepted?; + let io = TokioIo::new(stream); + let cb = Arc::clone(&callback); + + tokio::task::spawn(async move { + let service = service_fn(move |req: Request| { + let cb = Arc::clone(&cb); + async move { cb(req).await } + }); + + if let Err(err) = + hyper::server::conn::http1::Builder::new() + .serve_connection(io, service) + .await + { + eprintln!("Error serving connection: {err}"); + } + }); + } + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use hyper::StatusCode; + + #[tokio::test] + async fn test_proxy_server_responds() { + let callback: Callback = Arc::new(|_req| { + Box::pin(async { + Ok(Response::builder() + .status(StatusCode::OK) + .body(Full::new(Bytes::from("hello from callback"))) + .unwrap()) + }) + }); + + let (tx, rx) = oneshot::channel::<()>(); + + // Use port 0 to let the OS pick an available port. + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn({ + let callback = Arc::clone(&callback); + async move { + serve_loop(listener, callback, rx).await.unwrap(); + } + }); + + // Send a request to the server. + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://{}", local_addr)) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + assert_eq!(resp.text().await.unwrap(), "hello from callback"); + + // Shut down the server. + tx.send(()).unwrap(); + server_handle.await.unwrap(); + } +} diff --git a/v-cli-sdk/src/cmd/config/mod.rs b/v-cli-sdk/src/cmd/config/mod.rs index d643965b..1a1cde05 100644 --- a/v-cli-sdk/src/cmd/config/mod.rs +++ b/v-cli-sdk/src/cmd/config/mod.rs @@ -74,7 +74,10 @@ pub enum SetCmd { } impl ConfigCmd { - pub async fn run(&self, ctx: &mut T) -> Result<()> where T: CliContext { + pub async fn run(&self, ctx: &mut T) -> Result<()> + where + T: CliContext, + { match &self.setting { SettingCmd::Get(get) => get.run(ctx.config()).await?, SettingCmd::Set(set) => set.run(ctx.config_mut()).await?, @@ -85,10 +88,19 @@ impl ConfigCmd { } impl GetCmd { - pub async fn run(&self, config: &T) -> Result<()> where T: CliConfig { + pub async fn run(&self, config: &T) -> Result<()> + where + T: CliConfig, + { match &self { GetCmd::Format => { - println!("{}", config.default_format().map(|f| *f).unwrap_or(FormatStyle::Json)); + println!( + "{}", + config + .default_format() + .map(|f| *f) + .unwrap_or(FormatStyle::Json) + ); } GetCmd::Host => { println!("{}", config.host().unwrap_or("None")); @@ -109,7 +121,10 @@ impl GetCmd { } impl SetCmd { - pub async fn run(&self, config: &mut T) -> Result<()> where T: CliConfig { + pub async fn run(&self, config: &mut T) -> Result<()> + where + T: CliConfig, + { match &self { SetCmd::Format { format } => { config.set_default_format(format.clone()); diff --git a/v-cli-sdk/src/err.rs b/v-cli-sdk/src/err.rs index 2c7670e4..fd22c123 100644 --- a/v-cli-sdk/src/err.rs +++ b/v-cli-sdk/src/err.rs @@ -2,12 +2,16 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at https://mozilla.org/MPL/2.0/. -use anyhow::{Error, anyhow}; +use anyhow::{anyhow, Error}; use progenitor_client::Error as ProgenitorClientError; use crate::{ApiErrorMessage, CliContext, VerbosityLevel}; -pub fn format_api_err(ctx: &T, client_err: ProgenitorClientError) -> Error where T: CliContext, E: ApiErrorMessage { +pub fn format_api_err(ctx: &T, client_err: ProgenitorClientError) -> Error +where + T: CliContext, + E: ApiErrorMessage, +{ let mut err = anyhow!("API Request failed"); match client_err { @@ -25,14 +29,23 @@ pub fn format_api_err(ctx: &T, client_err: ProgenitorClientError) let response_message = response.into_inner(); if ctx.verbosity() >= VerbosityLevel::All { - err = err.context(format!("Request {}", response_message.request_id().as_deref().unwrap_or(""))); + err = err.context(format!( + "Request {}", + response_message.request_id().as_deref().unwrap_or("") + )); } err = err.context(format!( "Code: {}", response_message.error_code().as_deref().unwrap_or("") )); - err = err.context(response_message.message().as_deref().unwrap_or("").to_string()); + err = err.context( + response_message + .message() + .as_deref() + .unwrap_or("") + .to_string(), + ); } ProgenitorClientError::InvalidRequest(message) => { err = err.context("Invalid request").context(message); diff --git a/v-cli-sdk/src/lib.rs b/v-cli-sdk/src/lib.rs index b5375292..e0a8d374 100644 --- a/v-cli-sdk/src/lib.rs +++ b/v-cli-sdk/src/lib.rs @@ -2,7 +2,10 @@ use clap::ValueEnum; use serde::{Deserialize, Serialize}; use std::fmt::Display; -use crate::cmd::{auth::login::{CliMagicLinkAdapter, CliOAuthAdapter}, config::CliConfig}; +use crate::cmd::{ + auth::{login::CliMagicLinkAdapter, oauth::CliOAuthAdapter}, + config::CliConfig, +}; pub mod cmd; pub mod err; @@ -41,8 +44,12 @@ pub trait CliContext { fn printer(&self) -> Option<&P>; fn verbosity(&self) -> VerbosityLevel; - fn oauth_adapter(&self) -> impl CliOAuthAdapter; - fn mlink_adapter(&self) -> impl CliMagicLinkAdapter; + fn oauth_adapter( + &self, + ) -> impl CliOAuthAdapter + Send + Sync + 'static; + fn mlink_adapter( + &self, + ) -> impl CliMagicLinkAdapter + Send + Sync + 'static; } pub trait ApiErrorMessage { From eea958f7d143598393490f8a47cd17182f08a924 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 30 Apr 2026 10:39:18 -0500 Subject: [PATCH 04/51] Fmt and lint --- v-api/src/endpoints/handlers.rs | 5 ++-- v-api/src/endpoints/login/oauth/code.rs | 37 ++++++++++++++++++++++--- v-cli-sdk/src/cmd/mod.rs | 4 +++ v-cli-sdk/src/lib.rs | 4 +++ 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/v-api/src/endpoints/handlers.rs b/v-api/src/endpoints/handlers.rs index f2dd46e4..070977bf 100644 --- a/v-api/src/endpoints/handlers.rs +++ b/v-api/src/endpoints/handlers.rs @@ -72,7 +72,7 @@ mod macros { code::{ authz_code_callback_op, authz_code_exchange_op, authz_code_redirect_op, OAuthAuthzCodeExchangeBody, OAuthAuthzCodeExchangeResponse, - OAuthAuthzCodeQuery, OAuthAuthzCodeReturnQuery, + OAuthAuthzCodeQuery, OAuthAuthzCodeReturnQuery, OAuthAuthzCodeExchangeQuery }, device_token::{ exchange_device_token_op, get_device_provider_op, AccessTokenExchangeRequest, @@ -296,9 +296,10 @@ mod macros { pub async fn authz_code_exchange( rqctx: RequestContext<$context_type>, path: Path, + query: Query, body: TypedBody, ) -> Result, HttpError> { - authz_code_exchange_op(&rqctx, path, body).await + authz_code_exchange_op(&rqctx, path, query, body).await } // DEVICE CODE diff --git a/v-api/src/endpoints/login/oauth/code.rs b/v-api/src/endpoints/login/oauth/code.rs index 969547f9..c3cd164e 100644 --- a/v-api/src/endpoints/login/oauth/code.rs +++ b/v-api/src/endpoints/login/oauth/code.rs @@ -436,6 +436,12 @@ where Ok(attempt.callback_url()) } +#[derive(Debug, Deserialize, JsonSchema)] +pub struct OAuthAuthzCodeExchangeQuery { + #[serde(default)] + pub include_idp_token: bool, +} + #[derive(Debug, Deserialize, JsonSchema)] pub struct OAuthAuthzCodeExchangeBody { pub client_id: Option>, @@ -451,12 +457,19 @@ pub struct OAuthAuthzCodeExchangeResponse { pub access_token: String, pub token_type: String, pub expires_in: i64, + pub idp_token: Option, +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct OAuthAuthzCodeIdpToken { + pub token: String, } #[instrument(skip(rqctx), err(Debug))] pub async fn authz_code_exchange_op( rqctx: &RequestContext>, path: Path, + query: Query, body: TypedBody, ) -> Result, HttpError> where @@ -464,6 +477,7 @@ where { let ctx = rqctx.v_ctx(); let path = path.into_inner(); + let query = query.into_inner(); let body = body.into_inner(); let (client_id, client_secret) = @@ -537,7 +551,14 @@ where // Now that the attempt has been confirmed, use it to fetch user information form the remote // provider - let info = fetch_user_info(ctx.public_url(), &ctx.web_client(), &*provider, &attempt).await?; + let (info, raw_token) = fetch_user_info( + ctx.public_url(), + &ctx.web_client(), + &*provider, + &attempt, + query.include_idp_token, + ) + .await?; tracing::debug!("Retrieved user information from remote provider"); @@ -585,6 +606,9 @@ where token_type: "Bearer".to_string(), access_token: token.signed_token, expires_in: token.expires_in, + idp_token: query.include_idp_token.then(|| OAuthAuthzCodeIdpToken { + token: raw_token.unwrap(), + }), })) } @@ -709,7 +733,8 @@ async fn fetch_user_info( client_type: &ClientType, provider: &dyn OAuthProvider, attempt: &LoginAttempt, -) -> Result { + return_raw: bool, +) -> Result<(UserInfo, Option), HttpError> { // Exchange the stored authorization code with the remote provider for a remote access token let client = provider.as_web_client().map_err(to_internal_error)?; @@ -746,7 +771,7 @@ async fn fetch_user_info( // Now that we are done with fetching user information from the remote API, we can revoke it if // the provider supports it - if provider.token_revocation_endpoint().is_some() { + if !return_raw && provider.token_revocation_endpoint().is_some() { client .revoke_token(response.access_token().into()) .map_err(internal_error)? @@ -755,7 +780,11 @@ async fn fetch_user_info( .map_err(internal_error)?; } - Ok(info) + if return_raw { + Ok((info, Some(response.access_token().secret().to_string()))) + } else { + Ok((info, None)) + } } #[cfg(test)] diff --git a/v-cli-sdk/src/cmd/mod.rs b/v-cli-sdk/src/cmd/mod.rs index c73c20da..df67c4a3 100644 --- a/v-cli-sdk/src/cmd/mod.rs +++ b/v-cli-sdk/src/cmd/mod.rs @@ -1,2 +1,6 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + pub mod auth; pub mod config; diff --git a/v-cli-sdk/src/lib.rs b/v-cli-sdk/src/lib.rs index e0a8d374..652ab27c 100644 --- a/v-cli-sdk/src/lib.rs +++ b/v-cli-sdk/src/lib.rs @@ -1,3 +1,7 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + use clap::ValueEnum; use serde::{Deserialize, Serialize}; use std::fmt::Display; From 3ada182f6cf9f5d8917327f5a98b54382317a2ad Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 30 Apr 2026 10:43:36 -0500 Subject: [PATCH 05/51] Lint and fmt --- v-cli-sdk/src/cmd/config/mod.rs | 4 ++-- v-cli-sdk/src/err.rs | 12 +++--------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/v-cli-sdk/src/cmd/config/mod.rs b/v-cli-sdk/src/cmd/config/mod.rs index 1a1cde05..bf2d123e 100644 --- a/v-cli-sdk/src/cmd/config/mod.rs +++ b/v-cli-sdk/src/cmd/config/mod.rs @@ -98,7 +98,7 @@ impl GetCmd { "{}", config .default_format() - .map(|f| *f) + .copied() .unwrap_or(FormatStyle::Json) ); } @@ -127,7 +127,7 @@ impl SetCmd { { match &self { SetCmd::Format { format } => { - config.set_default_format(format.clone()); + config.set_default_format(*format); config.save()?; } SetCmd::Host { host } => { diff --git a/v-cli-sdk/src/err.rs b/v-cli-sdk/src/err.rs index fd22c123..e84eeb36 100644 --- a/v-cli-sdk/src/err.rs +++ b/v-cli-sdk/src/err.rs @@ -31,21 +31,15 @@ where if ctx.verbosity() >= VerbosityLevel::All { err = err.context(format!( "Request {}", - response_message.request_id().as_deref().unwrap_or("") + response_message.request_id().unwrap_or("") )); } err = err.context(format!( "Code: {}", - response_message.error_code().as_deref().unwrap_or("") + response_message.error_code().unwrap_or("") )); - err = err.context( - response_message - .message() - .as_deref() - .unwrap_or("") - .to_string(), - ); + err = err.context(response_message.message().unwrap_or("").to_string()); } ProgenitorClientError::InvalidRequest(message) => { err = err.context("Invalid request").context(message); From cd8d1eedb97644c85a83e3d036b36a040e91afa1 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Tue, 5 May 2026 14:08:24 -0500 Subject: [PATCH 06/51] A lot of refactoring of OAuth internals --- Cargo.lock | 25 ++ Cargo.toml | 2 + v-api/src/config.rs | 77 +++- v-api/src/context/mod.rs | 19 +- v-api/src/endpoints/handlers.rs | 19 +- v-api/src/endpoints/login/oauth/code.rs | 271 ++++++++++---- .../src/endpoints/login/oauth/device_token.rs | 336 +++++++++--------- v-api/src/endpoints/login/oauth/github.rs | 139 +++----- v-api/src/endpoints/login/oauth/google.rs | 157 ++++---- v-api/src/endpoints/login/oauth/mod.rs | 187 +++++++--- v-api/src/endpoints/login/oauth/zendesk.rs | 190 +++++----- v-api/src/secrets.rs | 2 +- v-cli-sdk/Cargo.toml | 4 + v-cli-sdk/src/cmd/auth/login.rs | 43 ++- v-cli-sdk/src/cmd/auth/mod.rs | 8 +- v-cli-sdk/src/cmd/auth/oauth.rs | 54 ++- v-cli-sdk/src/cmd/config/mod.rs | 20 +- v-cli-sdk/src/err.rs | 6 +- v-cli-sdk/src/lib.rs | 27 +- v-cli-sdk/src/printer/mod.rs | 227 ++++++++++++ 20 files changed, 1142 insertions(+), 671 deletions(-) create mode 100644 v-cli-sdk/src/printer/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 5a3f6cf2..2a07fcd3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2043,6 +2043,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" +[[package]] +name = "owo-colors" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d211803b9b6b570f68772237e415a029d5a50c65d382910b879fb19d3271f94d" + [[package]] name = "p256" version = "0.13.2" @@ -3347,6 +3353,15 @@ dependencies = [ "libc", ] +[[package]] +name = "tabwriter" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fce91f2f0ec87dff7e6bcbbeb267439aa1188703003c6055193c821487400432" +dependencies = [ + "unicode-width", +] + [[package]] name = "take_mut" version = "0.2.2" @@ -3737,6 +3752,12 @@ version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + [[package]] name = "unicode-xid" version = "0.2.6" @@ -3887,9 +3908,13 @@ dependencies = [ "hyper-util", "oauth2", "oauth2-reqwest", + "owo-colors", "progenitor-client", "reqwest 0.13.2", + "schemars 0.8.22", "serde", + "serde_json", + "tabwriter", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index e2353bad..2b476f4d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ mockall = "0.14.0" newtype-uuid = { version = "1.3.2", features = ["schemars08", "serde", "v4"] } oauth2 = { version = "5.0.0", default-features = false, features = ["rustls-tls"] } oauth2-reqwest = "0.1.0-alpha.3" +owo-colors = "4.2.3" partial-struct = { git = "https://github.com/oxidecomputer/partial-struct" } percent-encoding = "2.3.2" proc-macro2 = "1" @@ -56,6 +57,7 @@ sha2 = "0.11.0" slog = "2.8.2" steno = { git = "https://github.com/oxidecomputer/steno" } syn = "2" +tabwriter = "1.4.1" tap = "1.0.1" tempfile = "3" thiserror = "2" diff --git a/v-api/src/config.rs b/v-api/src/config.rs index 917063ce..e56a07fc 100644 --- a/v-api/src/config.rs +++ b/v-api/src/config.rs @@ -8,20 +8,21 @@ use jsonwebtoken::jwk::{ AlgorithmParameters, CommonParameters, Jwk, KeyAlgorithm, PublicKeyUse, RSAKeyParameters, RSAKeyType, }; +use partial_struct::partial; use rsa::{ pkcs1v15::{SigningKey, VerifyingKey}, pkcs8::{DecodePrivateKey, DecodePublicKey}, traits::PublicKeyParts, RsaPrivateKey, RsaPublicKey, }; -use secrecy::ExposeSecret; +use secrecy::{ExposeSecret, SecretString}; use serde::{ de::{self, Visitor}, Deserialize, Deserializer, }; use std::path::PathBuf; use thiserror::Error; -use v_api_param::StringParam; +use v_api_param::{ParamResolutionError, StringParam}; use crate::{ authn::{ @@ -154,25 +155,93 @@ pub struct OAuthProviders { pub zendesk: Option, } +#[partial(ResolvedOAuthConfig)] #[derive(Debug, Deserialize)] pub struct OAuthConfig { - pub device: OAuthDeviceConfig, - pub web: OAuthWebConfig, + #[partial(ResolvedOAuthConfig(retype = Option))] + pub device: Option, + #[partial(ResolvedOAuthConfig(retype = Option))] + pub web: Option, + #[partial(ResolvedOAuthConfig(retype = Option))] + pub proxy_web: Option, } +#[partial(ResolvedOAuthDeviceConfig)] #[derive(Debug, Deserialize)] pub struct OAuthDeviceConfig { pub client_id: String, + #[partial(ResolvedOAuthDeviceConfig(retype = SecretString))] pub client_secret: StringParam, } +#[partial(ResolvedOAuthWebConfig)] #[derive(Debug, Deserialize)] pub struct OAuthWebConfig { pub client_id: String, + #[partial(ResolvedOAuthWebConfig(retype = SecretString))] pub client_secret: StringParam, pub redirect_uri: String, } +#[partial(ResolvedOAuthWebProxyConfig)] +#[derive(Debug, Deserialize)] +pub struct OAuthWebProxyConfig { + pub client_id: String, + pub proxy_port: u16, +} + +impl OAuthConfig { + pub fn resolve( + &self, + base: Option, + ) -> Result { + let device = self.device.as_ref().and_then(|d| d.resolve(base).ok()); + let web = self.web.as_ref().and_then(|w| w.resolve(base).ok()); + let proxy_web = self.proxy_web.as_ref().and_then(|p| p.resolve(base).ok()); + Ok(ResolvedOAuthConfig { + device, + web, + proxy_web, + }) + } +} +impl OAuthDeviceConfig { + pub fn resolve( + &self, + base: Option, + ) -> Result { + let client_secret = self.client_secret.resolve(base)?; + Ok(ResolvedOAuthDeviceConfig { + client_id: self.client_id.clone(), + client_secret, + }) + } +} +impl OAuthWebConfig { + pub fn resolve( + &self, + base: Option, + ) -> Result { + let client_secret = self.client_secret.resolve(base)?; + Ok(ResolvedOAuthWebConfig { + client_id: self.client_id.clone(), + client_secret, + redirect_uri: self.redirect_uri.clone(), + }) + } +} +impl OAuthWebProxyConfig { + pub fn resolve( + &self, + base: Option, + ) -> Result { + Ok(ResolvedOAuthWebProxyConfig { + client_id: self.client_id.clone(), + proxy_port: self.proxy_port, + }) + } +} + impl AsymmetricKey { pub fn resolve_signer(&self, path: Option) -> Result { Ok(Signer::new( diff --git a/v-api/src/context/mod.rs b/v-api/src/context/mod.rs index f0dbd40c..5672eab4 100644 --- a/v-api/src/context/mod.rs +++ b/v-api/src/context/mod.rs @@ -1246,7 +1246,9 @@ pub(crate) mod test_mocks { use crate::{ config::JwtConfig, - endpoints::login::oauth::{google::GoogleOAuthProvider, OAuthProviderName}, + endpoints::login::oauth::{ + google::GoogleOAuthProvider, zendesk::ZendeskOAuthProvider, OAuthProviderName, + }, mapper::DefaultMappingEngine, permissions::VPermission, util::tests::{mock_key, MockKey}, @@ -1288,6 +1290,21 @@ pub(crate) mod test_mocks { }), ); + ctx.auth.insert_oauth_provider( + OAuthProviderName::Zendesk, + Box::new(move || { + Box::new(ZendeskOAuthProvider::new( + "https://test_public_url".to_string(), + "subdomain".to_string(), + "google_device_client_id".to_string(), + "google_device_client_secret".to_string().into(), + "google_web_client_id".to_string(), + "google_web_client_secret".to_string().into(), + None, + )) + }), + ); + ctx } diff --git a/v-api/src/endpoints/handlers.rs b/v-api/src/endpoints/handlers.rs index 070977bf..f5de62bb 100644 --- a/v-api/src/endpoints/handlers.rs +++ b/v-api/src/endpoints/handlers.rs @@ -71,6 +71,7 @@ mod macros { }, code::{ authz_code_callback_op, authz_code_exchange_op, authz_code_redirect_op, + get_web_pkce_provider_op, OAuthAuthzCodeExchangeBody, OAuthAuthzCodeExchangeResponse, OAuthAuthzCodeQuery, OAuthAuthzCodeReturnQuery, OAuthAuthzCodeExchangeQuery }, @@ -304,7 +305,19 @@ mod macros { // DEVICE CODE - /// Retrieve the metadata about an OAuth provider + /// Retrieve the metadata about an OAuth provider for public authorization code flow + #[endpoint { + method = GET, + path = "/login/oauth/{provider}/web-pkce" + }] + pub async fn get_web_pkce_provider( + rqctx: RequestContext<$context_type>, + path: Path, + ) -> Result, HttpError> { + get_web_pkce_provider_op(&rqctx, path).await + } + + /// Retrieve the metadata about an OAuth provider for limited input flow #[endpoint { method = GET, path = "/login/oauth/{provider}/device" @@ -757,7 +770,9 @@ mod macros { $api.register(authz_code_exchange) .expect("Failed to register endpoint"); - // OAuth Device Login + // OAuth Login + $api.register(get_web_provider) + .expect("Failed to register endpoint"); $api.register(get_device_provider) .expect("Failed to register endpoint"); $api.register(exchange_device_token) diff --git a/v-api/src/endpoints/login/oauth/code.rs b/v-api/src/endpoints/login/oauth/code.rs index c3cd164e..1528f495 100644 --- a/v-api/src/endpoints/login/oauth/code.rs +++ b/v-api/src/endpoints/login/oauth/code.rs @@ -12,7 +12,7 @@ use dropshot::{ }; use dropshot_authorization_header::basic::BasicAuth; use http::{header::SET_COOKIE, HeaderValue}; -use newtype_uuid::TypedUuid; +use newtype_uuid::{GenericUuid, TypedUuid}; use oauth2::{ AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, }; @@ -24,6 +24,7 @@ use sha2::{Digest, Sha256}; use std::{fmt::Debug, ops::Add}; use tap::TapFallible; use tracing::instrument; +use uuid::Uuid; use v_model::{ permissions::{AsScope, PermissionStorage}, schema_ext::LoginAttemptState, @@ -35,11 +36,12 @@ use crate::{ authn::key::RawKey, context::{ApiContext, VContext}, endpoints::login::{ - oauth::{CheckOAuthClient, ClientType}, + oauth::{CheckOAuthClient, ClientType, OAuthProviderAuthorizationCodePkceInfo}, LoginError, UserInfo, }, error::ApiError, permissions::{VAppPermission, VPermission}, + response::bad_request, secrets::OpenApiSecretString, util::{ request::RequestCookies, @@ -152,6 +154,32 @@ where } } +#[instrument(skip(rqctx), err(Debug))] +pub async fn get_web_pkce_provider_op( + rqctx: &RequestContext>, + path: Path, +) -> Result, HttpError> +where + T: VAppPermission + PermissionStorage, +{ + let path = path.into_inner(); + + tracing::trace!("Getting OAuth data for {}", path.provider); + + let provider = rqctx + .v_ctx() + .get_oauth_provider(&path.provider) + .await + .map_err(ApiError::OAuth)?; + + Ok(HttpResponseOk( + provider + .authz_code_pkce_flow_info() + .cloned() + .ok_or_else(|| bad_request("Provider does not support web pkce clients"))?, + )) +} + #[instrument(skip(rqctx), err(Debug))] pub async fn authz_code_redirect_op( rqctx: &RequestContext>, @@ -256,7 +284,7 @@ fn oauth_redirect_response( .authorize_url(|| CsrfToken::new(attempt.id.to_string())) .add_scopes( provider - .scopes() + .default_scopes() .into_iter() .map(|s| Scope::new(s.to_string())) .collect::>(), @@ -306,7 +334,7 @@ fn verify_csrf( .value() .parse() .map_err(|err| { - tracing::warn!(?err, "Failed to parse state"); + tracing::warn!(?err, "Failed to parse state cookie"); unauthorized() })?; @@ -480,46 +508,71 @@ where let query = query.into_inner(); let body = body.into_inner(); - let (client_id, client_secret) = - if let (Some(client_id), Some(client_secret)) = (body.client_id, body.client_secret) { - Ok::<_, HttpError>((client_id, client_secret)) - } else { - // Attempt to extract basic authorization credentials from the request if they were not - // present in the request body - let auth = ::from_request(rqctx) - .await - .tap_err(|err| { - tracing::warn!(?err, "Failed to extract basic authentication values"); - }); - let (client_id, client_secret) = match auth { - Ok(auth) if auth.username().is_some() && auth.password().is_some() => Ok(( - auth.username().unwrap().to_string(), - auth.password().unwrap().to_string(), - )), - _ => Err(internal_error( - "Missing client id and client secret from authz code exchange", - )), - }?; - - Ok(( - client_id.parse().map_err(to_internal_error)?, - OpenApiSecretString(client_secret.into()), - )) - }?; - let provider = ctx .get_oauth_provider(&path.provider) .await .map_err(ApiError::OAuth)?; + // Extract basic authorization credentials from the request if they were provided. + let auth = ::from_request(rqctx) + .await + .tap_err(|err| { + tracing::warn!(?err, "Failed to extract basic authentication values"); + }); + let basic_credentials = match auth { + Ok(auth) if auth.username().is_some() && auth.password().is_some() => Ok(Some(( + TypedUuid::from_untyped_uuid( + Uuid::parse_str(auth.username().unwrap()) + .map_err(|_| bad_request("Malformed client ID presented to code exchange"))?, + ), + auth.password().unwrap().to_string(), + ))), + Ok(_) => Err(bad_request( + "Malformed credentials presented to code exchange", + )), + Err(err) => { + tracing::info!( + ?err, + "Credentials for code exchange not defined via basic auth" + ); + Ok(None) + } + }?; + + // Extract credentials from the request body if they were provided. + let body_credentials = (body.client_id, body.client_secret); + + // Now validate if the credentials provided by the client support one of our expected schemes. + // We of course deny underspecifying credentials, but we also want to disallow over specifying + // them. For example, if the client provides both basic auth and a client id/secret in the + // request body, we should reject the request. + let (client_id, client_secret) = match (basic_credentials, body_credentials) { + (Some(_), (Some(_), _)) => Err(bad_request( + "Cannot provide both basic auth and client credentials", + )), + (Some(_), (_, Some(_))) => Err(bad_request( + "Cannot provide both basic auth and client credentials", + )), + (Some((client_id, client_secret)), (None, None)) => Ok(( + client_id, + Some(OpenApiSecretString(SecretString::from(client_secret))), + )), + (None, (Some(client_id), Some(client_secret))) => Ok((client_id, Some(client_secret))), + (None, (Some(client_id), _)) if provider.authz_code_pkce_flow_info().is_some() => { + Ok((client_id, None)) + } + _ => Err(bad_request("Missing client credentials")), + }?; + tracing::debug!("Attempting code exchange"); // Verify the submitted client credentials authorize_code_exchange( ctx, + &*provider, &body.grant_type, client_id, - &client_secret.0, + client_secret.map(|s| s.0).as_ref(), &body.redirect_uri, ) .await?; @@ -614,9 +667,10 @@ where async fn authorize_code_exchange( ctx: &VContext, + provider: &dyn OAuthProvider, grant_type: &str, client_id: TypedUuid, - client_secret: &SecretString, + client_secret: Option<&SecretString>, redirect_uri: &str, ) -> Result<(), OAuthError> where @@ -636,30 +690,44 @@ where tracing::debug!(grant_type, "Verified grant type"); - let client_secret = RawKey::try_from(client_secret).map_err(|err| { - tracing::warn!(?err, "Failed to parse OAuth client secret"); + // If we were provided a client secret, then it must be verified. If a client secret was not + // provided, then we can skip this step as long as the provider supports pkce_only + // authentication. + if let Some(client_secret) = client_secret { + let client_secret = RawKey::try_from(client_secret).map_err(|err| { + tracing::warn!(?err, "Failed to parse OAuth client secret"); - OAuthError { - error: OAuthErrorCode::InvalidRequest, - error_description: Some("Malformed client secret".to_string()), - error_uri: None, - state: None, - } - })?; + OAuthError { + error: OAuthErrorCode::InvalidRequest, + error_description: Some("Malformed client secret".to_string()), + error_uri: None, + state: None, + } + })?; - tracing::debug!("Constructed client secret"); + tracing::debug!("Constructed client secret"); - if !client.is_secret_valid(&client_secret, ctx) { + if !client.is_secret_valid(&client_secret, ctx) { + Err(OAuthError { + error: OAuthErrorCode::InvalidClient, + error_description: Some("Invalid client secret".to_string()), + error_uri: None, + state: None, + }) + } else { + tracing::debug!("Verified client secret validity"); + + Ok(()) + } + } else if provider.authz_code_pkce_flow_info().is_some() { + Ok(()) + } else { Err(OAuthError { - error: OAuthErrorCode::InvalidClient, - error_description: Some("Invalid client secret".to_string()), + error: OAuthErrorCode::InvalidRequest, + error_description: Some("Client secret required".to_string()), error_uri: None, state: None, }) - } else { - tracing::debug!("Verified client secret validity"); - - Ok(()) } } @@ -735,6 +803,9 @@ async fn fetch_user_info( attempt: &LoginAttempt, return_raw: bool, ) -> Result<(UserInfo, Option), HttpError> { + let provider_info = provider + .authz_code_flow_info() + .ok_or_else(|| internal_error("Authorization code flow not supported"))?; // Exchange the stored authorization code with the remote provider for a remote access token let client = provider.as_web_client().map_err(to_internal_error)?; @@ -771,7 +842,7 @@ async fn fetch_user_info( // Now that we are done with fetching user information from the remote API, we can revoke it if // the provider supports it - if !return_raw && provider.token_revocation_endpoint().is_some() { + if !return_raw && provider_info.revocation_endpoint.is_some() { client .revoke_token(response.access_token().into()) .map_err(internal_error)? @@ -1295,15 +1366,20 @@ mod tests { storage.oauth_client_store = Some(Arc::new(client_store)); ctx.set_storage(Arc::new(storage)); + let provider = ctx + .get_oauth_provider(&OAuthProviderName::Google) + .await + .unwrap(); // 1. Verify exchange fails when passing an incorrect client id assert_eq!( Some("Unknown client id".to_string()), authorize_code_exchange( &ctx, + &*provider, "authorization_code", wrong_client_id, - &client_secret, + Some(&client_secret), &redirect_uri, ) .await @@ -1316,9 +1392,10 @@ mod tests { Some("Invalid redirect uri".to_string()), authorize_code_exchange( &ctx, + &*provider, "authorization_code", client_id, - &client_secret, + Some(&client_secret), "wrong-callback-destination", ) .await @@ -1326,14 +1403,75 @@ mod tests { .error_description ); - // 3. Verify a successful exchange + // 3. Verify a successful exchange with a client secret assert_eq!( (), authorize_code_exchange( &ctx, + &*provider, "authorization_code", client_id, - &client_secret, + Some(&client_secret), + &redirect_uri, + ) + .await + .unwrap() + ); + } + + #[tokio::test] + async fn test_exchange_requires_secret_except_for_pkce_only() { + let (mut ctx, client, _) = mock_client().await; + let client_id = client.id; + let redirect_uri = client.redirect_uris[0].redirect_uri.clone(); + + let mut client_store = MockOAuthClientStore::new(); + client_store + .expect_get() + .with(eq(client_id), eq(false)) + .returning(move |_, _| Ok(Some(client.clone()))); + + let mut storage = MockStorage::new(); + storage.oauth_client_store = Some(Arc::new(client_store)); + + ctx.set_storage(Arc::new(storage)); + + let provider = ctx + .get_oauth_provider(&OAuthProviderName::Google) + .await + .unwrap(); + let pkce_only_provider = ctx + .get_oauth_provider(&OAuthProviderName::Zendesk) + .await + .unwrap(); + + // 1. Verify exchange fails when not passing a client secret for a client that does not + // support pkce_only + assert_eq!( + Some("Client secret required".to_string()), + authorize_code_exchange( + &ctx, + &*provider, + "authorization_code", + client_id, + None, + &redirect_uri, + ) + .await + .unwrap_err() + .error_description + ); + + // 2. Verify exchange passes when omitting the client secret for a client that does + // support pkce_only + assert_eq!( + (), + authorize_code_exchange( + &ctx, + &*pkce_only_provider, + "authorization_code", + client_id, + None, &redirect_uri, ) .await @@ -1357,14 +1495,19 @@ mod tests { storage.oauth_client_store = Some(Arc::new(client_store)); ctx.set_storage(Arc::new(storage)); + let provider = ctx + .get_oauth_provider(&OAuthProviderName::Google) + .await + .unwrap(); assert_eq!( OAuthErrorCode::UnsupportedGrantType, authorize_code_exchange( &ctx, + &*provider, "not_authorization_code", client_id, - &client_secret, + Some(&client_secret), &redirect_uri ) .await @@ -1376,9 +1519,10 @@ mod tests { (), authorize_code_exchange( &ctx, + &*provider, "authorization_code", client_id, - &client_secret, + Some(&client_secret), &redirect_uri ) .await @@ -1402,6 +1546,10 @@ mod tests { storage.oauth_client_store = Some(Arc::new(client_store)); ctx.set_storage(Arc::new(storage)); + let provider = ctx + .get_oauth_provider(&OAuthProviderName::Google) + .await + .unwrap(); let invalid_secret = RawKey::generate::<8>(&Uuid::new_v4()) .sign(&*ctx.signer()) @@ -1414,9 +1562,10 @@ mod tests { OAuthErrorCode::InvalidRequest, authorize_code_exchange( &ctx, + &*provider, "authorization_code", client_id, - &"too-short".to_string().into(), + Some(&"too-short".to_string().into()), &redirect_uri ) .await @@ -1428,9 +1577,10 @@ mod tests { OAuthErrorCode::InvalidClient, authorize_code_exchange( &ctx, + &*provider, "authorization_code", client_id, - &invalid_secret.into(), + Some(&invalid_secret.into()), &redirect_uri ) .await @@ -1442,9 +1592,10 @@ mod tests { (), authorize_code_exchange( &ctx, + &*provider, "authorization_code", client_id, - &client_secret, + Some(&client_secret), &redirect_uri ) .await diff --git a/v-api/src/endpoints/login/oauth/device_token.rs b/v-api/src/endpoints/login/oauth/device_token.rs index 22488a61..77530ea8 100644 --- a/v-api/src/endpoints/login/oauth/device_token.rs +++ b/v-api/src/endpoints/login/oauth/device_token.rs @@ -14,19 +14,21 @@ use tap::TapFallible; use tracing::instrument; use v_model::permissions::PermissionStorage; -use super::{ - ClientType, OAuthProvider, OAuthProviderInfo, OAuthProviderNameParam, UserInfoProvider, -}; +use super::{OAuthProviderNameParam, UserInfoProvider}; use crate::{ - context::ApiContext, endpoints::login::LoginError, error::ApiError, - permissions::VAppPermission, response::internal_error, util::response::bad_request, + context::ApiContext, + endpoints::login::{oauth::OAuthProviderDeviceInfo, LoginError}, + error::ApiError, + permissions::VAppPermission, + response::internal_error, + util::response::bad_request, }; #[instrument(skip(rqctx), err(Debug))] pub async fn get_device_provider_op( rqctx: &RequestContext>, path: Path, -) -> Result, HttpError> +) -> Result, HttpError> where T: VAppPermission + PermissionStorage, { @@ -40,10 +42,15 @@ where .await .map_err(ApiError::OAuth)?; - Ok(HttpResponseOk(provider.provider_info(&ClientType::Device))) + Ok(HttpResponseOk( + provider + .device_code_flow_info() + .cloned() + .ok_or_else(|| bad_request("Provider does not support device clients"))?, + )) } -#[derive(Debug, Deserialize, JsonSchema, Serialize)] +#[derive(Debug, Deserialize, Serialize, JsonSchema)] pub struct AccessTokenExchangeRequest { pub device_code: String, pub grant_type: String, @@ -65,21 +72,16 @@ pub struct ProviderTokenExchange { } impl AccessTokenExchange { - pub fn new( - req: AccessTokenExchangeRequest, - provider: &(dyn OAuthProvider + Send + Sync), - ) -> Option { - provider - .client_secret(&ClientType::Device) - .map(|client_secret| Self { - provider: ProviderTokenExchange { - client_id: provider.client_id(&ClientType::Device).to_string(), - device_code: req.device_code, - grant_type: req.grant_type, - client_secret: client_secret.expose_secret().to_string(), - }, - expires_at: req.expires_at, - }) + pub fn new(req: AccessTokenExchangeRequest, provider: &OAuthProviderDeviceInfo) -> Self { + Self { + provider: ProviderTokenExchange { + client_id: provider.client_id.clone(), + device_code: req.device_code, + grant_type: req.grant_type, + client_secret: provider.client_secret.0.expose_secret().to_string(), + }, + expires_at: req.expires_at, + } } } @@ -118,10 +120,11 @@ where .get_oauth_provider(&path.provider) .await .map_err(ApiError::OAuth)?; + let device_info = provider.device_code_flow_info(); tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for token exchange"); - if provider.device_code_endpoint().is_none() { + if device_info.is_none() { return Ok(Response::builder() .status(StatusCode::BAD_REQUEST) .header(header::CONTENT_TYPE, "application/json") @@ -139,156 +142,151 @@ where )?); } + let device_info = device_info.unwrap(); let exchange_request = body.into_inner(); - - if let Some(exchange) = AccessTokenExchange::new(exchange_request, &*provider) { - let token_exchange_endpoint = provider.token_exchange_endpoint(); - let client = reqwest::Client::new(); - - let response = client - .request(Method::POST, token_exchange_endpoint) - .header(header::CONTENT_TYPE, provider.token_exchange_content_type()) - .header(header::ACCEPT, HeaderValue::from_static("application/json")) - .body( - // We know that this is safe to unwrap as we just deserialized it via the body Extractor - serde_urlencoded::to_string(&exchange.provider).unwrap(), - ) - .send() - .await - .tap_err(|err| tracing::error!(?err, "Token exchange request failed")) - .map_err(internal_error)?; - - // Take a part the response as we will need the individual parts later - let status = response.status(); - let headers = response.headers().clone(); - let bytes = response.bytes().await.map_err(internal_error)?; - - // We unfortunately can not trust our providers to follow specs and therefore need to do - // our own inspection of the response to determine what to do - if !status.is_success() { - // If the server returned a non-success status then we are going to trust the server and - // report their error back to the client - tracing::debug!(provider = ?path.provider, ?headers, ?status, "Received error response from OAuth provider"); - - let mut client_response = Response::new(Body::from(bytes)); - *client_response.headers_mut() = headers; - *client_response.status_mut() = status; - - Ok(client_response) - } else { - // The server gave us back a non-error response but it still may not be a success. - // GitHub for instance does not use a status code for indicating the success or failure - // of a call. So instead we try to deserialize the body into an access token, with the - // understanding that it may fail and we will need to try and treat the response as - // an error instead. - - let parsed: Result< - StandardTokenResponse, - serde_json::Error, - > = serde_json::from_slice(&bytes); - - match parsed { - Ok(parsed) => { - let info = provider - .get_user_info(parsed.access_token().secret()) - .await - .map_err(LoginError::UserInfo) - .tap_err(|err| { - tracing::error!(?err, "Failed to look up user information") - })?; - - tracing::debug!("Verified and validated OAuth user"); - - let (api_user_info, api_user_provider) = ctx - .register_api_user(&ctx.builtin_registration_user(), info) - .await?; - - tracing::info!(api_user_id = ?api_user_info.user.id, api_user_provider_id = ?api_user_provider.id, "Retrieved api user to generate device token for"); - - let claims = - ctx.generate_claims(&api_user_info.user.id, &api_user_provider.id, None); - let token = ctx - .user - .register_access_token( - &ctx.builtin_registration_user(), - ctx.jwt_signer(), - &api_user_info.user.id, - &claims, - ) - .await?; - - tracing::info!(provider = ?path.provider, api_user_id = ?api_user_info.user.id, "Generated access token"); - - Ok(Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "application/json") - .body( - serde_json::to_string(&ProxyTokenResponse { - access_token: token.signed_token, - token_type: "Bearer".to_string(), - expires_in: Some(claims.exp - Utc::now().timestamp()), - refresh_token: None, - scopes: None, + let exchange = AccessTokenExchange::new(exchange_request, &device_info); + + let client = reqwest::Client::new(); + + let response = client + .request(Method::POST, &device_info.token_endpoint) + .header( + header::CONTENT_TYPE, + &device_info.token_endpoint_content_type, + ) + .header(header::ACCEPT, HeaderValue::from_static("application/json")) + .body( + // We know that this is safe to unwrap as we just deserialized it via the body Extractor + serde_urlencoded::to_string(&exchange.provider).unwrap(), + ) + .send() + .await + .tap_err(|err| tracing::error!(?err, "Token exchange request failed")) + .map_err(internal_error)?; + + // Take a part the response as we will need the individual parts later + let status = response.status(); + let headers = response.headers().clone(); + let bytes = response.bytes().await.map_err(internal_error)?; + + // We unfortunately can not trust our providers to follow specs and therefore need to do + // our own inspection of the response to determine what to do + if !status.is_success() { + // If the server returned a non-success status then we are going to trust the server and + // report their error back to the client + tracing::debug!(provider = ?path.provider, ?headers, ?status, "Received error response from OAuth provider"); + + let mut client_response = Response::new(Body::from(bytes)); + *client_response.headers_mut() = headers; + *client_response.status_mut() = status; + + Ok(client_response) + } else { + // The server gave us back a non-error response but it still may not be a success. + // GitHub for instance does not use a status code for indicating the success or failure + // of a call. So instead we try to deserialize the body into an access token, with the + // understanding that it may fail and we will need to try and treat the response as + // an error instead. + + let parsed: Result< + StandardTokenResponse, + serde_json::Error, + > = serde_json::from_slice(&bytes); + + match parsed { + Ok(parsed) => { + let info = provider + .get_user_info(parsed.access_token().secret()) + .await + .map_err(LoginError::UserInfo) + .tap_err(|err| tracing::error!(?err, "Failed to look up user information"))?; + + tracing::debug!("Verified and validated OAuth user"); + + let (api_user_info, api_user_provider) = ctx + .register_api_user(&ctx.builtin_registration_user(), info) + .await?; + + tracing::info!(api_user_id = ?api_user_info.user.id, api_user_provider_id = ?api_user_provider.id, "Retrieved api user to generate device token for"); + + let claims = + ctx.generate_claims(&api_user_info.user.id, &api_user_provider.id, None); + let token = ctx + .user + .register_access_token( + &ctx.builtin_registration_user(), + ctx.jwt_signer(), + &api_user_info.user.id, + &claims, + ) + .await?; + + tracing::info!(provider = ?path.provider, api_user_id = ?api_user_info.user.id, "Generated access token"); + + Ok(Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "application/json") + .body( + serde_json::to_string(&ProxyTokenResponse { + access_token: token.signed_token, + token_type: "Bearer".to_string(), + expires_in: Some(claims.exp - Utc::now().timestamp()), + refresh_token: None, + scopes: None, + }) + .unwrap() + .into(), + )?) + } + Err(_) => { + // Do not log the error here as we want to ensure we do not leak token information + tracing::debug!( + "Failed to parse a success response from the remote token endpoint" + ); + + // Try to deserialize the body again, but this time as an error + let mut error_response = match serde_json::from_slice::(&bytes) { + Ok(error) => { + // We found an error in the message body. This is not ideal, but we at + // least can understand what the server was trying to tell us + tracing::debug!(?error, provider = ?path.provider, "Parsed error response from OAuth provider"); + + let mut client_response = Response::new(Body::from(bytes)); + *client_response.headers_mut() = headers; + *client_response.status_mut() = status; + + client_response + } + Err(_) => { + // We still do not know what the remote server is doing... and need to + // cancel the request ourselves + tracing::warn!( + "Remote OAuth provide returned a response that we do not undestand" + ); + + Response::new( + serde_json::to_vec(&ProxyTokenError { + error: "access_denied".to_string(), + error_description: Some(format!( + "{} returned a malformed response", + path.provider + )), + error_uri: None, }) .unwrap() .into(), - )?) - } - Err(_) => { - // Do not log the error here as we want to ensure we do not leak token information - tracing::debug!( - "Failed to parse a success response from the remote token endpoint" - ); - - // Try to deserialize the body again, but this time as an error - let mut error_response = match serde_json::from_slice::(&bytes) - { - Ok(error) => { - // We found an error in the message body. This is not ideal, but we at - // least can understand what the server was trying to tell us - tracing::debug!(?error, provider = ?path.provider, "Parsed error response from OAuth provider"); - - let mut client_response = Response::new(Body::from(bytes)); - *client_response.headers_mut() = headers; - *client_response.status_mut() = status; - - client_response - } - Err(_) => { - // We still do not know what the remote server is doing... and need to - // cancel the request ourselves - tracing::warn!( - "Remote OAuth provide returned a response that we do not undestand" - ); - - Response::new( - serde_json::to_vec(&ProxyTokenError { - error: "access_denied".to_string(), - error_description: Some(format!( - "{} returned a malformed response", - path.provider - )), - error_uri: None, - }) - .unwrap() - .into(), - ) - } - }; - - *error_response.status_mut() = StatusCode::BAD_REQUEST; - error_response.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static("application/json"), - ); - - Ok(error_response) - } + ) + } + }; + + *error_response.status_mut() = StatusCode::BAD_REQUEST; + error_response.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + + Ok(error_response) } } - } else { - tracing::info!(provider = ?path.provider, "Found an OAuth provider, but it is not configured properly"); - - Err(bad_request("Invalid provider")) } } diff --git a/v-api/src/endpoints/login/oauth/github.rs b/v-api/src/endpoints/login/oauth/github.rs index 3f3483fe..9b7de3e9 100644 --- a/v-api/src/endpoints/login/oauth/github.rs +++ b/v-api/src/endpoints/login/oauth/github.rs @@ -5,30 +5,28 @@ use http::{header::USER_AGENT, HeaderMap, HeaderValue}; use hyper::body::Bytes; use reqwest::Request; -use secrecy::SecretString; use serde::Deserialize; use std::fmt; -use crate::endpoints::login::{ExternalUserId, UserInfo, UserInfoError}; - -use super::{ - ClientType, ExtractUserInfo, OAuthPrivateCredentials, OAuthProvider, OAuthProviderName, - OAuthPublicCredentials, +use crate::{ + config::ResolvedOAuthConfig, + endpoints::login::{ + oauth::{ + OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, + OAuthProviderDeviceInfo, + }, + ExternalUserId, UserInfo, UserInfoError, + }, }; +use super::{ExtractUserInfo, OAuthProvider, OAuthProviderName}; + pub struct GitHubOAuthProvider { - // public: GitHubPublicProvider, - // private: Option, - device_public: OAuthPublicCredentials, - device_private: Option, - web_public: OAuthPublicCredentials, - web_private: Option, - additional_scopes: Vec, + authz_code_flow_info: Option, + device_code_flow_info: Option, default_headers: HeaderMap, + default_scopes: Vec, client: reqwest::Client, - token_endpoint: Option, - redirect_endpoint: Option, - redirect_proxy_endpoint: Option, } impl fmt::Debug for GitHubOAuthProvider { @@ -39,38 +37,43 @@ impl fmt::Debug for GitHubOAuthProvider { impl GitHubOAuthProvider { pub fn new( + config: ResolvedOAuthConfig, public_url: String, - device_client_id: String, - device_client_secret: SecretString, - web_client_id: String, - web_client_secret: SecretString, additional_scopes: Option>, ) -> Self { let mut headers = HeaderMap::new(); headers.insert(USER_AGENT, HeaderValue::from_static("v-api")); + let mut default_scopes = vec!["user:email".to_string()]; + default_scopes.extend(additional_scopes.unwrap_or_default()); + + let authz_code_flow_info = config.web.map(|web| OAuthProviderAuthorizationCodeInfo { + client_id: web.client_id, + client_secret: web.client_secret.into(), + auth_url_endpoint: "https://github.com/login/oauth/authorize".to_string(), + redirect_endpoint: format!("{}/login/oauth/github/code/callback", public_url,), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: "https://github.com/login/oauth/access_token".to_string(), + revocation_endpoint: None, + }); + let device_code_flow_info = config.device.map(|device| OAuthProviderDeviceInfo { + client_id: device.client_id, + client_secret: device.client_secret.into(), + device_code_endpoint: "https://github.com/login/device/code".to_string(), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: "https://github.com/login/oauth/access_token".to_string(), + revocation_endpoint: None, + }); + Self { - device_public: OAuthPublicCredentials { - client_id: device_client_id, - }, - device_private: Some(OAuthPrivateCredentials { - client_secret: device_client_secret, - }), - web_public: OAuthPublicCredentials { - client_id: web_client_id, - }, - web_private: Some(OAuthPrivateCredentials { - client_secret: web_client_secret, - }), - additional_scopes: additional_scopes.unwrap_or_default(), + authz_code_flow_info, + device_code_flow_info, default_headers: headers, + default_scopes, client: reqwest::ClientBuilder::new() .redirect(reqwest::redirect::Policy::none()) .build() .expect("Static client must build"), - token_endpoint: Some(format!("{}/login/oauth/github/device/exchange", public_url)), - redirect_endpoint: Some(format!("{}/login/oauth/github/code/callback", public_url,)), - redirect_proxy_endpoint: None, } } @@ -119,79 +122,29 @@ impl OAuthProvider for GitHubOAuthProvider { fn name(&self) -> OAuthProviderName { OAuthProviderName::GitHub } - - fn scopes(&self) -> Vec<&str> { - let mut default = vec!["user:email"]; - default.extend(self.additional_scopes.iter().map(|s| s.as_str())); - default - } - fn initialize_headers(&self, request: &mut Request) { *request.headers_mut() = self.default_headers.clone(); } - fn client(&self) -> &reqwest::Client { &self.client } - - fn client_id(&self, client_type: &ClientType) -> &str { - match client_type { - ClientType::Device => &self.device_public.client_id, - ClientType::Web => &self.web_public.client_id, - } - } - - fn client_secret(&self, client_type: &ClientType) -> Option<&SecretString> { - match client_type { - ClientType::Device => self - .device_private - .as_ref() - .map(|private| &private.client_secret), - ClientType::Web => self - .web_private - .as_ref() - .map(|private| &private.client_secret), - } - } - fn user_info_endpoints(&self) -> Vec<&str> { vec![ "https://api.github.com/user", "https://api.github.com/user/emails", ] } - - fn device_code_endpoint(&self) -> Option<&str> { - Some("https://github.com/login/device/code") - } - - fn auth_url_endpoint(&self) -> &str { - "https://github.com/login/oauth/authorize" - } - - fn token_exchange_content_type(&self) -> &str { - "application/x-www-form-urlencoded" + fn default_scopes(&self) -> &[String] { + &self.default_scopes } - fn token_exchange_endpoint(&self) -> &str { - "https://github.com/login/oauth/access_token" + fn authz_code_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodeInfo> { + self.authz_code_flow_info.as_ref() } - - fn token_revocation_endpoint(&self) -> Option<&str> { + fn authz_code_pkce_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodePkceInfo> { None } - - fn supports_pkce(&self) -> bool { - true - } - - fn token_endpoint(&self) -> Option<&str> { - self.token_endpoint.as_deref() - } - fn redirect_endpoint(&self) -> Option<&str> { - self.redirect_endpoint.as_deref() - } - fn redirect_proxy_endpoint(&self) -> Option<&str> { - self.redirect_proxy_endpoint.as_deref() + fn device_code_flow_info(&self) -> Option<&OAuthProviderDeviceInfo> { + self.device_code_flow_info.as_ref() } } diff --git a/v-api/src/endpoints/login/oauth/google.rs b/v-api/src/endpoints/login/oauth/google.rs index e477ae44..666815e4 100644 --- a/v-api/src/endpoints/login/oauth/google.rs +++ b/v-api/src/endpoints/login/oauth/google.rs @@ -4,27 +4,28 @@ use hyper::body::Bytes; use reqwest::Request; -use secrecy::SecretString; use serde::Deserialize; use std::fmt; -use crate::endpoints::login::{ExternalUserId, UserInfo, UserInfoError}; - -use super::{ - ClientType, ExtractUserInfo, OAuthPrivateCredentials, OAuthProvider, OAuthProviderName, - OAuthPublicCredentials, +use crate::{ + config::ResolvedOAuthConfig, + endpoints::login::{ + oauth::{ + OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, + OAuthProviderDeviceInfo, + }, + ExternalUserId, UserInfo, UserInfoError, + }, }; +use super::{ExtractUserInfo, OAuthProvider, OAuthProviderName}; + pub struct GoogleOAuthProvider { - device_public: OAuthPublicCredentials, - device_private: Option, - web_public: OAuthPublicCredentials, - web_private: Option, - additional_scopes: Vec, + authz_code_flow_info: Option, + authz_code_pkce_flow_info: Option, + device_code_flow_info: Option, + default_scopes: Vec, client: reqwest::Client, - token_endpoint: Option, - redirect_endpoint: Option, - redirect_proxy_endpoint: Option, } impl fmt::Debug for GoogleOAuthProvider { @@ -35,34 +36,56 @@ impl fmt::Debug for GoogleOAuthProvider { impl GoogleOAuthProvider { pub fn new( + config: ResolvedOAuthConfig, public_url: String, - device_client_id: String, - device_client_secret: SecretString, - web_client_id: String, - web_client_secret: SecretString, additional_scopes: Option>, ) -> Self { + let mut default_scopes = vec![ + "openid".to_string(), + "email".to_string(), + "profile".to_string(), + ]; + default_scopes.extend(additional_scopes.unwrap_or_default()); + + let authz_code_flow_info = config.web.map(|web| OAuthProviderAuthorizationCodeInfo { + client_id: web.client_id, + client_secret: web.client_secret.into(), + auth_url_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(), + redirect_endpoint: format!("{}/login/oauth/github/code/callback", public_url,), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: "https://oauth2.googleapis.com/token".to_string(), + revocation_endpoint: None, + }); + let authz_code_pkce_flow_info = + config + .proxy_web + .map(|proxy| OAuthProviderAuthorizationCodePkceInfo { + client_id: proxy.client_id, + auth_url_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(), + redirect_endpoint: format!("{}/login/oauth/github/code/callback", public_url,), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: "https://oauth2.googleapis.com/token".to_string(), + proxy_port: proxy.proxy_port, + revocation_endpoint: None, + }); + let device_code_flow_info = config.device.map(|device| OAuthProviderDeviceInfo { + client_id: device.client_id, + client_secret: device.client_secret.into(), + device_code_endpoint: "https://github.com/login/device/code".to_string(), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: "https://github.com/login/oauth/access_token".to_string(), + revocation_endpoint: None, + }); + Self { - device_public: OAuthPublicCredentials { - client_id: device_client_id, - }, - device_private: Some(OAuthPrivateCredentials { - client_secret: device_client_secret, - }), - web_public: OAuthPublicCredentials { - client_id: web_client_id, - }, - web_private: Some(OAuthPrivateCredentials { - client_secret: web_client_secret, - }), - additional_scopes: additional_scopes.unwrap_or_default(), + authz_code_flow_info, + authz_code_pkce_flow_info, + device_code_flow_info, + default_scopes, client: reqwest::ClientBuilder::new() .redirect(reqwest::redirect::Policy::none()) .build() .expect("Static client must build"), - token_endpoint: Some(format!("{}/login/oauth/google/device/exchange", public_url)), - redirect_endpoint: Some(format!("{}/login/oauth/google/code/callback", public_url,)), - redirect_proxy_endpoint: None, } } @@ -129,77 +152,27 @@ impl OAuthProvider for GoogleOAuthProvider { fn name(&self) -> OAuthProviderName { OAuthProviderName::Google } - - fn scopes(&self) -> Vec<&str> { - let mut default = vec!["openid", "email", "profile"]; - default.extend(self.additional_scopes.iter().map(|s| s.as_str())); - default - } - fn initialize_headers(&self, _request: &mut Request) {} - fn client(&self) -> &reqwest::Client { &self.client } - - fn client_id(&self, client_type: &ClientType) -> &str { - match client_type { - ClientType::Device => &self.device_public.client_id, - ClientType::Web => &self.web_public.client_id, - } - } - - fn client_secret(&self, client_type: &ClientType) -> Option<&SecretString> { - match client_type { - ClientType::Device => self - .device_private - .as_ref() - .map(|private| &private.client_secret), - ClientType::Web => self - .web_private - .as_ref() - .map(|private| &private.client_secret), - } - } - fn user_info_endpoints(&self) -> Vec<&str> { vec![ "https://openidconnect.googleapis.com/v1/userinfo", "https://people.googleapis.com/v1/people/me?personFields=names", ] } - - fn device_code_endpoint(&self) -> Option<&str> { - Some("https://oauth2.googleapis.com/device/code") - } - - fn auth_url_endpoint(&self) -> &str { - "https://accounts.google.com/o/oauth2/v2/auth" - } - - fn token_exchange_content_type(&self) -> &str { - "application/x-www-form-urlencoded" - } - - fn token_exchange_endpoint(&self) -> &str { - "https://oauth2.googleapis.com/token" - } - - fn token_revocation_endpoint(&self) -> Option<&str> { - Some("https://oauth2.googleapis.com/revoke") - } - - fn supports_pkce(&self) -> bool { - true + fn default_scopes(&self) -> &[String] { + &self.default_scopes } - fn token_endpoint(&self) -> Option<&str> { - self.token_endpoint.as_deref() + fn authz_code_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodeInfo> { + self.authz_code_flow_info.as_ref() } - fn redirect_endpoint(&self) -> Option<&str> { - self.redirect_endpoint.as_deref() + fn authz_code_pkce_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodePkceInfo> { + None } - fn redirect_proxy_endpoint(&self) -> Option<&str> { - self.redirect_proxy_endpoint.as_deref() + fn device_code_flow_info(&self) -> Option<&OAuthProviderDeviceInfo> { + self.device_code_flow_info.as_ref() } } diff --git a/v-api/src/endpoints/login/oauth/mod.rs b/v-api/src/endpoints/login/oauth/mod.rs index b986c5c4..cbaeb1c6 100644 --- a/v-api/src/endpoints/login/oauth/mod.rs +++ b/v-api/src/endpoints/login/oauth/mod.rs @@ -18,7 +18,10 @@ use thiserror::Error; use tracing::instrument; use v_model::OAuthClient; -use crate::authn::{key::RawKey, Verify}; +use crate::{ + authn::{key::RawKey, Verify}, + secrets::OpenApiSecretString, +}; use super::{UserInfo, UserInfoError, UserInfoProvider}; @@ -37,12 +40,15 @@ pub enum OAuthProviderError { MissingRedirectUri, #[error("Failed to parse URL")] UrlParseError(#[from] ParseError), + #[error("Provider does not support web clients")] + WebClientNotSupported, } #[derive(Debug)] pub enum ClientType { Device, Web, + WebPkce, } pub type WebClient = BasicClient< @@ -68,63 +74,79 @@ pub struct OAuthPrivateCredentials { pub trait OAuthProvider: ExtractUserInfo + Debug + Send + Sync { fn name(&self) -> OAuthProviderName; - fn scopes(&self) -> Vec<&str>; fn initialize_headers(&self, request: &mut Request); fn client(&self) -> &reqwest::Client; - fn client_id(&self, client_type: &ClientType) -> &str; - fn client_secret(&self, client_type: &ClientType) -> Option<&SecretString>; - - // TODO: How can user info be change to something statically checked instead of a runtime check fn user_info_endpoints(&self) -> Vec<&str>; - fn device_code_endpoint(&self) -> Option<&str>; - fn auth_url_endpoint(&self) -> &str; - fn token_exchange_content_type(&self) -> &str; - fn token_exchange_endpoint(&self) -> &str; - fn token_revocation_endpoint(&self) -> Option<&str>; - fn supports_pkce(&self) -> bool; - - fn token_endpoint(&self) -> Option<&str>; - fn redirect_endpoint(&self) -> Option<&str>; - fn redirect_proxy_endpoint(&self) -> Option<&str>; - - fn provider_info(&self, client_type: &ClientType) -> OAuthProviderInfo { - OAuthProviderInfo { - provider: self.name(), - client_id: self.client_id(client_type).to_string(), - auth_url_endpoint: self.auth_url_endpoint().to_string(), - device_code_endpoint: self.device_code_endpoint().map(|s| s.to_string()), - token_endpoint: self.token_endpoint().map(|s| s.to_string()), - redirect_endpoint: self.redirect_endpoint().map(|s| s.to_string()), - redirect_proxy_endpoint: self.redirect_proxy_endpoint().map(|s| s.to_string()), - scopes: self - .scopes() - .into_iter() - .map(|s| s.to_string()) - .collect::>(), - } + + fn authz_code_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodeInfo>; + fn authz_code_pkce_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodePkceInfo>; + fn device_code_flow_info(&self) -> Option<&OAuthProviderDeviceInfo>; + + fn default_scopes(&self) -> &[String]; + + fn supports_pkce(&self) -> bool { + false } + // TODO: How can user info be change to something statically checked instead of a runtime check + // fn auth_ur_endpoint(&self) -> Option<&str>; + // fn redirect_endpoint(&self) -> Option<&str>; + + // fn token_exchange_content_type(&self) -> &str; + // fn token_exchange_endpoint(&self) -> &str; + // fn token_revocation_endpoint(&self) -> Option<&str>; + + // fn supports_pkce_only(&self) -> bool { false } + + // fn device_code_endpoint(&self) -> Option<&str>; + // fn token_endpoint(&self) -> Option<&str>; + + // fn provider_info(&self, client_type: &ClientType) -> Option { + // let default_scopes = self + // .scopes() + // .into_iter() + // .map(|s| s.to_string()) + // .collect::>(); + // self.client_id(client_type).map(|client_id| OAuthProviderInfo { + // provider: self.name(), + // client_id: client_id.to_string(), + // code: self.authz_code_flow_info(), + // pkce: self.authz_code_pkce_flow_info(), + // device: self.device_code_flow_info(), + // // auth_url_endpoint: self.auth_url_endpoint().to_string(), + // // device_code_endpoint: self.device_code_endpoint().map(|s| s.to_string()), + // // token_endpoint: self.token_endpoint().map(|s| s.to_string()), + // // redirect_endpoint: self.redirect_endpoint().map(|s| s.to_string()), + // // supports_pkce_only: self.supports_pkce_only(), + // // scopes: self + // // .scopes() + // // .into_iter() + // // .map(|s| s.to_string()) + // // .collect::>(), + // }) + // } + fn as_web_client(&self) -> Result { - let mut client = - BasicClient::new(ClientId::new(self.client_id(&ClientType::Web).to_string())) - .set_auth_uri(AuthUrl::new(self.auth_url_endpoint().to_string())?) - .set_token_uri(TokenUrl::new(self.token_exchange_endpoint().to_string())?) - .set_revocation_url_option( - self.token_revocation_endpoint() - .map(|s| RevocationUrl::new(s.to_string())) - .transpose()?, - ) - .set_redirect_uri(RedirectUrl::new( - self.redirect_endpoint() - .ok_or(OAuthProviderError::MissingRedirectUri)? - .to_string(), - )?); - - if let Some(secret) = self.client_secret(&ClientType::Web) { - client = client.set_client_secret(ClientSecret::new(secret.expose_secret().to_string())) + match self.authz_code_flow_info() { + Some(info) => { + let client = BasicClient::new(ClientId::new(info.client_id.clone())) + .set_auth_uri(AuthUrl::new(info.auth_url_endpoint.clone())?) + .set_token_uri(TokenUrl::new(info.token_endpoint.clone())?) + .set_revocation_url_option( + info.revocation_endpoint + .as_ref() + .map(|url| RevocationUrl::new(url.to_string())) + .transpose()?, + ) + .set_redirect_uri(RedirectUrl::new(info.redirect_endpoint.to_string())?) + .set_client_secret(ClientSecret::new( + info.client_secret.0.expose_secret().to_string(), + )); + + Ok(client) + } + None => Err(OAuthProviderError::WebClientNotSupported), } - - Ok(client) } } @@ -172,19 +194,68 @@ where } } -#[derive(Debug, Deserialize, Serialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct OAuthProviderInfo { provider: OAuthProviderName, client_id: String, + code: Option, + pkce: Option, + device: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub struct OAuthProviderAuthorizationCodeInfo { + client_id: String, + client_secret: OpenApiSecretString, auth_url_endpoint: String, - device_code_endpoint: Option, - token_endpoint: Option, - redirect_endpoint: Option, - redirect_proxy_endpoint: Option, - scopes: Vec, + redirect_endpoint: String, + token_endpoint_content_type: String, + token_endpoint: String, + revocation_endpoint: Option, +} + +impl OAuthProviderAuthorizationCodeInfo { + fn as_web_client(&self) -> Result { + let client = BasicClient::new(ClientId::new(self.client_id.clone())) + .set_auth_uri(AuthUrl::new(self.auth_url_endpoint.clone())?) + .set_token_uri(TokenUrl::new(self.token_endpoint.clone())?) + .set_revocation_url_option( + self.revocation_endpoint + .as_ref() + .map(|url| RevocationUrl::new(url.to_string())) + .transpose()?, + ) + .set_redirect_uri(RedirectUrl::new(self.redirect_endpoint.to_string())?) + .set_client_secret(ClientSecret::new( + self.client_secret.0.expose_secret().to_string(), + )); + + Ok(client) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub struct OAuthProviderAuthorizationCodePkceInfo { + client_id: String, + auth_url_endpoint: String, + redirect_endpoint: String, + token_endpoint_content_type: String, + token_endpoint: String, + revocation_endpoint: Option, + proxy_port: u16, +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub struct OAuthProviderDeviceInfo { + client_id: String, + client_secret: OpenApiSecretString, + device_code_endpoint: String, + token_endpoint_content_type: String, + token_endpoint: String, + revocation_endpoint: Option, } -#[derive(Debug, Deserialize, PartialEq, Eq, Hash, Serialize, JsonSchema)] +#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Hash, Serialize, JsonSchema)] #[serde(rename_all = "kebab-case")] pub enum OAuthProviderName { #[serde(rename = "github")] diff --git a/v-api/src/endpoints/login/oauth/zendesk.rs b/v-api/src/endpoints/login/oauth/zendesk.rs index b2421ce2..cedc3182 100644 --- a/v-api/src/endpoints/login/oauth/zendesk.rs +++ b/v-api/src/endpoints/login/oauth/zendesk.rs @@ -4,30 +4,28 @@ use hyper::body::Bytes; use reqwest::Request; -use secrecy::SecretString; use serde::Deserialize; use std::fmt; -use crate::endpoints::login::{ExternalUserId, UserInfo, UserInfoError}; - -use super::{ - ClientType, ExtractUserInfo, OAuthPrivateCredentials, OAuthProvider, OAuthProviderName, - OAuthPublicCredentials, +use crate::{ + config::ResolvedOAuthConfig, + endpoints::login::{ + oauth::{ + OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, + OAuthProviderDeviceInfo, + }, + ExternalUserId, UserInfo, UserInfoError, + }, }; +use super::{ExtractUserInfo, OAuthProvider, OAuthProviderName}; + pub struct ZendeskOAuthProvider { - device_public: OAuthPublicCredentials, - device_private: Option, - web_public: OAuthPublicCredentials, - web_private: Option, - additional_scopes: Vec, - client: reqwest::Client, + authz_code_flow_info: Option, + authz_code_pkce_flow_info: Option, user_info_endpoint: String, - auth_url_endpoint: String, - token_exchange_endpoint: String, - token_endpoint: Option, - redirect_endpoint: Option, - redirect_proxy_endpoint: Option, + default_scopes: Vec, + client: reqwest::Client, } impl fmt::Debug for ZendeskOAuthProvider { @@ -38,47 +36,81 @@ impl fmt::Debug for ZendeskOAuthProvider { impl ZendeskOAuthProvider { pub fn new( + config: ResolvedOAuthConfig, public_url: String, subdomain: String, - device_client_id: String, - device_client_secret: SecretString, - web_client_id: String, - web_client_secret: SecretString, additional_scopes: Option>, - redirect_proxy_port: u16, ) -> Self { + // let base_url = format!("https://{}.zendesk.com", subdomain); + + // Self { + // device_public: OAuthPublicCredentials { + // client_id: device_client_id, + // }, + // device_private: Some(OAuthPrivateCredentials { + // client_secret: device_client_secret, + // }), + // web_public: OAuthPublicCredentials { + // client_id: web_client_id, + // }, + // web_private: Some(OAuthPrivateCredentials { + // client_secret: web_client_secret, + // }), + // web_pkce: web_pkce_client_id.map(|client_id| OAuthPublicCredentials { + // client_id, + // }), + // web_pkce_port: web_pkce_port, + // additional_scopes: additional_scopes.unwrap_or_default(), + // client: reqwest::ClientBuilder::new() + // .redirect(reqwest::redirect::Policy::none()) + // .build() + // .expect("Static client must build"), + // user_info_endpoint: format!("{}/api/v2/users/me.json", base_url), + // auth_url_endpoint: format!("{}/oauth/authorizations/new", base_url), + // token_exchange_endpoint: format!("{}/oauth/tokens", base_url), + // token_endpoint: Some(format!( + // "{}/login/oauth/zendesk/device/exchange", + // public_url + // )), + // redirect_endpoint: Some(format!("{}/login/oauth/zendesk/code/callback", public_url,)), + // } + let base_url = format!("https://{}.zendesk.com", subdomain); + let mut default_scopes = vec!["users:read".to_string()]; + default_scopes.extend(additional_scopes.unwrap_or_default()); + + let authz_code_flow_info = config.web.map(|web| OAuthProviderAuthorizationCodeInfo { + client_id: web.client_id, + client_secret: web.client_secret.into(), + auth_url_endpoint: format!("{}/oauth/authorizations/new", base_url), + redirect_endpoint: format!("{}/login/oauth/zendesk/code/callback", public_url), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: format!("{}/oauth/tokens", base_url), + revocation_endpoint: None, + }); + let authz_code_pkce_flow_info = + config + .proxy_web + .map(|proxy| OAuthProviderAuthorizationCodePkceInfo { + client_id: proxy.client_id, + auth_url_endpoint: format!("{}/oauth/authorizations/new", base_url), + redirect_endpoint: format!("{}/login/oauth/zendesk/code/callback", public_url), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: format!("{}/oauth/tokens", base_url), + proxy_port: proxy.proxy_port, + revocation_endpoint: None, + }); + Self { - device_public: OAuthPublicCredentials { - client_id: device_client_id, - }, - device_private: Some(OAuthPrivateCredentials { - client_secret: device_client_secret, - }), - web_public: OAuthPublicCredentials { - client_id: web_client_id, - }, - web_private: Some(OAuthPrivateCredentials { - client_secret: web_client_secret, - }), - additional_scopes: additional_scopes.unwrap_or_default(), + authz_code_flow_info, + authz_code_pkce_flow_info, + user_info_endpoint: format!("{}/api/v2/users/me.json", base_url), + default_scopes, client: reqwest::ClientBuilder::new() .redirect(reqwest::redirect::Policy::none()) .build() .expect("Static client must build"), - user_info_endpoint: format!("{}/api/v2/users/me.json", base_url), - auth_url_endpoint: format!("{}/oauth/authorizations/new", base_url), - token_exchange_endpoint: format!("{}/oauth/tokens", base_url), - token_endpoint: Some(format!( - "{}/login/oauth/zendesk/device/exchange", - public_url - )), - redirect_endpoint: Some(format!("{}/login/oauth/zendesk/code/callback", public_url,)), - redirect_proxy_endpoint: Some(format!( - "http://localhost:{}/login/oauth/zendesk/code/callback", - redirect_proxy_port - )), } } @@ -129,74 +161,24 @@ impl OAuthProvider for ZendeskOAuthProvider { fn name(&self) -> OAuthProviderName { OAuthProviderName::Zendesk } - - fn scopes(&self) -> Vec<&str> { - let mut default = vec!["users:read"]; - default.extend(self.additional_scopes.iter().map(|s| s.as_str())); - default - } - fn initialize_headers(&self, _request: &mut Request) {} - fn client(&self) -> &reqwest::Client { &self.client } - - fn client_id(&self, client_type: &ClientType) -> &str { - match client_type { - ClientType::Device => &self.device_public.client_id, - ClientType::Web => &self.web_public.client_id, - } - } - - fn client_secret(&self, client_type: &ClientType) -> Option<&SecretString> { - match client_type { - ClientType::Device => self - .device_private - .as_ref() - .map(|private| &private.client_secret), - ClientType::Web => self - .web_private - .as_ref() - .map(|private| &private.client_secret), - } - } - fn user_info_endpoints(&self) -> Vec<&str> { vec![&self.user_info_endpoint] } - - fn device_code_endpoint(&self) -> Option<&str> { - None + fn default_scopes(&self) -> &[String] { + &self.default_scopes } - fn auth_url_endpoint(&self) -> &str { - &self.auth_url_endpoint + fn authz_code_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodeInfo> { + self.authz_code_flow_info.as_ref() } - - fn token_exchange_content_type(&self) -> &str { - "application/x-www-form-urlencoded" + fn authz_code_pkce_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodePkceInfo> { + self.authz_code_pkce_flow_info.as_ref() } - - fn token_exchange_endpoint(&self) -> &str { - &self.token_exchange_endpoint - } - - fn token_revocation_endpoint(&self) -> Option<&str> { + fn device_code_flow_info(&self) -> Option<&OAuthProviderDeviceInfo> { None } - - fn supports_pkce(&self) -> bool { - false - } - - fn token_endpoint(&self) -> Option<&str> { - self.token_endpoint.as_deref() - } - fn redirect_endpoint(&self) -> Option<&str> { - self.redirect_endpoint.as_deref() - } - fn redirect_proxy_endpoint(&self) -> Option<&str> { - self.redirect_proxy_endpoint.as_deref() - } } diff --git a/v-api/src/secrets.rs b/v-api/src/secrets.rs index a4f95604..c55ab6a4 100644 --- a/v-api/src/secrets.rs +++ b/v-api/src/secrets.rs @@ -10,7 +10,7 @@ use secrecy::{ExposeSecret, SecretString}; use serde::{Deserialize, Serialize, Serializer}; use std::borrow::Cow; -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct OpenApiSecretString(pub SecretString); impl From for OpenApiSecretString { diff --git a/v-cli-sdk/Cargo.toml b/v-cli-sdk/Cargo.toml index c3c630ec..dc1d1d75 100644 --- a/v-cli-sdk/Cargo.toml +++ b/v-cli-sdk/Cargo.toml @@ -13,7 +13,11 @@ hyper = { workspace = true, features = ["server", "http1"] } hyper-util = { workspace = true, features = ["tokio"] } oauth2 = { workspace = true } oauth2-reqwest = { workspace = true } +owo-colors = { workspace = true } progenitor-client = { workspace = true } reqwest = { workspace = true } +schemars = { workspace = true } serde = { workspace = true } +serde_json = { workspace = true } +tabwriter = { workspace = true } tokio = { workspace = true, features = ["macros", "rt", "net", "sync"] } diff --git a/v-cli-sdk/src/cmd/auth/login.rs b/v-cli-sdk/src/cmd/auth/login.rs index 6d09a937..8b77c2cd 100644 --- a/v-cli-sdk/src/cmd/auth/login.rs +++ b/v-cli-sdk/src/cmd/auth/login.rs @@ -8,11 +8,8 @@ use oauth2::TokenResponse; use std::{error::Error as StdError, fmt::Debug, future::Future, io::Write, pin::Pin, sync::Arc}; use crate::{ - cmd::{ - auth::oauth::{self, CliOAuthAdapter, CliOAuthProviderInfo}, - config::CliConfig, - }, - CliContext, + cmd::auth::oauth::{self, CliOAuthAdapter, CliOAuthProviderInfo}, + VCliConfig, VCliContext, }; pub trait CliAdapterToken { @@ -25,12 +22,12 @@ impl CliConsumerLoginProvider for T where T: Into + Subcommand // Authenticates and generates an access token for interacting with the api #[derive(Parser, Debug, Clone)] #[clap(name = "login")] -pub struct Login

+pub struct Login where - P: CliConsumerLoginProvider, + SupportedProviders: CliConsumerLoginProvider, { #[command(subcommand)] - method: LoginMethod

, + method: LoginMethod, #[arg(short = 'm', default_value = "id")] mode: AuthenticationMode, } @@ -41,8 +38,8 @@ where { pub async fn run(&self, ctx: &mut T) -> Result<()> where - T: CliContext, - >::Error: StdError + Send + Sync + 'static, + T: VCliContext, + >::Error: StdError + Send + Sync + 'static, { let access_token = self.method.run(ctx, &self.mode).await?; @@ -54,15 +51,15 @@ where } #[derive(Subcommand, Debug, Clone)] -pub enum LoginMethod

+pub enum LoginMethod where - P: Subcommand + Debug + Clone, + SupportedProviders: Subcommand + Debug + Clone, { #[command(name = "oauth")] /// Login via OAuth OAuth { #[command(subcommand)] - provider: P, + provider: SupportedProviders, }, /// Login via Magic Link #[command(name = "mlink")] @@ -96,20 +93,20 @@ pub enum AuthenticationMode { Remote, } -impl

LoginMethod

+impl LoginMethod where - P: CliConsumerLoginProvider, + SupportedProviders: CliConsumerLoginProvider, { pub async fn run(&self, ctx: &T, mode: &AuthenticationMode) -> Result where - T: CliContext, - >::Error: StdError + Send + Sync + 'static, + T: VCliContext, + >::Error: StdError + Send + Sync + 'static, { match self { Self::OAuth { provider } => { let adapter = ctx.oauth_adapter(); let provider = provider.clone().into(); - let provider = adapter.provider(&provider).await?; + let provider = adapter.provider(provider).await?; // We now need to inspect the provider to determine the correct flow to use. If // possible we use a limited input device flow, but not all providers support it. @@ -118,7 +115,7 @@ where if provider.device_code_endpoint().is_some() { self.run_oauth_device_provider(provider, mode, ctx.oauth_adapter()) .await - } else if provider.code_redirect_proxy_endpoint().is_some() { + } else if provider.supports_pkce_only() { self.run_oauth_code_provider(provider, mode, ctx.oauth_adapter()) .await } else { @@ -170,13 +167,13 @@ where async fn run_oauth_code_provider( &self, - provider: V, + provider: T, mode: &AuthenticationMode, - adapter: T, + adapter: V, ) -> Result where - T: CliOAuthAdapter + Send + Sync + 'static, - V: CliOAuthProviderInfo, + T: CliOAuthProviderInfo, + V: CliOAuthAdapter + Send + Sync + 'static, { let oauth_client = oauth::CodeOAuth::new(provider)?; let adapter = Arc::new(adapter); diff --git a/v-cli-sdk/src/cmd/auth/mod.rs b/v-cli-sdk/src/cmd/auth/mod.rs index 5de8042e..ad7d6a24 100644 --- a/v-cli-sdk/src/cmd/auth/mod.rs +++ b/v-cli-sdk/src/cmd/auth/mod.rs @@ -6,7 +6,7 @@ use anyhow::Result; use clap::{Parser, Subcommand}; use std::error::Error as StdError; -use crate::{cmd::auth::login::CliConsumerLoginProvider, CliContext}; +use crate::{cmd::auth::login::CliConsumerLoginProvider, VCliContext}; pub mod login; pub mod oauth; @@ -36,10 +36,10 @@ impl

Auth

where P: CliConsumerLoginProvider, { - pub async fn run(&self, ctx: &mut T) -> Result<()> + pub async fn run(&self, ctx: &mut T) -> Result<()> where - T: CliContext, - >::Error: StdError + Send + Sync + 'static, + T: VCliContext, + >::Error: StdError + Send + Sync + 'static, { match &self.auth { AuthCommands::Login(login) => login.run(ctx).await, diff --git a/v-cli-sdk/src/cmd/auth/oauth.rs b/v-cli-sdk/src/cmd/auth/oauth.rs index 8f5f6751..950e859d 100644 --- a/v-cli-sdk/src/cmd/auth/oauth.rs +++ b/v-cli-sdk/src/cmd/auth/oauth.rs @@ -12,25 +12,26 @@ use http::{Request, Response, StatusCode}; use http_body_util::{BodyExt, Full}; use hyper::body::{Bytes, Incoming}; use oauth2::basic::{BasicClient, BasicTokenType}; -use oauth2::StandardDeviceAuthorizationResponse; use oauth2::{ AuthType, AuthUrl, ClientId, CsrfToken, DeviceAuthorizationUrl, EmptyExtraTokenFields, EndpointNotSet, EndpointSet, RedirectUrl, Scope, StandardTokenResponse, TokenUrl, }; -use reqwest::Url; +use oauth2::{PkceCodeChallenge, StandardDeviceAuthorizationResponse}; use tokio::sync::oneshot; use crate::cmd::auth::login::CliAdapterToken; use super::proxy::run_proxy_server; +static PROXY_PORT: u16 = 8174; + pub trait CliOAuthAdapter { type Token: CliAdapterToken; type Error: StdError + Send + Sync + 'static; fn provider( &self, - provider: &super::login::LoginProvider, + provider: super::login::LoginProvider, ) -> Pin> + Send>>; fn exchange_authorization_code( &self, @@ -43,10 +44,11 @@ pub trait CliOAuthAdapter { } pub trait CliOAuthProviderInfo { + fn supports_pkce_only(&self) -> bool; fn device_code_endpoint(&self) -> Option<&str>; - fn code_redirect_proxy_endpoint(&self) -> Option<&str>; fn auth_url_endpoint(&self) -> &str; fn token_endpoint(&self) -> &str; + fn redirect_endpoint(&self) -> Option<&str>; fn client_id(&self) -> &str; fn scopes(&self) -> &[String]; } @@ -75,45 +77,34 @@ impl CodeOAuth { where T: CliOAuthProviderInfo, { - let redirect_url = provider - .code_redirect_proxy_endpoint() - .ok_or_else(|| anyhow::anyhow!("Provider does not support code redirect proxy flow"))?; - - let parsed_url = Url::parse(redirect_url)?; - - let port = parsed_url.port().ok_or_else(|| { - anyhow::anyhow!("Provider proxy url does not have a defined port to listen on") - })?; - - if parsed_url.scheme() != "http" { - anyhow::bail!("Provider proxy url scheme must be http"); - } - - if parsed_url - .host_str() - .map(|h| h != "localhost" && h != "127.0.0.1") - .unwrap_or(true) - { - anyhow::bail!("Provider proxy url host must be localhost"); - } - let client = BasicClient::new(ClientId::new(provider.client_id().to_string())) .set_auth_uri(AuthUrl::new(provider.auth_url_endpoint().to_string())?) .set_auth_type(AuthType::RequestBody) .set_token_uri(TokenUrl::new(provider.token_endpoint().to_string())?) - .set_redirect_uri(RedirectUrl::new(redirect_url.to_string())?); + .set_redirect_uri(RedirectUrl::new( + provider + .redirect_endpoint() + .expect("OAuth code flow provider must define a redirect url") + .to_string(), + )?); Ok(Self { client, scopes: provider.scopes().iter().map(|s| s.to_string()).collect(), - port, + port: PROXY_PORT, }) } /// Build the authorization URL that the user should visit in a browser. /// Returns the full URL and the CSRF state token used for verification. - pub fn authorize_url(&self) -> (oauth2::url::Url, CsrfToken) { - let mut req = self.client.authorize_url(CsrfToken::new_random); + pub fn authorize_url( + &self, + pkce_challenge: PkceCodeChallenge, + ) -> (oauth2::url::Url, CsrfToken) { + let mut req = self + .client + .authorize_url(CsrfToken::new_random) + .set_pkce_challenge(pkce_challenge); for scope in &self.scopes { req = req.add_scope(Scope::new(scope.to_string())); @@ -133,7 +124,8 @@ impl CodeOAuth { where T: CliOAuthAdapter + Send + Sync + 'static, { - let (auth_url, _csrf_state) = self.authorize_url(); + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + let (auth_url, _csrf_state) = self.authorize_url(pkce_challenge); println!( "Open the following URL in your browser to authenticate:\n\n {}\n", diff --git a/v-cli-sdk/src/cmd/config/mod.rs b/v-cli-sdk/src/cmd/config/mod.rs index bf2d123e..1ba94478 100644 --- a/v-cli-sdk/src/cmd/config/mod.rs +++ b/v-cli-sdk/src/cmd/config/mod.rs @@ -5,14 +5,14 @@ use anyhow::Result; use clap::{Parser, Subcommand}; -use crate::{CliContext, FormatStyle}; +use crate::{FormatStyle, VCliContext}; -pub trait CliConfig { +pub trait VCliConfig { fn host(&self) -> Option<&str>; fn set_host(&mut self, host: String); fn token(&self) -> Option<&str>; fn set_token(&mut self, token: String); - fn default_format(&self) -> Option<&FormatStyle>; + fn default_format(&self) -> FormatStyle; fn set_default_format(&mut self, format: FormatStyle); fn mlink_redirect(&self) -> Option<&str>; fn set_mlink_redirect(&mut self, redirect: String); @@ -76,7 +76,7 @@ pub enum SetCmd { impl ConfigCmd { pub async fn run(&self, ctx: &mut T) -> Result<()> where - T: CliContext, + T: VCliContext, { match &self.setting { SettingCmd::Get(get) => get.run(ctx.config()).await?, @@ -90,17 +90,11 @@ impl ConfigCmd { impl GetCmd { pub async fn run(&self, config: &T) -> Result<()> where - T: CliConfig, + T: VCliConfig, { match &self { GetCmd::Format => { - println!( - "{}", - config - .default_format() - .copied() - .unwrap_or(FormatStyle::Json) - ); + println!("{}", config.default_format()); } GetCmd::Host => { println!("{}", config.host().unwrap_or("None")); @@ -123,7 +117,7 @@ impl GetCmd { impl SetCmd { pub async fn run(&self, config: &mut T) -> Result<()> where - T: CliConfig, + T: VCliConfig, { match &self { SetCmd::Format { format } => { diff --git a/v-cli-sdk/src/err.rs b/v-cli-sdk/src/err.rs index e84eeb36..6f8066a6 100644 --- a/v-cli-sdk/src/err.rs +++ b/v-cli-sdk/src/err.rs @@ -5,12 +5,12 @@ use anyhow::{anyhow, Error}; use progenitor_client::Error as ProgenitorClientError; -use crate::{ApiErrorMessage, CliContext, VerbosityLevel}; +use crate::{VApiErrorMessage, VCliContext, VerbosityLevel}; pub fn format_api_err(ctx: &T, client_err: ProgenitorClientError) -> Error where - T: CliContext, - E: ApiErrorMessage, + T: VCliContext, + E: VApiErrorMessage, { let mut err = anyhow!("API Request failed"); diff --git a/v-cli-sdk/src/lib.rs b/v-cli-sdk/src/lib.rs index 652ab27c..eb7950ab 100644 --- a/v-cli-sdk/src/lib.rs +++ b/v-cli-sdk/src/lib.rs @@ -6,22 +6,24 @@ use clap::ValueEnum; use serde::{Deserialize, Serialize}; use std::fmt::Display; -use crate::cmd::{ - auth::{login::CliMagicLinkAdapter, oauth::CliOAuthAdapter}, - config::CliConfig, -}; - pub mod cmd; pub mod err; +pub mod printer; + +use crate::cmd::auth::{login::CliMagicLinkAdapter, oauth::CliOAuthAdapter}; +pub use cmd::config::VCliConfig; -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum VerbosityLevel { None, All, } -#[derive(Copy, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Clone, Serialize, Deserialize)] +#[derive( + Copy, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Clone, Serialize, Deserialize, Default, +)] pub enum FormatStyle { + #[default] #[value(name = "json")] Json, #[value(name = "tab")] @@ -37,14 +39,13 @@ impl Display for FormatStyle { } } -pub trait CliContext { - type Attempt; +pub trait VCliContext { type Token; type Error; - fn config(&self) -> &impl CliConfig; - fn config_mut(&mut self) -> &mut impl CliConfig; - fn client(&self) -> Option<&C>; + fn config(&self) -> &impl VCliConfig; + fn config_mut(&mut self) -> &mut impl VCliConfig; + fn client(&self) -> Option; fn printer(&self) -> Option<&P>; fn verbosity(&self) -> VerbosityLevel; @@ -56,7 +57,7 @@ pub trait CliContext { ) -> impl CliMagicLinkAdapter + Send + Sync + 'static; } -pub trait ApiErrorMessage { +pub trait VApiErrorMessage { fn message(&self) -> Option<&str>; fn error_code(&self) -> Option<&str>; fn request_id(&self) -> Option<&str>; diff --git a/v-cli-sdk/src/printer/mod.rs b/v-cli-sdk/src/printer/mod.rs new file mode 100644 index 00000000..4685b8ed --- /dev/null +++ b/v-cli-sdk/src/printer/mod.rs @@ -0,0 +1,227 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use owo_colors::{OwoColorize, Style}; +use serde::Serialize; +use std::io::Write; +use tabwriter::TabWriter; + +#[derive(Debug, Clone)] +pub enum Printer { + Json, + Tab, +} + +pub trait CliOutput { + fn output_error(&self, value: &progenitor_client::Error) + where + T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug; +} + +impl Printer { + /// Print any serializable response object in the configured format. + /// + /// - `Json` mode emits compact, single-line JSON. + /// - `Tab` mode serializes to a `serde_json::Value` and pretty-prints it + /// with tab-aligned key/value pairs. + pub fn print_response(&self, value: &T) + where + T: Serialize, + { + let json_value = serde_json::to_value(value) + .unwrap_or_else(|e| serde_json::Value::String(format!("", e))); + + match self { + Printer::Json => { + println!("{}", serde_json::to_string(&json_value).unwrap_or_default()); + } + Printer::Tab => { + let styles = TabStyles::default(); + let mut tw = TabWriter::new(vec![]).ansi(true); + pretty_print_value(&mut tw, &json_value, 0, &styles); + tw.flush().unwrap(); + let output = String::from_utf8(tw.into_inner().unwrap()).unwrap(); + print!("{}", output); + } + } + } + + /// Print an error from a progenitor client response. + /// + /// A 401 Unauthorized is treated specially: instead of dumping the raw + /// server error we print a short, actionable message telling the user to + /// authenticate first. + pub fn print_error_response(&self, value: &progenitor_client::Error) + where + T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug, + { + // Check for 401 Unauthorized up-front, regardless of output format. + if let Some(status) = value.status() { + if status == reqwest::StatusCode::UNAUTHORIZED { + eprintln!("Authentication required. Please run `sprue auth login` first."); + return; + } + } + + match self { + Printer::Json => { + // For JSON mode, try to extract a serializable body from the + // error and fall back to the Debug representation. + let msg = match value { + progenitor_client::Error::ErrorResponse(rv) => { + serde_json::to_string(rv.as_ref()).ok() + } + _ => None, + }; + eprintln!("{}", msg.unwrap_or_else(|| format!("{:?}", value))); + } + Printer::Tab => { + eprintln!("{}", value); + } + } + } +} + +impl CliOutput for Printer { + fn output_error(&self, value: &progenitor_client::Error) + where + T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug, + { + self.print_error_response(value); + } +} + +// --------------------------------------------------------------------------- +// Tab-indented pretty-printer for serde_json::Value +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone)] +struct TabStyles { + label: Style, + value: Style, + null: Style, +} + +impl Default for TabStyles { + fn default() -> Self { + TabStyles { + label: Style::new().bold(), + value: Style::new(), + null: Style::new().dimmed(), + } + } +} + +fn indent(tw: &mut TabWriter>, depth: usize) { + for _ in 0..depth { + let _ = write!(tw, "\t"); + } +} + +fn pretty_print_value( + tw: &mut TabWriter>, + value: &serde_json::Value, + depth: usize, + styles: &TabStyles, +) { + match value { + serde_json::Value::Object(map) => { + for (key, val) in map { + pretty_print_field(tw, key, val, depth, styles); + } + } + serde_json::Value::Array(arr) => { + for (i, val) in arr.iter().enumerate() { + indent(tw, depth); + let _ = writeln!(tw, "{}", format!("[{}]", i).style(styles.label),); + pretty_print_value(tw, val, depth + 1, styles); + } + } + _ => { + indent(tw, depth); + let _ = writeln!(tw, "{}", format_scalar(value, styles)); + } + } +} + +fn pretty_print_field( + tw: &mut TabWriter>, + key: &str, + value: &serde_json::Value, + depth: usize, + styles: &TabStyles, +) { + match value { + serde_json::Value::Object(_) => { + indent(tw, depth); + let _ = writeln!(tw, "{}:", key.style(styles.label)); + pretty_print_value(tw, value, depth + 1, styles); + } + serde_json::Value::Array(arr) if arr.is_empty() => { + indent(tw, depth); + let _ = writeln!( + tw, + "{}:\t{}", + key.style(styles.label), + "[]".style(styles.null), + ); + } + serde_json::Value::Array(arr) if arr.iter().all(is_scalar) => { + // Print simple arrays inline, one value per line with the key on + // the first line only (mimics the existing TabDisplay list style). + for (i, val) in arr.iter().enumerate() { + indent(tw, depth); + if i == 0 { + let _ = writeln!( + tw, + "{}:\t{}", + key.style(styles.label), + format_scalar(val, styles), + ); + } else { + let _ = writeln!(tw, "\t{}", format_scalar(val, styles)); + } + } + } + serde_json::Value::Array(arr) => { + indent(tw, depth); + let _ = writeln!(tw, "{}:", key.style(styles.label)); + for (i, val) in arr.iter().enumerate() { + indent(tw, depth + 1); + let _ = writeln!(tw, "{}", format!("[{}]", i).style(styles.label),); + pretty_print_value(tw, val, depth + 2, styles); + } + } + _ => { + indent(tw, depth); + let _ = writeln!( + tw, + "{}:\t{}", + key.style(styles.label), + format_scalar(value, styles), + ); + } + } +} + +fn is_scalar(value: &serde_json::Value) -> bool { + matches!( + value, + serde_json::Value::Null + | serde_json::Value::Bool(_) + | serde_json::Value::Number(_) + | serde_json::Value::String(_) + ) +} + +fn format_scalar(value: &serde_json::Value, styles: &TabStyles) -> String { + match value { + serde_json::Value::Null => format!("{}", "null".style(styles.null)), + serde_json::Value::Bool(b) => format!("{}", b.style(styles.value)), + serde_json::Value::Number(n) => format!("{}", n.style(styles.value)), + serde_json::Value::String(s) => format!("{}", s.style(styles.value)), + // Fallback for non-scalars that end up here + other => format!("{}", other.style(styles.value)), + } +} From a5f6aec314b088e2765aa47f5d2385c5f47b7645 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 14:13:43 -0500 Subject: [PATCH 07/51] Rough cut --- Cargo.lock | 1 + v-api-permission-derive/src/lib.rs | 2 + v-api/src/config.rs | 45 +-- v-api/src/context/auth.rs | 28 +- v-api/src/context/mod.rs | 48 ++- v-api/src/context/oauth.rs | 9 +- v-api/src/endpoints/handlers.rs | 34 +- v-api/src/endpoints/login/magic_link/mod.rs | 1 + v-api/src/endpoints/login/mod.rs | 1 + v-api/src/endpoints/login/oauth/client.rs | 5 +- .../endpoints/login/oauth/{ => flow}/code.rs | 44 +-- .../login/oauth/{ => flow}/device_token.rs | 7 +- v-api/src/endpoints/login/oauth/flow/mod.rs | 2 + v-api/src/endpoints/login/oauth/mod.rs | 114 ++----- .../login/oauth/{ => remote}/github.rs | 31 +- .../login/oauth/{ => remote}/google.rs | 49 +-- v-api/src/endpoints/login/oauth/remote/mod.rs | 3 + .../login/oauth/{ => remote}/zendesk.rs | 72 ++--- v-cli-sdk/Cargo.toml | 1 + v-cli-sdk/src/cmd/auth/login.rs | 91 ++++-- v-cli-sdk/src/cmd/auth/oauth.rs | 298 ------------------ v-cli-sdk/src/cmd/auth/oauth/code.rs | 221 +++++++++++++ v-cli-sdk/src/cmd/auth/oauth/device.rs | 92 ++++++ v-cli-sdk/src/cmd/auth/oauth/mod.rs | 50 +++ v-cli-sdk/src/cmd/auth/proxy.rs | 46 ++- v-cli-sdk/src/lib.rs | 7 +- 26 files changed, 694 insertions(+), 608 deletions(-) rename v-api/src/endpoints/login/oauth/{ => flow}/code.rs (98%) rename v-api/src/endpoints/login/oauth/{ => flow}/device_token.rs (97%) create mode 100644 v-api/src/endpoints/login/oauth/flow/mod.rs rename v-api/src/endpoints/login/oauth/{ => remote}/github.rs (80%) rename v-api/src/endpoints/login/oauth/{ => remote}/google.rs (74%) create mode 100644 v-api/src/endpoints/login/oauth/remote/mod.rs rename v-api/src/endpoints/login/oauth/{ => remote}/zendesk.rs (60%) delete mode 100644 v-cli-sdk/src/cmd/auth/oauth.rs create mode 100644 v-cli-sdk/src/cmd/auth/oauth/code.rs create mode 100644 v-cli-sdk/src/cmd/auth/oauth/device.rs create mode 100644 v-cli-sdk/src/cmd/auth/oauth/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 2a07fcd3..9dc0cfb4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3916,6 +3916,7 @@ dependencies = [ "serde_json", "tabwriter", "tokio", + "uuid", ] [[package]] diff --git a/v-api-permission-derive/src/lib.rs b/v-api-permission-derive/src/lib.rs index 62dfe6ca..0dc29db4 100644 --- a/v-api-permission-derive/src/lib.rs +++ b/v-api-permission-derive/src/lib.rs @@ -439,6 +439,7 @@ fn from_system_permission_tokens( VPermission::ManageMagicLinkClientsAll => Self::ManageMagicLinkClientsAll, VPermission::CreateAccessToken => Self::CreateAccessToken, + VPermission::RetrieveRemoteAccessToken => Self::RetrieveRemoteAccessToken, VPermission::GetSagasAll => Self::GetSagasAll, VPermission::ManageSagasAll => Self::ManageSagasAll, @@ -722,6 +723,7 @@ fn system_permission_tokens() -> TokenStream { ManageMagicLinkClientsAll, CreateAccessToken, + RetrieveRemoteAccessToken, #[v_api(scope(to = "saga:r", from = "saga:r"))] GetSagasAll, diff --git a/v-api/src/config.rs b/v-api/src/config.rs index e56a07fc..18e0e32d 100644 --- a/v-api/src/config.rs +++ b/v-api/src/config.rs @@ -8,6 +8,7 @@ use jsonwebtoken::jwk::{ AlgorithmParameters, CommonParameters, Jwk, KeyAlgorithm, PublicKeyUse, RSAKeyParameters, RSAKeyType, }; +use newtype_uuid::TypedUuid; use partial_struct::partial; use rsa::{ pkcs1v15::{SigningKey, VerifyingKey}, @@ -20,6 +21,7 @@ use serde::{ de::{self, Visitor}, Deserialize, Deserializer, }; +use v_model::OAuthClientId; use std::path::PathBuf; use thiserror::Error; use v_api_param::{ParamResolutionError, StringParam}; @@ -156,7 +158,7 @@ pub struct OAuthProviders { } #[partial(ResolvedOAuthConfig)] -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct OAuthConfig { #[partial(ResolvedOAuthConfig(retype = Option))] pub device: Option, @@ -167,26 +169,27 @@ pub struct OAuthConfig { } #[partial(ResolvedOAuthDeviceConfig)] -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct OAuthDeviceConfig { - pub client_id: String, + pub client_id: TypedUuid, + pub remote_client_id: String, #[partial(ResolvedOAuthDeviceConfig(retype = SecretString))] - pub client_secret: StringParam, + pub remote_client_secret: StringParam, } #[partial(ResolvedOAuthWebConfig)] -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct OAuthWebConfig { - pub client_id: String, + pub remote_client_id: String, #[partial(ResolvedOAuthWebConfig(retype = SecretString))] - pub client_secret: StringParam, - pub redirect_uri: String, + pub remote_client_secret: StringParam, } #[partial(ResolvedOAuthWebProxyConfig)] -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct OAuthWebProxyConfig { - pub client_id: String, + pub client_id: TypedUuid, + pub redirect_uri: String, pub proxy_port: u16, } @@ -195,8 +198,11 @@ impl OAuthConfig { &self, base: Option, ) -> Result { - let device = self.device.as_ref().and_then(|d| d.resolve(base).ok()); - let web = self.web.as_ref().and_then(|w| w.resolve(base).ok()); + let device = self + .device + .as_ref() + .and_then(|d| d.resolve(base.clone()).ok()); + let web = self.web.as_ref().and_then(|w| w.resolve(base.clone()).ok()); let proxy_web = self.proxy_web.as_ref().and_then(|p| p.resolve(base).ok()); Ok(ResolvedOAuthConfig { device, @@ -210,10 +216,11 @@ impl OAuthDeviceConfig { &self, base: Option, ) -> Result { - let client_secret = self.client_secret.resolve(base)?; + let remote_client_secret = self.remote_client_secret.resolve(base)?; Ok(ResolvedOAuthDeviceConfig { client_id: self.client_id.clone(), - client_secret, + remote_client_id: self.remote_client_id.clone(), + remote_client_secret, }) } } @@ -222,21 +229,21 @@ impl OAuthWebConfig { &self, base: Option, ) -> Result { - let client_secret = self.client_secret.resolve(base)?; + let remote_client_secret = self.remote_client_secret.resolve(base)?; Ok(ResolvedOAuthWebConfig { - client_id: self.client_id.clone(), - client_secret, - redirect_uri: self.redirect_uri.clone(), + remote_client_id: self.remote_client_id.clone(), + remote_client_secret, }) } } impl OAuthWebProxyConfig { pub fn resolve( &self, - base: Option, + _base: Option, ) -> Result { Ok(ResolvedOAuthWebProxyConfig { client_id: self.client_id.clone(), + redirect_uri: self.redirect_uri.clone(), proxy_port: self.proxy_port, }) } diff --git a/v-api/src/context/auth.rs b/v-api/src/context/auth.rs index a818820b..dd8a5719 100644 --- a/v-api/src/context/auth.rs +++ b/v-api/src/context/auth.rs @@ -40,9 +40,22 @@ where jwks: JwkSet, signers: Vec, verifiers: Vec, + mut additional_permissions: Vec, ) -> Result { let signers = signers.into_iter().map(Arc::new).collect::>(); let verifiers = verifiers.into_iter().map(Arc::new).collect::>(); + additional_permissions.extend_from_slice(&[ + VPermission::CreateApiUser.into(), + VPermission::GetApiUsersAll.into(), + VPermission::ManageApiUsersAll.into(), + VPermission::GetApiKeysAll.into(), + VPermission::CreateGroup.into(), + VPermission::GetGroupsAll.into(), + VPermission::CreateMapper.into(), + VPermission::GetMappersAll.into(), + VPermission::GetOAuthClientsAll.into(), + VPermission::CreateAccessToken.into(), + ]); Ok(Self { unauthenticated_caller: Caller { id: "00000000-0000-4000-8000-000000000000".parse().unwrap(), @@ -51,19 +64,7 @@ where }, registration_caller: Caller { id: "00000000-0000-4000-8000-000000000001".parse().unwrap(), - permissions: vec![ - VPermission::CreateApiUser, - VPermission::GetApiUsersAll, - VPermission::ManageApiUsersAll, - VPermission::GetApiKeysAll, - VPermission::CreateGroup, - VPermission::GetGroupsAll, - VPermission::CreateMapper, - VPermission::GetMappersAll, - VPermission::GetOAuthClientsAll, - VPermission::CreateAccessToken, - ] - .into(), + permissions: additional_permissions.into(), extensions: HashMap::default(), }, jwt: JwtContext { @@ -215,6 +216,7 @@ mod tests { wrong_verifier.resolve_verifier(None).await.unwrap(), verifier.resolve_verifier(None).await.unwrap(), ], + vec![], ) .unwrap(); diff --git a/v-api/src/context/mod.rs b/v-api/src/context/mod.rs index 5672eab4..afe51505 100644 --- a/v-api/src/context/mod.rs +++ b/v-api/src/context/mod.rs @@ -764,6 +764,7 @@ pub struct VContextBuilder { keys: Option>, #[cfg(feature = "sagas")] saga: Option<(TypedUuid, Option)>, + additional_builtin_permissions: Vec, } impl Default for VContextBuilder @@ -790,6 +791,7 @@ where keys: None, #[cfg(feature = "sagas")] saga: None, + additional_builtin_permissions: Vec::new(), } } @@ -838,6 +840,11 @@ where self } + pub fn with_additional_builtin_permissions(mut self, permissions: Vec) -> Self { + self.additional_builtin_permissions = permissions; + self + } + pub async fn build(self) -> Result, VContextBuilderError> { if self.storage.is_some() && self.storage_url.is_some() { return Err(VContextBuilderError::ConfigConflict( @@ -900,7 +907,14 @@ where .into_iter() .filter_map(|key| key.ok()) .collect::>(); - let auth_ctx = AuthContext::new(jwt, jwks, signers, verifiers).map_err(|err| { + let auth_ctx = AuthContext::new( + jwt, + jwks, + signers, + verifiers, + self.additional_builtin_permissions, + ) + .map_err(|err| { tracing::error!(?err, "Auth context construction failed"); VContextError::InternalAuthContext })?; @@ -1245,9 +1259,12 @@ pub(crate) mod test_mocks { }; use crate::{ - config::JwtConfig, + config::{ + JwtConfig, ResolvedOAuthConfig, ResolvedOAuthWebConfig, ResolvedOAuthWebProxyConfig, + }, endpoints::login::oauth::{ - google::GoogleOAuthProvider, zendesk::ZendeskOAuthProvider, OAuthProviderName, + remote::google::GoogleOAuthProvider, remote::zendesk::ZendeskOAuthProvider, + OAuthProviderName, }, mapper::DefaultMappingEngine, permissions::VPermission, @@ -1280,11 +1297,15 @@ pub(crate) mod test_mocks { OAuthProviderName::Google, Box::new(move || { Box::new(GoogleOAuthProvider::new( + ResolvedOAuthConfig { + device: None, + web: Some(ResolvedOAuthWebConfig { + remote_client_id: "google_web_client_id".to_string(), + remote_client_secret: "google_web_client_secret".to_string().into(), + }), + proxy_web: None, + }, "https://test_public_url".to_string(), - "google_device_client_id".to_string(), - "google_device_client_secret".to_string().into(), - "google_web_client_id".to_string(), - "google_web_client_secret".to_string().into(), None, )) }), @@ -1294,12 +1315,17 @@ pub(crate) mod test_mocks { OAuthProviderName::Zendesk, Box::new(move || { Box::new(ZendeskOAuthProvider::new( + ResolvedOAuthConfig { + device: None, + web: None, + proxy_web: Some(ResolvedOAuthWebProxyConfig { + client_id: TypedUuid::new_v4(), + redirect_uri: "test".to_string(), + proxy_port: 1234, + }), + }, "https://test_public_url".to_string(), "subdomain".to_string(), - "google_device_client_id".to_string(), - "google_device_client_secret".to_string().into(), - "google_web_client_id".to_string(), - "google_web_client_secret".to_string().into(), None, )) }), diff --git a/v-api/src/context/oauth.rs b/v-api/src/context/oauth.rs index 41835b60..1a9ee0c4 100644 --- a/v-api/src/context/oauth.rs +++ b/v-api/src/context/oauth.rs @@ -40,15 +40,10 @@ where pub async fn create_oauth_client( &self, caller: &Caller, + id: TypedUuid, ) -> ResourceResult { if caller.can(&VPermission::CreateOAuthClient.into()) { - Ok(OAuthClientStore::upsert( - &*self.storage, - NewOAuthClient { - id: TypedUuid::new_v4(), - }, - ) - .await?) + Ok(OAuthClientStore::upsert(&*self.storage, NewOAuthClient { id }).await?) } else { resource_restricted() } diff --git a/v-api/src/endpoints/handlers.rs b/v-api/src/endpoints/handlers.rs index f5de62bb..f939ee38 100644 --- a/v-api/src/endpoints/handlers.rs +++ b/v-api/src/endpoints/handlers.rs @@ -69,16 +69,16 @@ mod macros { DeleteOAuthClientSecretPath, GetOAuthClientPath, InitialOAuthClientSecretResponse, }, - code::{ + flow::code::{ authz_code_callback_op, authz_code_exchange_op, authz_code_redirect_op, - get_web_pkce_provider_op, - OAuthAuthzCodeExchangeBody, OAuthAuthzCodeExchangeResponse, - OAuthAuthzCodeQuery, OAuthAuthzCodeReturnQuery, OAuthAuthzCodeExchangeQuery + get_public_pkce_provider_op, + OAuthAuthzCodeExchangeBody, OAuthAuthzCodeExchangeResponse, OAuthAuthzCodeExchangeQuery, + OAuthAuthzCodeQuery, OAuthAuthzCodeReturnQuery, }, - device_token::{ + flow::device_token::{ exchange_device_token_op, get_device_provider_op, AccessTokenExchangeRequest, }, - OAuthProviderInfo, OAuthProviderNameParam, + OAuthProviderNameParam, OAuthProviderDeviceInfo, OAuthProviderAuthorizationCodePkceInfo } }, mappers::{ @@ -260,7 +260,7 @@ mod macros { // LOGIN ENDPOINTS - // AUTHZ CODE + // AUTHORIZATION CODE FLOW /// Generate the remote provider login url and redirect the user #[endpoint { @@ -296,27 +296,29 @@ mod macros { }] pub async fn authz_code_exchange( rqctx: RequestContext<$context_type>, - path: Path, query: Query, + path: Path, body: TypedBody, ) -> Result, HttpError> { - authz_code_exchange_op(&rqctx, path, query, body).await + authz_code_exchange_op(&rqctx, query, path, body).await } - // DEVICE CODE + // AUTHORIZATION CODE PKCE ONLY FLOW - /// Retrieve the metadata about an OAuth provider for public authorization code flow + /// Retrieve the metadata about an OAuth provider for public PKCE authorization code flow #[endpoint { method = GET, - path = "/login/oauth/{provider}/web-pkce" + path = "/login/oauth/{provider}/public-pkce" }] pub async fn get_web_pkce_provider( rqctx: RequestContext<$context_type>, path: Path, - ) -> Result, HttpError> { - get_web_pkce_provider_op(&rqctx, path).await + ) -> Result, HttpError> { + get_public_pkce_provider_op(&rqctx, path).await } + // DEVICE CODE FLOW + /// Retrieve the metadata about an OAuth provider for limited input flow #[endpoint { method = GET, @@ -325,7 +327,7 @@ mod macros { pub async fn get_device_provider( rqctx: RequestContext<$context_type>, path: Path, - ) -> Result, HttpError> { + ) -> Result, HttpError> { get_device_provider_op(&rqctx, path).await } @@ -771,7 +773,7 @@ mod macros { .expect("Failed to register endpoint"); // OAuth Login - $api.register(get_web_provider) + $api.register(get_web_pkce_provider) .expect("Failed to register endpoint"); $api.register(get_device_provider) .expect("Failed to register endpoint"); diff --git a/v-api/src/endpoints/login/magic_link/mod.rs b/v-api/src/endpoints/login/magic_link/mod.rs index 2108a5b6..b666b7cc 100644 --- a/v-api/src/endpoints/login/magic_link/mod.rs +++ b/v-api/src/endpoints/login/magic_link/mod.rs @@ -218,6 +218,7 @@ where external_id: ExternalUserId::MagicLink(body.recipient.clone()), verified_emails: vec![body.recipient], display_name: None, + idp_token: None, }, ) .await?; diff --git a/v-api/src/endpoints/login/mod.rs b/v-api/src/endpoints/login/mod.rs index 05a7207e..ef683c2e 100644 --- a/v-api/src/endpoints/login/mod.rs +++ b/v-api/src/endpoints/login/mod.rs @@ -191,6 +191,7 @@ pub struct UserInfo { pub external_id: ExternalUserId, pub verified_emails: Vec, pub display_name: Option, + pub idp_token: Option, } #[derive(Debug, Error)] diff --git a/v-api/src/endpoints/login/oauth/client.rs b/v-api/src/endpoints/login/oauth/client.rs index c8fd8304..b7d799d7 100644 --- a/v-api/src/endpoints/login/oauth/client.rs +++ b/v-api/src/endpoints/login/oauth/client.rs @@ -54,7 +54,10 @@ where T: VAppPermission + From + PermissionStorage, { // Create the new client - let client = ctx.oauth.create_oauth_client(&caller).await?; + let client = ctx + .oauth + .create_oauth_client(&caller, TypedUuid::new_v4()) + .await?; // Give the caller permission to perform actions on the client ctx.user diff --git a/v-api/src/endpoints/login/oauth/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs similarity index 98% rename from v-api/src/endpoints/login/oauth/code.rs rename to v-api/src/endpoints/login/oauth/flow/code.rs index 1528f495..c6529c26 100644 --- a/v-api/src/endpoints/login/oauth/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -31,7 +31,8 @@ use v_model::{ LoginAttempt, LoginAttemptId, NewLoginAttempt, OAuthClient, OAuthClientId, }; -use super::{OAuthProvider, OAuthProviderNameParam, UserInfoProvider}; +use super::super::{OAuthProvider, OAuthProviderNameParam}; +use crate::endpoints::login::UserInfoProvider; use crate::{ authn::key::RawKey, context::{ApiContext, VContext}, @@ -155,7 +156,7 @@ where } #[instrument(skip(rqctx), err(Debug))] -pub async fn get_web_pkce_provider_op( +pub async fn get_public_pkce_provider_op( rqctx: &RequestContext>, path: Path, ) -> Result, HttpError> @@ -466,8 +467,7 @@ where #[derive(Debug, Deserialize, JsonSchema)] pub struct OAuthAuthzCodeExchangeQuery { - #[serde(default)] - pub include_idp_token: bool, + pub request_idp_token: bool, } #[derive(Debug, Deserialize, JsonSchema)] @@ -485,7 +485,7 @@ pub struct OAuthAuthzCodeExchangeResponse { pub access_token: String, pub token_type: String, pub expires_in: i64, - pub idp_token: Option, + pub idp_token: Option, } #[derive(Debug, Deserialize, JsonSchema, Serialize)] @@ -496,16 +496,16 @@ pub struct OAuthAuthzCodeIdpToken { #[instrument(skip(rqctx), err(Debug))] pub async fn authz_code_exchange_op( rqctx: &RequestContext>, - path: Path, query: Query, + path: Path, body: TypedBody, ) -> Result, HttpError> where T: VAppPermission + PermissionStorage, { let ctx = rqctx.v_ctx(); - let path = path.into_inner(); let query = query.into_inner(); + let path = path.into_inner(); let body = body.into_inner(); let provider = ctx @@ -527,13 +527,17 @@ where ), auth.password().unwrap().to_string(), ))), + Ok(auth) if auth.username().is_none() && auth.password().is_none() => { + tracing::info!("Credentials for code exchange not defined via basic auth"); + Ok(None) + } Ok(_) => Err(bad_request( "Malformed credentials presented to code exchange", )), Err(err) => { tracing::info!( ?err, - "Credentials for code exchange not defined via basic auth" + "Failed to extract basic authentication credentials" ); Ok(None) } @@ -546,6 +550,7 @@ where // We of course deny underspecifying credentials, but we also want to disallow over specifying // them. For example, if the client provides both basic auth and a client id/secret in the // request body, we should reject the request. + tracing::debug!(?basic_credentials, ?body_credentials, "Extracted credentials from request"); let (client_id, client_secret) = match (basic_credentials, body_credentials) { (Some(_), (Some(_), _)) => Err(bad_request( "Cannot provide both basic auth and client credentials", @@ -604,14 +609,15 @@ where // Now that the attempt has been confirmed, use it to fetch user information form the remote // provider - let (info, raw_token) = fetch_user_info( + let info = fetch_user_info( ctx.public_url(), &ctx.web_client(), &*provider, &attempt, - query.include_idp_token, + !query.request_idp_token, ) .await?; + let idp_token = info.idp_token.clone(); tracing::debug!("Retrieved user information from remote provider"); @@ -659,9 +665,7 @@ where token_type: "Bearer".to_string(), access_token: token.signed_token, expires_in: token.expires_in, - idp_token: query.include_idp_token.then(|| OAuthAuthzCodeIdpToken { - token: raw_token.unwrap(), - }), + idp_token, })) } @@ -801,8 +805,8 @@ async fn fetch_user_info( client_type: &ClientType, provider: &dyn OAuthProvider, attempt: &LoginAttempt, - return_raw: bool, -) -> Result<(UserInfo, Option), HttpError> { + revoke_idp_token: bool, +) -> Result { let provider_info = provider .authz_code_flow_info() .ok_or_else(|| internal_error("Authorization code flow not supported"))?; @@ -842,7 +846,7 @@ async fn fetch_user_info( // Now that we are done with fetching user information from the remote API, we can revoke it if // the provider supports it - if !return_raw && provider_info.revocation_endpoint.is_some() { + if revoke_idp_token && provider_info.remote.revocation_endpoint.is_some() { client .revoke_token(response.access_token().into()) .map_err(internal_error)? @@ -851,11 +855,7 @@ async fn fetch_user_info( .map_err(internal_error)?; } - if return_raw { - Ok((info, Some(response.access_token().secret().to_string()))) - } else { - Ok((info, None)) - } + Ok(info) } #[cfg(test)] @@ -891,7 +891,7 @@ mod tests { VContext, }, endpoints::login::oauth::{ - code::{ + flow::code::{ authz_code_callback_op_inner, verify_csrf, verify_login_attempt, OAuthAuthzCodeReturnQuery, OAuthError, OAuthErrorCode, LOGIN_ATTEMPT_COOKIE, }, diff --git a/v-api/src/endpoints/login/oauth/device_token.rs b/v-api/src/endpoints/login/oauth/flow/device_token.rs similarity index 97% rename from v-api/src/endpoints/login/oauth/device_token.rs rename to v-api/src/endpoints/login/oauth/flow/device_token.rs index 77530ea8..d18ddae4 100644 --- a/v-api/src/endpoints/login/oauth/device_token.rs +++ b/v-api/src/endpoints/login/oauth/flow/device_token.rs @@ -14,7 +14,8 @@ use tap::TapFallible; use tracing::instrument; use v_model::permissions::PermissionStorage; -use super::{OAuthProviderNameParam, UserInfoProvider}; +use super::super::OAuthProviderNameParam; +use crate::endpoints::login::UserInfoProvider; use crate::{ context::ApiContext, endpoints::login::{oauth::OAuthProviderDeviceInfo, LoginError}, @@ -75,10 +76,10 @@ impl AccessTokenExchange { pub fn new(req: AccessTokenExchangeRequest, provider: &OAuthProviderDeviceInfo) -> Self { Self { provider: ProviderTokenExchange { - client_id: provider.client_id.clone(), + client_id: provider.remote_client_id.clone(), device_code: req.device_code, grant_type: req.grant_type, - client_secret: provider.client_secret.0.expose_secret().to_string(), + client_secret: provider.remote_client_secret.0.expose_secret().to_string(), }, expires_at: req.expires_at, } diff --git a/v-api/src/endpoints/login/oauth/flow/mod.rs b/v-api/src/endpoints/login/oauth/flow/mod.rs new file mode 100644 index 00000000..304abf26 --- /dev/null +++ b/v-api/src/endpoints/login/oauth/flow/mod.rs @@ -0,0 +1,2 @@ +pub mod code; +pub mod device_token; diff --git a/v-api/src/endpoints/login/oauth/mod.rs b/v-api/src/endpoints/login/oauth/mod.rs index cbaeb1c6..3b6c58e6 100644 --- a/v-api/src/endpoints/login/oauth/mod.rs +++ b/v-api/src/endpoints/login/oauth/mod.rs @@ -5,18 +5,19 @@ use async_trait::async_trait; use http::Method; use hyper::{body::Bytes, header::HeaderValue, header::AUTHORIZATION}; +use newtype_uuid::TypedUuid; use oauth2::{ basic::BasicClient, url::ParseError, AuthUrl, ClientId, ClientSecret, EndpointMaybeSet, EndpointNotSet, EndpointSet, RedirectUrl, RevocationUrl, TokenUrl, }; use reqwest::Request; use schemars::JsonSchema; -use secrecy::{ExposeSecret, SecretString}; +use secrecy::ExposeSecret; use serde::{Deserialize, Serialize}; use std::fmt::{Debug, Display}; use thiserror::Error; use tracing::instrument; -use v_model::OAuthClient; +use v_model::{OAuthClient, OAuthClientId}; use crate::{ authn::{key::RawKey, Verify}, @@ -26,11 +27,8 @@ use crate::{ use super::{UserInfo, UserInfoError, UserInfoProvider}; pub mod client; -pub mod code; -pub mod device_token; -pub mod github; -pub mod google; -pub mod zendesk; +pub mod flow; +pub mod remote; #[derive(Debug, Error)] pub enum OAuthProviderError { @@ -64,14 +62,6 @@ pub type WebClient = BasicClient< EndpointSet, >; -pub struct OAuthPublicCredentials { - client_id: String, -} - -pub struct OAuthPrivateCredentials { - client_secret: SecretString, -} - pub trait OAuthProvider: ExtractUserInfo + Debug + Send + Sync { fn name(&self) -> OAuthProviderName; fn initialize_headers(&self, request: &mut Request); @@ -85,62 +75,24 @@ pub trait OAuthProvider: ExtractUserInfo + Debug + Send + Sync { fn default_scopes(&self) -> &[String]; fn supports_pkce(&self) -> bool { - false + true } - // TODO: How can user info be change to something statically checked instead of a runtime check - // fn auth_ur_endpoint(&self) -> Option<&str>; - // fn redirect_endpoint(&self) -> Option<&str>; - - // fn token_exchange_content_type(&self) -> &str; - // fn token_exchange_endpoint(&self) -> &str; - // fn token_revocation_endpoint(&self) -> Option<&str>; - - // fn supports_pkce_only(&self) -> bool { false } - - // fn device_code_endpoint(&self) -> Option<&str>; - // fn token_endpoint(&self) -> Option<&str>; - - // fn provider_info(&self, client_type: &ClientType) -> Option { - // let default_scopes = self - // .scopes() - // .into_iter() - // .map(|s| s.to_string()) - // .collect::>(); - // self.client_id(client_type).map(|client_id| OAuthProviderInfo { - // provider: self.name(), - // client_id: client_id.to_string(), - // code: self.authz_code_flow_info(), - // pkce: self.authz_code_pkce_flow_info(), - // device: self.device_code_flow_info(), - // // auth_url_endpoint: self.auth_url_endpoint().to_string(), - // // device_code_endpoint: self.device_code_endpoint().map(|s| s.to_string()), - // // token_endpoint: self.token_endpoint().map(|s| s.to_string()), - // // redirect_endpoint: self.redirect_endpoint().map(|s| s.to_string()), - // // supports_pkce_only: self.supports_pkce_only(), - // // scopes: self - // // .scopes() - // // .into_iter() - // // .map(|s| s.to_string()) - // // .collect::>(), - // }) - // } - fn as_web_client(&self) -> Result { match self.authz_code_flow_info() { Some(info) => { - let client = BasicClient::new(ClientId::new(info.client_id.clone())) - .set_auth_uri(AuthUrl::new(info.auth_url_endpoint.clone())?) - .set_token_uri(TokenUrl::new(info.token_endpoint.clone())?) + let client = BasicClient::new(ClientId::new(info.remote.client_id.clone())) + .set_auth_uri(AuthUrl::new(info.remote.auth_url_endpoint.clone())?) + .set_token_uri(TokenUrl::new(info.remote.token_endpoint.clone())?) .set_revocation_url_option( - info.revocation_endpoint + info.remote.revocation_endpoint .as_ref() .map(|url| RevocationUrl::new(url.to_string())) .transpose()?, ) .set_redirect_uri(RedirectUrl::new(info.redirect_endpoint.to_string())?) .set_client_secret(ClientSecret::new( - info.client_secret.0.expose_secret().to_string(), + info.remote.client_secret.0.expose_secret().to_string(), )); Ok(client) @@ -190,7 +142,9 @@ where responses.push(bytes); } - self.extract_user_info(&responses) + let mut info = self.extract_user_info(&responses)?; + info.idp_token = Some(token.to_string()); + Ok(info) } } @@ -205,50 +159,36 @@ pub struct OAuthProviderInfo { #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct OAuthProviderAuthorizationCodeInfo { - client_id: String, - client_secret: OpenApiSecretString, auth_url_endpoint: String, redirect_endpoint: String, token_endpoint_content_type: String, token_endpoint: String, - revocation_endpoint: Option, -} - -impl OAuthProviderAuthorizationCodeInfo { - fn as_web_client(&self) -> Result { - let client = BasicClient::new(ClientId::new(self.client_id.clone())) - .set_auth_uri(AuthUrl::new(self.auth_url_endpoint.clone())?) - .set_token_uri(TokenUrl::new(self.token_endpoint.clone())?) - .set_revocation_url_option( - self.revocation_endpoint - .as_ref() - .map(|url| RevocationUrl::new(url.to_string())) - .transpose()?, - ) - .set_redirect_uri(RedirectUrl::new(self.redirect_endpoint.to_string())?) - .set_client_secret(ClientSecret::new( - self.client_secret.0.expose_secret().to_string(), - )); - - Ok(client) - } + remote: OAuthProviderAuthorizationCodeRemoteInfo, } #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] -pub struct OAuthProviderAuthorizationCodePkceInfo { +pub struct OAuthProviderAuthorizationCodeRemoteInfo { client_id: String, + client_secret: OpenApiSecretString, auth_url_endpoint: String, - redirect_endpoint: String, token_endpoint_content_type: String, token_endpoint: String, revocation_endpoint: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub struct OAuthProviderAuthorizationCodePkceInfo { + client_id: TypedUuid, + redirect_endpoint: String, proxy_port: u16, + web: OAuthProviderAuthorizationCodeInfo, } #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct OAuthProviderDeviceInfo { - client_id: String, - client_secret: OpenApiSecretString, + client_id: TypedUuid, + remote_client_id: String, + remote_client_secret: OpenApiSecretString, device_code_endpoint: String, token_endpoint_content_type: String, token_endpoint: String, diff --git a/v-api/src/endpoints/login/oauth/github.rs b/v-api/src/endpoints/login/oauth/remote/github.rs similarity index 80% rename from v-api/src/endpoints/login/oauth/github.rs rename to v-api/src/endpoints/login/oauth/remote/github.rs index 9b7de3e9..0d15a685 100644 --- a/v-api/src/endpoints/login/oauth/github.rs +++ b/v-api/src/endpoints/login/oauth/remote/github.rs @@ -11,15 +11,13 @@ use std::fmt; use crate::{ config::ResolvedOAuthConfig, endpoints::login::{ - oauth::{ - OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, - OAuthProviderDeviceInfo, - }, - ExternalUserId, UserInfo, UserInfoError, + ExternalUserId, UserInfo, UserInfoError, oauth::{ + OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo + } }, }; -use super::{ExtractUserInfo, OAuthProvider, OAuthProviderName}; +use super::super::{ExtractUserInfo, OAuthProvider, OAuthProviderName}; pub struct GitHubOAuthProvider { authz_code_flow_info: Option, @@ -48,17 +46,23 @@ impl GitHubOAuthProvider { default_scopes.extend(additional_scopes.unwrap_or_default()); let authz_code_flow_info = config.web.map(|web| OAuthProviderAuthorizationCodeInfo { - client_id: web.client_id, - client_secret: web.client_secret.into(), - auth_url_endpoint: "https://github.com/login/oauth/authorize".to_string(), - redirect_endpoint: format!("{}/login/oauth/github/code/callback", public_url,), + auth_url_endpoint: format!("{}/login/oauth/github/code/authorize", public_url), + redirect_endpoint: format!("{}/login/oauth/github/code/callback", public_url), token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), - token_endpoint: "https://github.com/login/oauth/access_token".to_string(), - revocation_endpoint: None, + token_endpoint: format!("{}/login/oauth/github/device/exchange", public_url), + remote: OAuthProviderAuthorizationCodeRemoteInfo { + client_id: web.remote_client_id, + client_secret: web.remote_client_secret.into(), + auth_url_endpoint: "https://github.com/login/oauth/authorize".to_string(), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: "https://github.com/login/oauth/access_token".to_string(), + revocation_endpoint: None, + }, }); let device_code_flow_info = config.device.map(|device| OAuthProviderDeviceInfo { client_id: device.client_id, - client_secret: device.client_secret.into(), + remote_client_id: device.remote_client_id, + remote_client_secret: device.remote_client_secret.into(), device_code_endpoint: "https://github.com/login/device/code".to_string(), token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), token_endpoint: "https://github.com/login/oauth/access_token".to_string(), @@ -114,6 +118,7 @@ impl ExtractUserInfo for GitHubOAuthProvider { external_id: ExternalUserId::GitHub(user.id.to_string()), verified_emails, display_name: Some(user.login), + idp_token: None, }) } } diff --git a/v-api/src/endpoints/login/oauth/google.rs b/v-api/src/endpoints/login/oauth/remote/google.rs similarity index 74% rename from v-api/src/endpoints/login/oauth/google.rs rename to v-api/src/endpoints/login/oauth/remote/google.rs index 666815e4..24f02b91 100644 --- a/v-api/src/endpoints/login/oauth/google.rs +++ b/v-api/src/endpoints/login/oauth/remote/google.rs @@ -10,15 +10,13 @@ use std::fmt; use crate::{ config::ResolvedOAuthConfig, endpoints::login::{ - oauth::{ - OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, - OAuthProviderDeviceInfo, - }, - ExternalUserId, UserInfo, UserInfoError, + ExternalUserId, UserInfo, UserInfoError, oauth::{ + OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo + } }, }; -use super::{ExtractUserInfo, OAuthProvider, OAuthProviderName}; +use super::super::{ExtractUserInfo, OAuthProvider, OAuthProviderName}; pub struct GoogleOAuthProvider { authz_code_flow_info: Option, @@ -48,33 +46,37 @@ impl GoogleOAuthProvider { default_scopes.extend(additional_scopes.unwrap_or_default()); let authz_code_flow_info = config.web.map(|web| OAuthProviderAuthorizationCodeInfo { - client_id: web.client_id, - client_secret: web.client_secret.into(), - auth_url_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(), - redirect_endpoint: format!("{}/login/oauth/github/code/callback", public_url,), + auth_url_endpoint: format!("{}/login/oauth/google/code/authorize", public_url), + redirect_endpoint: format!("{}/login/oauth/google/code/callback", public_url), token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), - token_endpoint: "https://oauth2.googleapis.com/token".to_string(), - revocation_endpoint: None, + token_endpoint: format!("{}/login/oauth/google/device/exchange", public_url), + remote: OAuthProviderAuthorizationCodeRemoteInfo { + client_id: web.remote_client_id, + client_secret: web.remote_client_secret.into(), + auth_url_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: "https://oauth2.googleapis.com/token".to_string(), + revocation_endpoint: Some("https://oauth2.googleapis.com/revoke".to_string()), + }, }); let authz_code_pkce_flow_info = config .proxy_web - .map(|proxy| OAuthProviderAuthorizationCodePkceInfo { + .and_then(|proxy| authz_code_flow_info.as_ref().map(|web| (web, proxy))) + .map(|(web, proxy)| OAuthProviderAuthorizationCodePkceInfo { client_id: proxy.client_id, - auth_url_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(), - redirect_endpoint: format!("{}/login/oauth/github/code/callback", public_url,), - token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), - token_endpoint: "https://oauth2.googleapis.com/token".to_string(), + redirect_endpoint: proxy.redirect_uri, proxy_port: proxy.proxy_port, - revocation_endpoint: None, + web: web.clone() }); let device_code_flow_info = config.device.map(|device| OAuthProviderDeviceInfo { client_id: device.client_id, - client_secret: device.client_secret.into(), - device_code_endpoint: "https://github.com/login/device/code".to_string(), + remote_client_id: device.remote_client_id, + remote_client_secret: device.remote_client_secret.into(), + device_code_endpoint: "https://oauth2.googleapis.com/device/code".to_string(), token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), - token_endpoint: "https://github.com/login/oauth/access_token".to_string(), - revocation_endpoint: None, + token_endpoint: "https://oauth2.googleapis.com/token".to_string(), + revocation_endpoint: Some("https://oauth2.googleapis.com/revoke".to_string()), }); Self { @@ -144,6 +146,7 @@ impl ExtractUserInfo for GoogleOAuthProvider { external_id: ExternalUserId::Google(remote_info.sub), verified_emails, display_name, + idp_token: None, }) } } @@ -170,7 +173,7 @@ impl OAuthProvider for GoogleOAuthProvider { self.authz_code_flow_info.as_ref() } fn authz_code_pkce_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodePkceInfo> { - None + self.authz_code_pkce_flow_info.as_ref() } fn device_code_flow_info(&self) -> Option<&OAuthProviderDeviceInfo> { self.device_code_flow_info.as_ref() diff --git a/v-api/src/endpoints/login/oauth/remote/mod.rs b/v-api/src/endpoints/login/oauth/remote/mod.rs new file mode 100644 index 00000000..db2a1dd9 --- /dev/null +++ b/v-api/src/endpoints/login/oauth/remote/mod.rs @@ -0,0 +1,3 @@ +pub mod github; +pub mod google; +pub mod zendesk; diff --git a/v-api/src/endpoints/login/oauth/zendesk.rs b/v-api/src/endpoints/login/oauth/remote/zendesk.rs similarity index 60% rename from v-api/src/endpoints/login/oauth/zendesk.rs rename to v-api/src/endpoints/login/oauth/remote/zendesk.rs index cedc3182..26e1578c 100644 --- a/v-api/src/endpoints/login/oauth/zendesk.rs +++ b/v-api/src/endpoints/login/oauth/remote/zendesk.rs @@ -10,15 +10,13 @@ use std::fmt; use crate::{ config::ResolvedOAuthConfig, endpoints::login::{ - oauth::{ - OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, - OAuthProviderDeviceInfo, - }, - ExternalUserId, UserInfo, UserInfoError, + ExternalUserId, UserInfo, UserInfoError, oauth::{ + OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo + } }, }; -use super::{ExtractUserInfo, OAuthProvider, OAuthProviderName}; +use super::super::{ExtractUserInfo, OAuthProvider, OAuthProviderName}; pub struct ZendeskOAuthProvider { authz_code_flow_info: Option, @@ -41,65 +39,34 @@ impl ZendeskOAuthProvider { subdomain: String, additional_scopes: Option>, ) -> Self { - // let base_url = format!("https://{}.zendesk.com", subdomain); - - // Self { - // device_public: OAuthPublicCredentials { - // client_id: device_client_id, - // }, - // device_private: Some(OAuthPrivateCredentials { - // client_secret: device_client_secret, - // }), - // web_public: OAuthPublicCredentials { - // client_id: web_client_id, - // }, - // web_private: Some(OAuthPrivateCredentials { - // client_secret: web_client_secret, - // }), - // web_pkce: web_pkce_client_id.map(|client_id| OAuthPublicCredentials { - // client_id, - // }), - // web_pkce_port: web_pkce_port, - // additional_scopes: additional_scopes.unwrap_or_default(), - // client: reqwest::ClientBuilder::new() - // .redirect(reqwest::redirect::Policy::none()) - // .build() - // .expect("Static client must build"), - // user_info_endpoint: format!("{}/api/v2/users/me.json", base_url), - // auth_url_endpoint: format!("{}/oauth/authorizations/new", base_url), - // token_exchange_endpoint: format!("{}/oauth/tokens", base_url), - // token_endpoint: Some(format!( - // "{}/login/oauth/zendesk/device/exchange", - // public_url - // )), - // redirect_endpoint: Some(format!("{}/login/oauth/zendesk/code/callback", public_url,)), - // } - let base_url = format!("https://{}.zendesk.com", subdomain); - let mut default_scopes = vec!["users:read".to_string()]; + let mut default_scopes = vec!["read".to_string(), "write".to_string()]; default_scopes.extend(additional_scopes.unwrap_or_default()); let authz_code_flow_info = config.web.map(|web| OAuthProviderAuthorizationCodeInfo { - client_id: web.client_id, - client_secret: web.client_secret.into(), - auth_url_endpoint: format!("{}/oauth/authorizations/new", base_url), + auth_url_endpoint: format!("{}/login/oauth/zendesk/code/authorize", public_url), redirect_endpoint: format!("{}/login/oauth/zendesk/code/callback", public_url), token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), - token_endpoint: format!("{}/oauth/tokens", base_url), - revocation_endpoint: None, + token_endpoint: format!("{}/login/oauth/zendesk/device/exchange", public_url), + remote: OAuthProviderAuthorizationCodeRemoteInfo { + client_id: web.remote_client_id, + client_secret: web.remote_client_secret.into(), + auth_url_endpoint: format!("{}/oauth/authorizations/new", base_url), + token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), + token_endpoint: format!("{}/oauth/tokens", base_url), + revocation_endpoint: None + }, }); let authz_code_pkce_flow_info = config .proxy_web - .map(|proxy| OAuthProviderAuthorizationCodePkceInfo { + .and_then(|proxy| authz_code_flow_info.as_ref().map(|web| (web, proxy))) + .map(|(web, proxy)| OAuthProviderAuthorizationCodePkceInfo { client_id: proxy.client_id, - auth_url_endpoint: format!("{}/oauth/authorizations/new", base_url), - redirect_endpoint: format!("{}/login/oauth/zendesk/code/callback", public_url), - token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), - token_endpoint: format!("{}/oauth/tokens", base_url), + redirect_endpoint: proxy.redirect_uri, proxy_port: proxy.proxy_port, - revocation_endpoint: None, + web: web.clone() }); Self { @@ -153,6 +120,7 @@ impl ExtractUserInfo for ZendeskOAuthProvider { external_id: ExternalUserId::Zendesk(user.id.to_string()), verified_emails, display_name: Some(user.name), + idp_token: None, }) } } diff --git a/v-cli-sdk/Cargo.toml b/v-cli-sdk/Cargo.toml index dc1d1d75..030df94a 100644 --- a/v-cli-sdk/Cargo.toml +++ b/v-cli-sdk/Cargo.toml @@ -21,3 +21,4 @@ serde = { workspace = true } serde_json = { workspace = true } tabwriter = { workspace = true } tokio = { workspace = true, features = ["macros", "rt", "net", "sync"] } +uuid = { workspace = true } diff --git a/v-cli-sdk/src/cmd/auth/login.rs b/v-cli-sdk/src/cmd/auth/login.rs index 8b77c2cd..762885da 100644 --- a/v-cli-sdk/src/cmd/auth/login.rs +++ b/v-cli-sdk/src/cmd/auth/login.rs @@ -14,6 +14,7 @@ use crate::{ pub trait CliAdapterToken { fn access_token(&self) -> &str; + fn idp_token(&self) -> Option<&str>; } pub trait CliConsumerLoginProvider: Into + Subcommand + Debug + Clone {} @@ -41,11 +42,21 @@ where T: VCliContext, >::Error: StdError + Send + Sync + 'static, { - let access_token = self.method.run(ctx, &self.mode).await?; + let (access_token, idp_token) = self.method.run(ctx, self.mode).await?; ctx.config_mut().set_token(access_token); ctx.config_mut().save()?; + // If we are acquiring an IdP token, present it to the user. + if let Some(idp_token) = idp_token { + println!("\nYou can now additionally authenticate against the requested remote service API \ + with the following token."); + println!("IdP token: {}", idp_token); + println!(""); + println!("Please note that this should be kept secure as calls made with this token are \ + made on behalf of your user acount"); + } + Ok(()) } } @@ -60,6 +71,10 @@ where OAuth { #[command(subcommand)] provider: SupportedProviders, + /// Additionally retrieve a the underlying IdP token. This token is not stored. Remote mode + /// should be used when you need to authenticate to the underlying system frontend by the API + #[arg(long, default_value = "false")] + request_idp_token: bool, }, /// Login via Magic Link #[command(name = "mlink")] @@ -77,7 +92,7 @@ pub enum LoginProvider { Zendesk, } -#[derive(ValueEnum, Debug, Clone, PartialEq)] +#[derive(Copy, ValueEnum, Debug, Clone, PartialEq)] pub enum AuthenticationMode { /// Retrieve and store an identity token. Identity mode is the default and should be used to /// when you do not require extended (multi-day) access @@ -87,23 +102,19 @@ pub enum AuthenticationMode { /// a machine for continued access. This requires the permission to create api tokens #[value(name = "token")] Token, - /// Retrieve and store a remote token. Remote mode should be used when you want to authenticate - /// and retrieve a token for use against the underlying authentication provider - #[value(name = "remote")] - Remote, } impl LoginMethod where SupportedProviders: CliConsumerLoginProvider, { - pub async fn run(&self, ctx: &T, mode: &AuthenticationMode) -> Result + pub async fn run(&self, ctx: &T, mode: AuthenticationMode) -> Result<(String, Option)> where T: VCliContext, >::Error: StdError + Send + Sync + 'static, { match self { - Self::OAuth { provider } => { + Self::OAuth { provider, request_idp_token } => { let adapter = ctx.oauth_adapter(); let provider = provider.clone().into(); let provider = adapter.provider(provider).await?; @@ -113,18 +124,21 @@ where // To handle those cases we need to use a proxy path that emulates an authorization // code flow. if provider.device_code_endpoint().is_some() { - self.run_oauth_device_provider(provider, mode, ctx.oauth_adapter()) - .await + if *request_idp_token { + anyhow::bail!("Remote token access is not supported via device authentication flow"); + } + Ok((self.run_oauth_device_provider(provider, mode, ctx.oauth_adapter()) + .await?, None)) } else if provider.supports_pkce_only() { - self.run_oauth_code_provider(provider, mode, ctx.oauth_adapter()) + self.run_oauth_code_provider(provider, mode, *request_idp_token, ctx.oauth_adapter()) .await } else { anyhow::bail!("OAuth provider does not support any CLI authentication methods") } } Self::MagicLink { email, scope } => { - self.run_magic_link(email, scope.as_deref(), ctx.mlink_adapter()) - .await + Ok((self.run_magic_link(email, scope.as_deref(), ctx.mlink_adapter()) + .await?, None)) } } } @@ -132,14 +146,14 @@ where async fn run_oauth_device_provider( &self, provider: V, - mode: &AuthenticationMode, + mode: AuthenticationMode, adapter: T, ) -> Result where T: CliOAuthAdapter, V: CliOAuthProviderInfo, { - let oauth_client = oauth::DeviceOAuth::new(provider)?; + let oauth_client = oauth::device::DeviceOAuth::new(provider)?; let details = oauth_client.get_device_authorization().await?; println!( @@ -155,37 +169,52 @@ where Err(err) => Err(anyhow::anyhow!("Authentication failed: {}", err)), }?; - if mode == &AuthenticationMode::Token { - let token = adapter - .get_long_lived_token(identity_token.secret()) - .await?; - Ok(token.access_token().to_string()) - } else { - Ok(identity_token.secret().to_string()) + match mode { + AuthenticationMode::Identity => { + Ok(identity_token.secret().to_string()) + } + AuthenticationMode::Token => { + let token = adapter + .get_long_lived_token(identity_token.secret()) + .await?; + Ok(token.access_token().to_string()) + } } } async fn run_oauth_code_provider( &self, provider: T, - mode: &AuthenticationMode, + mode: AuthenticationMode, + request_idp_token: bool, adapter: V, - ) -> Result + ) -> Result<(String, Option)> where T: CliOAuthProviderInfo, V: CliOAuthAdapter + Send + Sync + 'static, { - let oauth_client = oauth::CodeOAuth::new(provider)?; + let oauth_client = oauth::code::CodeOAuth::new(provider)?; let adapter = Arc::new(adapter); - let identity_token = oauth_client.login(Arc::clone(&adapter)).await?; + let identity_token = oauth_client.login(Arc::clone(&adapter), request_idp_token).await?; + + let access_token = match mode { + AuthenticationMode::Identity => identity_token.access_token().to_string(), + AuthenticationMode::Token => { + let token = adapter + .get_long_lived_token(identity_token.access_token()) + .await?; + token.access_token().to_string() + } + }; - if mode == &AuthenticationMode::Token { - let token = adapter.get_long_lived_token(&identity_token).await?; - Ok(token.access_token().to_string()) + let idp_token = if request_idp_token { + identity_token.idp_token().map(|s| s.to_string()) } else { - Ok(identity_token) - } + None + }; + + Ok((access_token, idp_token)) } async fn run_magic_link( diff --git a/v-cli-sdk/src/cmd/auth/oauth.rs b/v-cli-sdk/src/cmd/auth/oauth.rs deleted file mode 100644 index 950e859d..00000000 --- a/v-cli-sdk/src/cmd/auth/oauth.rs +++ /dev/null @@ -1,298 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -use std::error::Error as StdError; -use std::future::Future; -use std::pin::Pin; -use std::sync::{Arc, Mutex}; - -use anyhow::Result; -use http::{Request, Response, StatusCode}; -use http_body_util::{BodyExt, Full}; -use hyper::body::{Bytes, Incoming}; -use oauth2::basic::{BasicClient, BasicTokenType}; -use oauth2::{ - AuthType, AuthUrl, ClientId, CsrfToken, DeviceAuthorizationUrl, EmptyExtraTokenFields, - EndpointNotSet, EndpointSet, RedirectUrl, Scope, StandardTokenResponse, TokenUrl, -}; -use oauth2::{PkceCodeChallenge, StandardDeviceAuthorizationResponse}; -use tokio::sync::oneshot; - -use crate::cmd::auth::login::CliAdapterToken; - -use super::proxy::run_proxy_server; - -static PROXY_PORT: u16 = 8174; - -pub trait CliOAuthAdapter { - type Token: CliAdapterToken; - type Error: StdError + Send + Sync + 'static; - - fn provider( - &self, - provider: super::login::LoginProvider, - ) -> Pin> + Send>>; - fn exchange_authorization_code( - &self, - request: Request, - ) -> Pin>, Self::Error>> + Send>>; - fn get_long_lived_token( - &self, - access_token: &str, - ) -> Pin> + Send>>; -} - -pub trait CliOAuthProviderInfo { - fn supports_pkce_only(&self) -> bool; - fn device_code_endpoint(&self) -> Option<&str>; - fn auth_url_endpoint(&self) -> &str; - fn token_endpoint(&self) -> &str; - fn redirect_endpoint(&self) -> Option<&str>; - fn client_id(&self) -> &str; - fn scopes(&self) -> &[String]; -} - -type CodeClient = BasicClient< - // HasAuthUrl - EndpointSet, - // HasDeviceAuthUrl - EndpointNotSet, - // HasIntrospectionUrl - EndpointNotSet, - // HasRevocationUrl - EndpointNotSet, - // HasTokenUrl - EndpointSet, ->; - -pub struct CodeOAuth { - client: CodeClient, - scopes: Vec, - port: u16, -} - -impl CodeOAuth { - pub fn new(provider: T) -> Result - where - T: CliOAuthProviderInfo, - { - let client = BasicClient::new(ClientId::new(provider.client_id().to_string())) - .set_auth_uri(AuthUrl::new(provider.auth_url_endpoint().to_string())?) - .set_auth_type(AuthType::RequestBody) - .set_token_uri(TokenUrl::new(provider.token_endpoint().to_string())?) - .set_redirect_uri(RedirectUrl::new( - provider - .redirect_endpoint() - .expect("OAuth code flow provider must define a redirect url") - .to_string(), - )?); - - Ok(Self { - client, - scopes: provider.scopes().iter().map(|s| s.to_string()).collect(), - port: PROXY_PORT, - }) - } - - /// Build the authorization URL that the user should visit in a browser. - /// Returns the full URL and the CSRF state token used for verification. - pub fn authorize_url( - &self, - pkce_challenge: PkceCodeChallenge, - ) -> (oauth2::url::Url, CsrfToken) { - let mut req = self - .client - .authorize_url(CsrfToken::new_random) - .set_pkce_challenge(pkce_challenge); - - for scope in &self.scopes { - req = req.add_scope(Scope::new(scope.to_string())); - } - - req.url() - } - - /// Run the full authorization code login flow: - /// - /// 1. Generate the authorization URL and print it for the user. - /// 2. Spin up a local HTTP proxy server to capture the IdP redirect. - /// 3. Forward the redirect request to the API server via the adapter. - /// 4. Extract the token from the server's response. - /// 5. Return a success page to the browser and shut down the proxy. - pub async fn login(&self, adapter: Arc) -> Result - where - T: CliOAuthAdapter + Send + Sync + 'static, - { - let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); - let (auth_url, _csrf_state) = self.authorize_url(pkce_challenge); - - println!( - "Open the following URL in your browser to authenticate:\n\n {}\n", - auth_url - ); - - // Channel to receive the token extracted from the server response. - let (token_tx, token_rx) = oneshot::channel::>(); - let token_tx: Arc>>>> = - Arc::new(Mutex::new(Some(token_tx))); - - // Channel to shut down the proxy server once we have the token. - let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); - - let port = self.port; - - // Spawn the local proxy server in a background task. - tokio::spawn({ - let callback_token_tx = Arc::clone(&token_tx); - let error_token_tx = Arc::clone(&token_tx); - - async move { - let callback = Arc::new(move |request: Request| { - let adapter = Arc::clone(&adapter); - let token_tx = Arc::clone(&callback_token_tx); - - Box::pin(async move { - // Forward the redirect request to the API server. - let response = adapter - .exchange_authorization_code(request) - .await - .map_err(|e| anyhow::anyhow!(e))?; - - // The server responds with the access token in the body. - let (_parts, body) = response.into_parts(); - let body_bytes = body - .collect() - .await - .expect("Full collection cannot fail") - .to_bytes(); - let token = String::from_utf8(body_bytes.to_vec())?; - - // Send the token back to the main task. - if let Ok(mut guard) = token_tx.lock() { - if let Some(tx) = guard.take() { - let _ = tx.send(Ok(token)); - } - } - - // Return a friendly page to the browser so the user - // knows they can close the tab. - Ok(Response::builder() - .status(StatusCode::OK) - .header("content-type", "text/html; charset=utf-8") - .body(Full::new(Bytes::from(concat!( - "", - "

Authentication successful!

", - "

You can close this tab and return to the CLI.

", - "" - ))))?) - }) - as Pin< - Box>>> + Send>, - > - }); - - if let Err(e) = run_proxy_server(port, callback, shutdown_rx).await { - eprintln!("Proxy server error: {e}"); - - // If the proxy died before we got a token, unblock the - // receiver so the caller isn't stuck forever. - if let Ok(mut guard) = error_token_tx.lock() { - if let Some(tx) = guard.take() { - let _ = tx.send(Err(anyhow::anyhow!( - "Proxy server exited unexpectedly: {e}" - ))); - } - } - } - } - }); - - // Wait for the proxy callback to extract the token. - let token = token_rx.await.map_err(|_| { - anyhow::anyhow!( - "Authentication callback was never received — proxy server may have exited early" - ) - })??; - - // Tell the proxy server to stop. - let _ = shutdown_tx.send(()); - - Ok(token) - } -} - -type DeviceClient = BasicClient< - // HasAuthUrl - EndpointSet, - // HasDeviceAuthUrl - EndpointSet, - // HasIntrospectionUrl - EndpointNotSet, - // HasRevocationUrl - EndpointNotSet, - // HasTokenUrl - EndpointSet, ->; - -pub struct DeviceOAuth { - client: DeviceClient, - http: oauth2_reqwest::ReqwestClient, - scopes: Vec, -} - -impl DeviceOAuth { - pub fn new(provider: T) -> Result - where - T: CliOAuthProviderInfo, - { - if let Some(device_endpoint) = provider.device_code_endpoint() { - let device_auth_url = DeviceAuthorizationUrl::new(device_endpoint.to_string())?; - - let client = BasicClient::new(ClientId::new(provider.client_id().to_string())) - .set_auth_uri(AuthUrl::new(provider.auth_url_endpoint().to_string())?) - .set_auth_type(AuthType::RequestBody) - .set_token_uri(TokenUrl::new(provider.token_endpoint().to_string())?) - .set_device_authorization_url(device_auth_url); - - Ok(Self { - client, - http: oauth2_reqwest::ReqwestClient::from( - reqwest::ClientBuilder::new() - .redirect(reqwest::redirect::Policy::none()) - .build() - .unwrap(), - ), - scopes: provider.scopes().iter().map(|s| s.to_string()).collect(), - }) - } else { - anyhow::bail!("Device authorization is not supported by this provider") - } - } - - pub async fn login( - &self, - details: &StandardDeviceAuthorizationResponse, - ) -> Result> { - let token = self - .client - .exchange_device_access_token(details) - .set_max_backoff_interval(details.interval()) - .request_async(&self.http, tokio::time::sleep, Some(details.expires_in())) - .await; - - Ok(token?) - } - - pub async fn get_device_authorization(&self) -> Result { - let mut req = self.client.exchange_device_code(); - - for scope in &self.scopes { - req = req.add_scope(Scope::new(scope.to_string())); - } - - let res = req.request_async(&self.http).await; - - Ok(res?) - } -} diff --git a/v-cli-sdk/src/cmd/auth/oauth/code.rs b/v-cli-sdk/src/cmd/auth/oauth/code.rs new file mode 100644 index 00000000..fba2ae30 --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/oauth/code.rs @@ -0,0 +1,221 @@ +use std::{ + future::Future, + pin::Pin, + sync::{Arc, Mutex}, +}; + +use anyhow::Result; +use http::{Request, Response, StatusCode}; +use http_body_util::Full; +use hyper::body::{Bytes, Incoming}; + +use oauth2::{ + basic::BasicClient, AuthType, AuthUrl, ClientId, CsrfToken, EndpointNotSet, EndpointSet, + PkceCodeChallenge, RedirectUrl, Scope, TokenUrl, +}; +use tokio::sync::oneshot; +use uuid::Uuid; + +use crate::cmd::auth::{ + oauth::{CliOAuthAdapter, CliOAuthProviderInfo}, proxy::run_proxy_server +}; + +type CodeClient = BasicClient< + // HasAuthUrl + EndpointSet, + // HasDeviceAuthUrl + EndpointNotSet, + // HasIntrospectionUrl + EndpointNotSet, + // HasRevocationUrl + EndpointNotSet, + // HasTokenUrl + EndpointSet, +>; + +pub struct CodeOAuth { + client: CodeClient, + client_id: Uuid, + redirect_uri: String, + scopes: Vec, + port: u16, +} + +impl CodeOAuth { + pub fn new(provider: T) -> Result + where + T: CliOAuthProviderInfo, + { + let client = BasicClient::new(ClientId::new(provider.client_id().to_string())) + .set_auth_uri(AuthUrl::new( + provider + .auth_url_endpoint() + .ok_or_else(|| { + anyhow::anyhow!("OAuth code flow provider must define an authorization url") + })? + .to_string(), + )?) + .set_auth_type(AuthType::RequestBody) + .set_token_uri(TokenUrl::new(provider.token_endpoint().to_string())?) + .set_redirect_uri(RedirectUrl::new( + provider + .redirect_endpoint() + .ok_or_else(|| { + anyhow::anyhow!("OAuth code flow provider must define a redirect url") + })? + .to_string(), + )?); + + Ok(Self { + client, + client_id: provider.client_id(), + redirect_uri: provider.redirect_endpoint().unwrap_or_default().to_string(), + scopes: provider.scopes().iter().map(|s| s.to_string()).collect(), + port: provider.public_pkce_port().ok_or_else(|| { + anyhow::anyhow!("OAuth code flow provider must define a public proxy port") + })?, + }) + } + + /// Build the authorization URL that the user should visit in a browser. + /// Returns the full URL and the CSRF state token used for verification. + pub fn authorize_url( + &self, + pkce_challenge: PkceCodeChallenge, + ) -> (oauth2::url::Url, CsrfToken) { + let mut req = self + .client + .authorize_url(CsrfToken::new_random) + .set_pkce_challenge(pkce_challenge); + + for scope in &self.scopes { + req = req.add_scope(Scope::new(scope.to_string())); + } + + req.url() + } + + /// Run the full authorization code login flow: + /// + /// 1. Generate the authorization URL and print it for the user. + /// 2. Spin up a local HTTP proxy server to capture the IdP redirect. + /// 3. Forward the redirect request to the API server via the adapter. + /// 4. Extract the token from the server's response. + /// 5. Return a success page to the browser and shut down the proxy. + pub async fn login(&self, adapter: Arc, request_idp_token: bool) -> Result + where + T: CliOAuthAdapter + Send + Sync + 'static, + { + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + let (auth_url, _csrf_state) = self.authorize_url(pkce_challenge); + + println!( + "Open the following URL in your browser to authenticate:\n\n {}\n", + auth_url + ); + + // Channel to receive the token extracted from the server response. + let (token_tx, token_rx) = oneshot::channel::>(); + let token_tx: Arc>>>> = + Arc::new(Mutex::new(Some(token_tx))); + + // Channel to shut down the proxy server once we have the token. + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + + let port = self.port; + + // Spawn the local proxy server in a background task. + tokio::spawn({ + let callback_token_tx = Arc::clone(&token_tx); + let error_token_tx = Arc::clone(&token_tx); + let client_id = self.client_id.clone(); + let redirect_uri = self.redirect_uri.clone(); + + async move { + let callback: crate::cmd::auth::proxy::Callback = Arc::new(Mutex::new(Some(Box::new(move |request: Request| { + let adapter = Arc::clone(&adapter); + let token_tx = Arc::clone(&callback_token_tx); + + Box::pin(async move { + let code = request + .uri() + .query() + .and_then(|q: &str| { + q.split('&') + .filter_map(|pair: &str| pair.split_once('=')) + .find(|(key, _): &(&str, &str)| *key == "code") + .map(|(_, value): (&str, &str)| value.to_string()) + }) + .ok_or_else(|| { + anyhow::anyhow!("Missing 'code' query parameter in callback request") + })?; + + // Forward the redirect request to the API server. + let token = adapter + .exchange_authorization_code( + crate::cmd::auth::login::LoginProvider::Zendesk, + client_id.clone(), + redirect_uri.clone(), + "authorization_code".to_string(), + code, + pkce_verifier, + request_idp_token, + ) + .await + .map_err(|e| { + anyhow::anyhow!(e) + })?; + + // Send the token back to the main task. + if let Ok(mut guard) = token_tx.lock() { + if let Some(tx) = guard.take() { + let _ = tx.send(Ok(token)); + } + } + + // Return a friendly page to the browser so the user + // knows they can close the tab. + Ok(Response::builder() + .status(StatusCode::OK) + .header("content-type", "text/html; charset=utf-8") + .body(Full::new(Bytes::from(concat!( + "", + "", + "

Authentication successful. This window should close automatically.

", + "" + ))))?) + }) + as Pin< + Box>>> + Send>, + > + })))); + + if let Err(e) = run_proxy_server(port, callback, shutdown_rx).await { + eprintln!("Proxy server error: {e}"); + + // If the proxy died before we got a token, unblock the + // receiver so the caller is not stuck forever. + if let Ok(mut guard) = error_token_tx.lock() { + if let Some(tx) = guard.take() { + let _ = tx.send(Err(anyhow::anyhow!( + "Proxy server exited unexpectedly: {e}" + ))); + } + } + } + } + }); + + // Wait for the proxy callback to extract the token. + let token = token_rx.await.map_err(|_| { + anyhow::anyhow!( + "Authentication callback was never received — proxy server may have exited early" + ) + })??; + + // Tell the proxy server to stop. + let _ = shutdown_tx.send(()); + + Ok(token) + } +} diff --git a/v-cli-sdk/src/cmd/auth/oauth/device.rs b/v-cli-sdk/src/cmd/auth/oauth/device.rs new file mode 100644 index 00000000..4acf227b --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/oauth/device.rs @@ -0,0 +1,92 @@ +use anyhow::Result; +use oauth2::{ + basic::{BasicClient, BasicTokenType}, + AuthType, AuthUrl, ClientId, DeviceAuthorizationUrl, EmptyExtraTokenFields, EndpointNotSet, + EndpointSet, Scope, StandardDeviceAuthorizationResponse, StandardTokenResponse, TokenUrl, +}; + +use crate::cmd::auth::oauth::CliOAuthProviderInfo; + +type DeviceClient = BasicClient< + // HasAuthUrl + EndpointSet, + // HasDeviceAuthUrl + EndpointSet, + // HasIntrospectionUrl + EndpointNotSet, + // HasRevocationUrl + EndpointNotSet, + // HasTokenUrl + EndpointSet, +>; + +pub struct DeviceOAuth { + client: DeviceClient, + http: oauth2_reqwest::ReqwestClient, + scopes: Vec, +} + +impl DeviceOAuth { + pub fn new(provider: T) -> Result + where + T: CliOAuthProviderInfo, + { + if let Some(device_endpoint) = provider.device_code_endpoint() { + let device_auth_url = DeviceAuthorizationUrl::new(device_endpoint.to_string())?; + + let client = BasicClient::new(ClientId::new(provider.client_id().to_string())) + .set_auth_uri(AuthUrl::new( + provider + .device_code_endpoint() + .ok_or_else(|| { + anyhow::anyhow!( + "OAuth device flow provider must define an device code url" + ) + })? + .to_string(), + )?) + .set_auth_type(AuthType::RequestBody) + .set_token_uri(TokenUrl::new(provider.token_endpoint().to_string())?) + .set_device_authorization_url(device_auth_url); + + Ok(Self { + client, + http: oauth2_reqwest::ReqwestClient::from( + reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap(), + ), + scopes: provider.scopes().iter().map(|s| s.to_string()).collect(), + }) + } else { + anyhow::bail!("Device authorization is not supported by this provider") + } + } + + pub async fn login( + &self, + details: &StandardDeviceAuthorizationResponse, + ) -> Result> { + let token = self + .client + .exchange_device_access_token(details) + .set_max_backoff_interval(details.interval()) + .request_async(&self.http, tokio::time::sleep, Some(details.expires_in())) + .await; + + Ok(token?) + } + + pub async fn get_device_authorization(&self) -> Result { + let mut req = self.client.exchange_device_code(); + + for scope in &self.scopes { + req = req.add_scope(Scope::new(scope.to_string())); + } + + let res = req.request_async(&self.http).await; + + Ok(res?) + } +} diff --git a/v-cli-sdk/src/cmd/auth/oauth/mod.rs b/v-cli-sdk/src/cmd/auth/oauth/mod.rs new file mode 100644 index 00000000..afdcf244 --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/oauth/mod.rs @@ -0,0 +1,50 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::Result; +use oauth2::PkceCodeVerifier; +use uuid::Uuid; +use std::{error::Error as StdError, future::Future, pin::Pin}; + +pub mod code; +pub mod device; + +use crate::cmd::auth::login::CliAdapterToken; + +pub trait CliOAuthAdapter { + type ShortToken: CliAdapterToken + Send + 'static; + type LongToken: CliAdapterToken + Send + 'static; + type Error: StdError + Send + Sync + 'static; + + fn provider( + &self, + provider: super::login::LoginProvider, + ) -> Pin> + Send>>; + fn exchange_authorization_code( + &self, + provider: super::login::LoginProvider, + client_id: Uuid, + redirect_uri: String, + grant_type: String, + code: String, + pkce_verifier: PkceCodeVerifier, + request_idp_token: bool, + ) -> Pin> + Send>>; + fn get_long_lived_token( + &self, + access_token: &str, + ) -> Pin> + Send>>; +} + +pub trait CliOAuthProviderInfo { + fn client_id(&self) -> Uuid; + fn remote_client_id(&self) -> &str; + fn public_pkce_port(&self) -> Option; + fn supports_pkce_only(&self) -> bool; + fn device_code_endpoint(&self) -> Option<&str>; + fn auth_url_endpoint(&self) -> Option<&str>; + fn token_endpoint(&self) -> &str; + fn redirect_endpoint(&self) -> Option<&str>; + fn scopes(&self) -> &[String]; +} diff --git a/v-cli-sdk/src/cmd/auth/proxy.rs b/v-cli-sdk/src/cmd/auth/proxy.rs index 47049421..d1b4dd60 100644 --- a/v-cli-sdk/src/cmd/auth/proxy.rs +++ b/v-cli-sdk/src/cmd/auth/proxy.rs @@ -4,26 +4,30 @@ use std::future::Future; use std::net::SocketAddr; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use http_body_util::Full; use hyper::body::{Bytes, Incoming}; use hyper::service::service_fn; -use hyper::{Request, Response}; +use hyper::{Request, Response, StatusCode}; use hyper_util::rt::TokioIo; use tokio::net::TcpListener; use tokio::sync::oneshot; /// A callback function that receives an incoming HTTP request and returns a response. -pub type Callback = Arc< - dyn Fn( +/// Wrapped in `Arc>>` so it can be called at most once. +pub type CallbackFn = Box< + dyn FnOnce( Request, ) -> std::pin::Pin>>> + Send>> - + Send - + Sync, + + Send, >; +/// A shareable, single-use callback. The first request to arrive `.take()`s +/// the inner function; any subsequent request receives an error response. +pub type Callback = Arc>>; + /// Start a minimal HTTP server on the given port that forwards every incoming /// request to `callback` and returns whatever response the callback produces. /// @@ -63,7 +67,22 @@ async fn serve_loop( tokio::task::spawn(async move { let service = service_fn(move |req: Request| { let cb = Arc::clone(&cb); - async move { cb(req).await } + async move { + let handler = cb + .lock() + .expect("callback mutex poisoned") + .take(); + + match handler { + Some(f) => f(req).await, + None => Ok(Response::builder() + .status(StatusCode::GONE) + .body(Full::new(Bytes::from( + "Callback has already been invoked", + ))) + .expect("building static response cannot fail")), + } + } }); if let Err(err) = @@ -88,14 +107,14 @@ mod tests { #[tokio::test] async fn test_proxy_server_responds() { - let callback: Callback = Arc::new(|_req| { + let callback: Callback = Arc::new(Mutex::new(Some(Box::new(|_req| { Box::pin(async { Ok(Response::builder() .status(StatusCode::OK) .body(Full::new(Bytes::from("hello from callback"))) .unwrap()) }) - }); + })))); let (tx, rx) = oneshot::channel::<()>(); @@ -122,6 +141,15 @@ mod tests { assert_eq!(resp.status(), 200); assert_eq!(resp.text().await.unwrap(), "hello from callback"); + // A second request should get a 410 GONE. + let resp2 = client + .get(format!("http://{}", local_addr)) + .send() + .await + .unwrap(); + + assert_eq!(resp2.status(), 410); + // Shut down the server. tx.send(()).unwrap(); server_handle.await.unwrap(); diff --git a/v-cli-sdk/src/lib.rs b/v-cli-sdk/src/lib.rs index eb7950ab..60cfd5ad 100644 --- a/v-cli-sdk/src/lib.rs +++ b/v-cli-sdk/src/lib.rs @@ -40,7 +40,8 @@ impl Display for FormatStyle { } pub trait VCliContext { - type Token; + type ShortToken; + type LongToken; type Error; fn config(&self) -> &impl VCliConfig; @@ -51,10 +52,10 @@ pub trait VCliContext { fn oauth_adapter( &self, - ) -> impl CliOAuthAdapter + Send + Sync + 'static; + ) -> impl CliOAuthAdapter + Send + Sync + 'static; fn mlink_adapter( &self, - ) -> impl CliMagicLinkAdapter + Send + Sync + 'static; + ) -> impl CliMagicLinkAdapter + Send + Sync + 'static; } pub trait VApiErrorMessage { From 09711f53d5dff2f236b35770ff717730010de650 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 14:13:49 -0500 Subject: [PATCH 08/51] Fmt --- v-api/src/config.rs | 2 +- v-api/src/endpoints/login/oauth/flow/code.rs | 11 +- v-api/src/endpoints/login/oauth/mod.rs | 3 +- .../endpoints/login/oauth/remote/github.rs | 8 +- .../endpoints/login/oauth/remote/google.rs | 27 ++--- .../endpoints/login/oauth/remote/zendesk.rs | 29 ++--- v-cli-sdk/src/cmd/auth/login.rs | 60 +++++++---- v-cli-sdk/src/cmd/auth/oauth/code.rs | 102 +++++++++--------- v-cli-sdk/src/cmd/auth/oauth/mod.rs | 2 +- v-cli-sdk/src/cmd/auth/proxy.rs | 6 +- v-cli-sdk/src/lib.rs | 8 +- 11 files changed, 149 insertions(+), 109 deletions(-) diff --git a/v-api/src/config.rs b/v-api/src/config.rs index 18e0e32d..e37e57f3 100644 --- a/v-api/src/config.rs +++ b/v-api/src/config.rs @@ -21,10 +21,10 @@ use serde::{ de::{self, Visitor}, Deserialize, Deserializer, }; -use v_model::OAuthClientId; use std::path::PathBuf; use thiserror::Error; use v_api_param::{ParamResolutionError, StringParam}; +use v_model::OAuthClientId; use crate::{ authn::{ diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index c6529c26..059466b4 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -535,10 +535,7 @@ where "Malformed credentials presented to code exchange", )), Err(err) => { - tracing::info!( - ?err, - "Failed to extract basic authentication credentials" - ); + tracing::info!(?err, "Failed to extract basic authentication credentials"); Ok(None) } }?; @@ -550,7 +547,11 @@ where // We of course deny underspecifying credentials, but we also want to disallow over specifying // them. For example, if the client provides both basic auth and a client id/secret in the // request body, we should reject the request. - tracing::debug!(?basic_credentials, ?body_credentials, "Extracted credentials from request"); + tracing::debug!( + ?basic_credentials, + ?body_credentials, + "Extracted credentials from request" + ); let (client_id, client_secret) = match (basic_credentials, body_credentials) { (Some(_), (Some(_), _)) => Err(bad_request( "Cannot provide both basic auth and client credentials", diff --git a/v-api/src/endpoints/login/oauth/mod.rs b/v-api/src/endpoints/login/oauth/mod.rs index 3b6c58e6..74412c9c 100644 --- a/v-api/src/endpoints/login/oauth/mod.rs +++ b/v-api/src/endpoints/login/oauth/mod.rs @@ -85,7 +85,8 @@ pub trait OAuthProvider: ExtractUserInfo + Debug + Send + Sync { .set_auth_uri(AuthUrl::new(info.remote.auth_url_endpoint.clone())?) .set_token_uri(TokenUrl::new(info.remote.token_endpoint.clone())?) .set_revocation_url_option( - info.remote.revocation_endpoint + info.remote + .revocation_endpoint .as_ref() .map(|url| RevocationUrl::new(url.to_string())) .transpose()?, diff --git a/v-api/src/endpoints/login/oauth/remote/github.rs b/v-api/src/endpoints/login/oauth/remote/github.rs index 0d15a685..86705976 100644 --- a/v-api/src/endpoints/login/oauth/remote/github.rs +++ b/v-api/src/endpoints/login/oauth/remote/github.rs @@ -11,9 +11,11 @@ use std::fmt; use crate::{ config::ResolvedOAuthConfig, endpoints::login::{ - ExternalUserId, UserInfo, UserInfoError, oauth::{ - OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo - } + oauth::{ + OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, + OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo, + }, + ExternalUserId, UserInfo, UserInfoError, }, }; diff --git a/v-api/src/endpoints/login/oauth/remote/google.rs b/v-api/src/endpoints/login/oauth/remote/google.rs index 24f02b91..9cb7c6f5 100644 --- a/v-api/src/endpoints/login/oauth/remote/google.rs +++ b/v-api/src/endpoints/login/oauth/remote/google.rs @@ -10,9 +10,11 @@ use std::fmt; use crate::{ config::ResolvedOAuthConfig, endpoints::login::{ - ExternalUserId, UserInfo, UserInfoError, oauth::{ - OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo - } + oauth::{ + OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, + OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo, + }, + ExternalUserId, UserInfo, UserInfoError, }, }; @@ -59,16 +61,15 @@ impl GoogleOAuthProvider { revocation_endpoint: Some("https://oauth2.googleapis.com/revoke".to_string()), }, }); - let authz_code_pkce_flow_info = - config - .proxy_web - .and_then(|proxy| authz_code_flow_info.as_ref().map(|web| (web, proxy))) - .map(|(web, proxy)| OAuthProviderAuthorizationCodePkceInfo { - client_id: proxy.client_id, - redirect_endpoint: proxy.redirect_uri, - proxy_port: proxy.proxy_port, - web: web.clone() - }); + let authz_code_pkce_flow_info = config + .proxy_web + .and_then(|proxy| authz_code_flow_info.as_ref().map(|web| (web, proxy))) + .map(|(web, proxy)| OAuthProviderAuthorizationCodePkceInfo { + client_id: proxy.client_id, + redirect_endpoint: proxy.redirect_uri, + proxy_port: proxy.proxy_port, + web: web.clone(), + }); let device_code_flow_info = config.device.map(|device| OAuthProviderDeviceInfo { client_id: device.client_id, remote_client_id: device.remote_client_id, diff --git a/v-api/src/endpoints/login/oauth/remote/zendesk.rs b/v-api/src/endpoints/login/oauth/remote/zendesk.rs index 26e1578c..143c6e69 100644 --- a/v-api/src/endpoints/login/oauth/remote/zendesk.rs +++ b/v-api/src/endpoints/login/oauth/remote/zendesk.rs @@ -10,9 +10,11 @@ use std::fmt; use crate::{ config::ResolvedOAuthConfig, endpoints::login::{ - ExternalUserId, UserInfo, UserInfoError, oauth::{ - OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo - } + oauth::{ + OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, + OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo, + }, + ExternalUserId, UserInfo, UserInfoError, }, }; @@ -55,19 +57,18 @@ impl ZendeskOAuthProvider { auth_url_endpoint: format!("{}/oauth/authorizations/new", base_url), token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), token_endpoint: format!("{}/oauth/tokens", base_url), - revocation_endpoint: None + revocation_endpoint: None, }, }); - let authz_code_pkce_flow_info = - config - .proxy_web - .and_then(|proxy| authz_code_flow_info.as_ref().map(|web| (web, proxy))) - .map(|(web, proxy)| OAuthProviderAuthorizationCodePkceInfo { - client_id: proxy.client_id, - redirect_endpoint: proxy.redirect_uri, - proxy_port: proxy.proxy_port, - web: web.clone() - }); + let authz_code_pkce_flow_info = config + .proxy_web + .and_then(|proxy| authz_code_flow_info.as_ref().map(|web| (web, proxy))) + .map(|(web, proxy)| OAuthProviderAuthorizationCodePkceInfo { + client_id: proxy.client_id, + redirect_endpoint: proxy.redirect_uri, + proxy_port: proxy.proxy_port, + web: web.clone(), + }); Self { authz_code_flow_info, diff --git a/v-cli-sdk/src/cmd/auth/login.rs b/v-cli-sdk/src/cmd/auth/login.rs index 762885da..c079c82e 100644 --- a/v-cli-sdk/src/cmd/auth/login.rs +++ b/v-cli-sdk/src/cmd/auth/login.rs @@ -49,12 +49,16 @@ where // If we are acquiring an IdP token, present it to the user. if let Some(idp_token) = idp_token { - println!("\nYou can now additionally authenticate against the requested remote service API \ - with the following token."); + println!( + "\nYou can now additionally authenticate against the requested remote service API \ + with the following token." + ); println!("IdP token: {}", idp_token); println!(""); - println!("Please note that this should be kept secure as calls made with this token are \ - made on behalf of your user acount"); + println!( + "Please note that this should be kept secure as calls made with this token are \ + made on behalf of your user acount" + ); } Ok(()) @@ -108,13 +112,20 @@ impl LoginMethod where SupportedProviders: CliConsumerLoginProvider, { - pub async fn run(&self, ctx: &T, mode: AuthenticationMode) -> Result<(String, Option)> + pub async fn run( + &self, + ctx: &T, + mode: AuthenticationMode, + ) -> Result<(String, Option)> where T: VCliContext, >::Error: StdError + Send + Sync + 'static, { match self { - Self::OAuth { provider, request_idp_token } => { + Self::OAuth { + provider, + request_idp_token, + } => { let adapter = ctx.oauth_adapter(); let provider = provider.clone().into(); let provider = adapter.provider(provider).await?; @@ -125,21 +136,32 @@ where // code flow. if provider.device_code_endpoint().is_some() { if *request_idp_token { - anyhow::bail!("Remote token access is not supported via device authentication flow"); + anyhow::bail!( + "Remote token access is not supported via device authentication flow" + ); } - Ok((self.run_oauth_device_provider(provider, mode, ctx.oauth_adapter()) - .await?, None)) + Ok(( + self.run_oauth_device_provider(provider, mode, ctx.oauth_adapter()) + .await?, + None, + )) } else if provider.supports_pkce_only() { - self.run_oauth_code_provider(provider, mode, *request_idp_token, ctx.oauth_adapter()) - .await + self.run_oauth_code_provider( + provider, + mode, + *request_idp_token, + ctx.oauth_adapter(), + ) + .await } else { anyhow::bail!("OAuth provider does not support any CLI authentication methods") } } - Self::MagicLink { email, scope } => { - Ok((self.run_magic_link(email, scope.as_deref(), ctx.mlink_adapter()) - .await?, None)) - } + Self::MagicLink { email, scope } => Ok(( + self.run_magic_link(email, scope.as_deref(), ctx.mlink_adapter()) + .await?, + None, + )), } } @@ -170,9 +192,7 @@ where }?; match mode { - AuthenticationMode::Identity => { - Ok(identity_token.secret().to_string()) - } + AuthenticationMode::Identity => Ok(identity_token.secret().to_string()), AuthenticationMode::Token => { let token = adapter .get_long_lived_token(identity_token.secret()) @@ -196,7 +216,9 @@ where let oauth_client = oauth::code::CodeOAuth::new(provider)?; let adapter = Arc::new(adapter); - let identity_token = oauth_client.login(Arc::clone(&adapter), request_idp_token).await?; + let identity_token = oauth_client + .login(Arc::clone(&adapter), request_idp_token) + .await?; let access_token = match mode { AuthenticationMode::Identity => identity_token.access_token().to_string(), diff --git a/v-cli-sdk/src/cmd/auth/oauth/code.rs b/v-cli-sdk/src/cmd/auth/oauth/code.rs index fba2ae30..298d943d 100644 --- a/v-cli-sdk/src/cmd/auth/oauth/code.rs +++ b/v-cli-sdk/src/cmd/auth/oauth/code.rs @@ -17,7 +17,8 @@ use tokio::sync::oneshot; use uuid::Uuid; use crate::cmd::auth::{ - oauth::{CliOAuthAdapter, CliOAuthProviderInfo}, proxy::run_proxy_server + oauth::{CliOAuthAdapter, CliOAuthProviderInfo}, + proxy::run_proxy_server, }; type CodeClient = BasicClient< @@ -132,50 +133,51 @@ impl CodeOAuth { let redirect_uri = self.redirect_uri.clone(); async move { - let callback: crate::cmd::auth::proxy::Callback = Arc::new(Mutex::new(Some(Box::new(move |request: Request| { - let adapter = Arc::clone(&adapter); - let token_tx = Arc::clone(&callback_token_tx); - - Box::pin(async move { - let code = request - .uri() - .query() - .and_then(|q: &str| { - q.split('&') - .filter_map(|pair: &str| pair.split_once('=')) - .find(|(key, _): &(&str, &str)| *key == "code") - .map(|(_, value): (&str, &str)| value.to_string()) - }) - .ok_or_else(|| { - anyhow::anyhow!("Missing 'code' query parameter in callback request") - })?; - - // Forward the redirect request to the API server. - let token = adapter - .exchange_authorization_code( - crate::cmd::auth::login::LoginProvider::Zendesk, - client_id.clone(), - redirect_uri.clone(), - "authorization_code".to_string(), - code, - pkce_verifier, - request_idp_token, - ) - .await - .map_err(|e| { - anyhow::anyhow!(e) - })?; - - // Send the token back to the main task. - if let Ok(mut guard) = token_tx.lock() { - if let Some(tx) = guard.take() { - let _ = tx.send(Ok(token)); + let callback: crate::cmd::auth::proxy::Callback = Arc::new(Mutex::new(Some( + Box::new(move |request: Request| { + let adapter = Arc::clone(&adapter); + let token_tx = Arc::clone(&callback_token_tx); + + Box::pin(async move { + let code = request + .uri() + .query() + .and_then(|q: &str| { + q.split('&') + .filter_map(|pair: &str| pair.split_once('=')) + .find(|(key, _): &(&str, &str)| *key == "code") + .map(|(_, value): (&str, &str)| value.to_string()) + }) + .ok_or_else(|| { + anyhow::anyhow!( + "Missing 'code' query parameter in callback request" + ) + })?; + + // Forward the redirect request to the API server. + let token = adapter + .exchange_authorization_code( + crate::cmd::auth::login::LoginProvider::Zendesk, + client_id.clone(), + redirect_uri.clone(), + "authorization_code".to_string(), + code, + pkce_verifier, + request_idp_token, + ) + .await + .map_err(|e| anyhow::anyhow!(e))?; + + // Send the token back to the main task. + if let Ok(mut guard) = token_tx.lock() { + if let Some(tx) = guard.take() { + let _ = tx.send(Ok(token)); + } } - } - // Return a friendly page to the browser so the user - // knows they can close the tab. - Ok(Response::builder() + // Return a friendly page to the browser so the user + // knows they can close the tab. + Ok(Response::builder() .status(StatusCode::OK) .header("content-type", "text/html; charset=utf-8") .body(Full::new(Bytes::from(concat!( @@ -184,11 +186,15 @@ impl CodeOAuth { "

Authentication successful. This window should close automatically.

", "" ))))?) - }) - as Pin< - Box>>> + Send>, - > - })))); + }) + as Pin< + Box< + dyn Future>>> + + Send, + >, + > + }), + ))); if let Err(e) = run_proxy_server(port, callback, shutdown_rx).await { eprintln!("Proxy server error: {e}"); diff --git a/v-cli-sdk/src/cmd/auth/oauth/mod.rs b/v-cli-sdk/src/cmd/auth/oauth/mod.rs index afdcf244..8d7841a0 100644 --- a/v-cli-sdk/src/cmd/auth/oauth/mod.rs +++ b/v-cli-sdk/src/cmd/auth/oauth/mod.rs @@ -4,8 +4,8 @@ use anyhow::Result; use oauth2::PkceCodeVerifier; -use uuid::Uuid; use std::{error::Error as StdError, future::Future, pin::Pin}; +use uuid::Uuid; pub mod code; pub mod device; diff --git a/v-cli-sdk/src/cmd/auth/proxy.rs b/v-cli-sdk/src/cmd/auth/proxy.rs index d1b4dd60..527584f7 100644 --- a/v-cli-sdk/src/cmd/auth/proxy.rs +++ b/v-cli-sdk/src/cmd/auth/proxy.rs @@ -19,9 +19,9 @@ use tokio::sync::oneshot; pub type CallbackFn = Box< dyn FnOnce( Request, - ) - -> std::pin::Pin>>> + Send>> - + Send, + ) -> std::pin::Pin< + Box>>> + Send>, + > + Send, >; /// A shareable, single-use callback. The first request to arrive `.take()`s diff --git a/v-cli-sdk/src/lib.rs b/v-cli-sdk/src/lib.rs index 60cfd5ad..27f678a5 100644 --- a/v-cli-sdk/src/lib.rs +++ b/v-cli-sdk/src/lib.rs @@ -52,7 +52,13 @@ pub trait VCliContext { fn oauth_adapter( &self, - ) -> impl CliOAuthAdapter + Send + Sync + 'static; + ) -> impl CliOAuthAdapter< + ShortToken = Self::ShortToken, + LongToken = Self::LongToken, + Error = Self::Error, + > + Send + + Sync + + 'static; fn mlink_adapter( &self, ) -> impl CliMagicLinkAdapter + Send + Sync + 'static; From 8fea7ea34e75c6b88a7b04bb30b076e00ffe88ee Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 14:14:20 -0500 Subject: [PATCH 09/51] Clippy lints --- v-api/src/config.rs | 4 ++-- v-api/src/endpoints/login/oauth/flow/code.rs | 2 +- v-api/src/endpoints/login/oauth/flow/device_token.rs | 2 +- v-cli-sdk/src/cmd/auth/login.rs | 2 +- v-cli-sdk/src/cmd/auth/oauth/code.rs | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/v-api/src/config.rs b/v-api/src/config.rs index e37e57f3..a12a688c 100644 --- a/v-api/src/config.rs +++ b/v-api/src/config.rs @@ -218,7 +218,7 @@ impl OAuthDeviceConfig { ) -> Result { let remote_client_secret = self.remote_client_secret.resolve(base)?; Ok(ResolvedOAuthDeviceConfig { - client_id: self.client_id.clone(), + client_id: self.client_id, remote_client_id: self.remote_client_id.clone(), remote_client_secret, }) @@ -242,7 +242,7 @@ impl OAuthWebProxyConfig { _base: Option, ) -> Result { Ok(ResolvedOAuthWebProxyConfig { - client_id: self.client_id.clone(), + client_id: self.client_id, redirect_uri: self.redirect_uri.clone(), proxy_port: self.proxy_port, }) diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 059466b4..80ebe9f4 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -286,7 +286,7 @@ fn oauth_redirect_response( .add_scopes( provider .default_scopes() - .into_iter() + .iter() .map(|s| Scope::new(s.to_string())) .collect::>(), ); diff --git a/v-api/src/endpoints/login/oauth/flow/device_token.rs b/v-api/src/endpoints/login/oauth/flow/device_token.rs index d18ddae4..b20753dc 100644 --- a/v-api/src/endpoints/login/oauth/flow/device_token.rs +++ b/v-api/src/endpoints/login/oauth/flow/device_token.rs @@ -145,7 +145,7 @@ where let device_info = device_info.unwrap(); let exchange_request = body.into_inner(); - let exchange = AccessTokenExchange::new(exchange_request, &device_info); + let exchange = AccessTokenExchange::new(exchange_request, device_info); let client = reqwest::Client::new(); diff --git a/v-cli-sdk/src/cmd/auth/login.rs b/v-cli-sdk/src/cmd/auth/login.rs index c079c82e..64b2b36b 100644 --- a/v-cli-sdk/src/cmd/auth/login.rs +++ b/v-cli-sdk/src/cmd/auth/login.rs @@ -54,7 +54,7 @@ where with the following token." ); println!("IdP token: {}", idp_token); - println!(""); + println!(); println!( "Please note that this should be kept secure as calls made with this token are \ made on behalf of your user acount" diff --git a/v-cli-sdk/src/cmd/auth/oauth/code.rs b/v-cli-sdk/src/cmd/auth/oauth/code.rs index 298d943d..c8e520dc 100644 --- a/v-cli-sdk/src/cmd/auth/oauth/code.rs +++ b/v-cli-sdk/src/cmd/auth/oauth/code.rs @@ -129,7 +129,7 @@ impl CodeOAuth { tokio::spawn({ let callback_token_tx = Arc::clone(&token_tx); let error_token_tx = Arc::clone(&token_tx); - let client_id = self.client_id.clone(); + let client_id = self.client_id; let redirect_uri = self.redirect_uri.clone(); async move { @@ -158,7 +158,7 @@ impl CodeOAuth { let token = adapter .exchange_authorization_code( crate::cmd::auth::login::LoginProvider::Zendesk, - client_id.clone(), + client_id, redirect_uri.clone(), "authorization_code".to_string(), code, From 3aeddc4744554f1aa211349bfdfece0e1ebc4d67 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 14:33:47 -0500 Subject: [PATCH 10/51] Redirect fixes --- v-api/src/context/mod.rs | 7 +++-- v-api/src/endpoints/login/magic_link/mod.rs | 7 +++-- v-api/src/endpoints/login/mod.rs | 31 ++++++++++++++++++++ v-api/src/endpoints/login/oauth/flow/code.rs | 2 +- v-api/src/endpoints/login/oauth/mod.rs | 9 +++--- 5 files changed, 46 insertions(+), 10 deletions(-) diff --git a/v-api/src/context/mod.rs b/v-api/src/context/mod.rs index afe51505..73808f31 100644 --- a/v-api/src/context/mod.rs +++ b/v-api/src/context/mod.rs @@ -1317,10 +1317,13 @@ pub(crate) mod test_mocks { Box::new(ZendeskOAuthProvider::new( ResolvedOAuthConfig { device: None, - web: None, + web: Some(ResolvedOAuthWebConfig { + remote_client_id: "zendesk_web_client_id".to_string(), + remote_client_secret: "zendesk_web_client_secret".to_string().into(), + }), proxy_web: Some(ResolvedOAuthWebProxyConfig { client_id: TypedUuid::new_v4(), - redirect_uri: "test".to_string(), + redirect_uri: "https://test_public_url/pkce-callback".to_string(), proxy_port: 1234, }), }, diff --git a/v-api/src/endpoints/login/magic_link/mod.rs b/v-api/src/endpoints/login/magic_link/mod.rs index b666b7cc..e33d6a9e 100644 --- a/v-api/src/endpoints/login/magic_link/mod.rs +++ b/v-api/src/endpoints/login/magic_link/mod.rs @@ -313,8 +313,9 @@ impl CheckMagicLinkClient for MagicLink { fn is_redirect_uri_valid(&self, redirect_uri: &str) -> bool { tracing::trace!(?redirect_uri, valid_uris = ?self.redirect_uris, "Checking redirect uri against list of valid uris"); - self.redirect_uris - .iter() - .any(|r| r.redirect_uri == redirect_uri) + super::is_redirect_uri_valid( + redirect_uri, + self.redirect_uris.iter().map(|r| r.redirect_uri.as_str()), + ) } } diff --git a/v-api/src/endpoints/login/mod.rs b/v-api/src/endpoints/login/mod.rs index ef683c2e..54073743 100644 --- a/v-api/src/endpoints/login/mod.rs +++ b/v-api/src/endpoints/login/mod.rs @@ -10,6 +10,7 @@ use serde::{ Deserialize, Deserializer, Serialize, Serializer, }; use thiserror::Error; +use url::Url; use crate::{ permissions::VPermission, @@ -212,3 +213,33 @@ pub enum UserInfoError { pub trait UserInfoProvider { async fn get_user_info(&self, token: &str) -> Result; } + +/// Structurally compare a candidate redirect URI against a list of registered redirect URIs. +/// Comparison is performed on scheme, host, port, and path. URIs that fail to parse or contain +/// fragments are rejected (per RFC 6749 §3.1.2). +pub fn is_redirect_uri_valid<'a>( + redirect_uri: &str, + registered_uris: impl Iterator, +) -> bool { + let candidate = match Url::parse(redirect_uri) { + Ok(url) => url, + Err(_) => return false, + }; + + // Reject redirect URIs that contain fragments (per RFC 6749 §3.1.2) + if candidate.fragment().is_some() { + return false; + } + + registered_uris.into_iter().any(|registered| { + match Url::parse(registered) { + Ok(registered) => { + registered.scheme() == candidate.scheme() + && registered.host() == candidate.host() + && registered.port() == candidate.port() + && registered.path() == candidate.path() + } + Err(_) => false, + } + }) +} diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 80ebe9f4..74999836 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -912,7 +912,7 @@ mod tests { .unwrap(); let secret_signature = key.signature().to_string(); let client_secret = key.key(); - let redirect_uri = "callback-destination"; + let redirect_uri = "https://example.com/callback"; ( ctx, diff --git a/v-api/src/endpoints/login/oauth/mod.rs b/v-api/src/endpoints/login/oauth/mod.rs index 74412c9c..27b0a468 100644 --- a/v-api/src/endpoints/login/oauth/mod.rs +++ b/v-api/src/endpoints/login/oauth/mod.rs @@ -24,7 +24,7 @@ use crate::{ secrets::OpenApiSecretString, }; -use super::{UserInfo, UserInfoError, UserInfoProvider}; +use super::{is_redirect_uri_valid, UserInfo, UserInfoError, UserInfoProvider}; pub mod client; pub mod flow; @@ -246,8 +246,9 @@ impl CheckOAuthClient for OAuthClient { fn is_redirect_uri_valid(&self, redirect_uri: &str) -> bool { tracing::trace!(?redirect_uri, valid_uris = ?self.redirect_uris, "Checking redirect uri against list of valid uris"); - self.redirect_uris - .iter() - .any(|r| r.redirect_uri == redirect_uri) + is_redirect_uri_valid( + redirect_uri, + self.redirect_uris.iter().map(|r| r.redirect_uri.as_str()), + ) } } From 532feaca989c9e99d8a632bcb915974b9ceeeb47 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 14:47:35 -0500 Subject: [PATCH 11/51] Filter proxy responses --- .../login/oauth/flow/device_token.rs | 275 +++++++++++++++--- 1 file changed, 227 insertions(+), 48 deletions(-) diff --git a/v-api/src/endpoints/login/oauth/flow/device_token.rs b/v-api/src/endpoints/login/oauth/flow/device_token.rs index b20753dc..e7b0c051 100644 --- a/v-api/src/endpoints/login/oauth/flow/device_token.rs +++ b/v-api/src/endpoints/login/oauth/flow/device_token.rs @@ -4,7 +4,8 @@ use chrono::{DateTime, Utc}; use dropshot::{Body, HttpError, HttpResponseOk, Method, Path, RequestContext, TypedBody}; -use http::{header, HeaderValue, Response, StatusCode}; +use http::{header, HeaderMap, HeaderValue, Response, StatusCode}; +use hyper::body::Bytes; use oauth2::{basic::BasicTokenType, EmptyExtraTokenFields, StandardTokenResponse, TokenResponse}; use schemars::JsonSchema; use secrecy::ExposeSecret; @@ -177,11 +178,7 @@ where // report their error back to the client tracing::debug!(provider = ?path.provider, ?headers, ?status, "Received error response from OAuth provider"); - let mut client_response = Response::new(Body::from(bytes)); - *client_response.headers_mut() = headers; - *client_response.status_mut() = status; - - Ok(client_response) + Ok(proxy_upstream_response(bytes, headers, status)) } else { // The server gave us back a non-error response but it still may not be a success. // GitHub for instance does not use a status code for indicating the success or failure @@ -245,49 +242,231 @@ where "Failed to parse a success response from the remote token endpoint" ); - // Try to deserialize the body again, but this time as an error - let mut error_response = match serde_json::from_slice::(&bytes) { - Ok(error) => { - // We found an error in the message body. This is not ideal, but we at - // least can understand what the server was trying to tell us - tracing::debug!(?error, provider = ?path.provider, "Parsed error response from OAuth provider"); - - let mut client_response = Response::new(Body::from(bytes)); - *client_response.headers_mut() = headers; - *client_response.status_mut() = status; - - client_response - } - Err(_) => { - // We still do not know what the remote server is doing... and need to - // cancel the request ourselves - tracing::warn!( - "Remote OAuth provide returned a response that we do not undestand" - ); - - Response::new( - serde_json::to_vec(&ProxyTokenError { - error: "access_denied".to_string(), - error_description: Some(format!( - "{} returned a malformed response", - path.provider - )), - error_uri: None, - }) - .unwrap() - .into(), - ) - } - }; - - *error_response.status_mut() = StatusCode::BAD_REQUEST; - error_response.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static("application/json"), - ); - - Ok(error_response) + Ok(handle_token_parse_failure( + &path.provider.to_string(), + bytes, + headers, + status, + )) } } } } + +/// Headers that are safe to forward from an upstream OAuth provider response. +/// Only `Content-Type` is needed so the client can parse the body. Polling backoff +/// is handled via the JSON body per RFC 8628 (`interval` field / `slow_down` error), +/// not via HTTP headers. +const FORWARDED_HEADERS: &[header::HeaderName] = &[ + header::CONTENT_TYPE, +]; + +/// Copy only allowlisted headers from an upstream response to avoid forwarding +/// dangerous headers such as `Set-Cookie`, `Location`, or CORS headers. +fn filter_upstream_headers(upstream: &HeaderMap) -> HeaderMap { + let mut filtered = HeaderMap::new(); + for name in FORWARDED_HEADERS { + if let Some(value) = upstream.get(name) { + filtered.insert(name.clone(), value.clone()); + } + } + filtered +} + +/// Build a response to the client by proxying an upstream provider's response. This is used +/// when the upstream provider returns a non-success status code. +fn proxy_upstream_response(bytes: Bytes, headers: HeaderMap, status: StatusCode) -> Response { + let mut client_response = Response::new(Body::from(bytes)); + *client_response.headers_mut() = filter_upstream_headers(&headers); + *client_response.status_mut() = status; + client_response +} + +/// Handle the case where the upstream provider returned a 200 status but the body could not be +/// parsed as a valid token response. We try to interpret the body as an error response and proxy +/// it back. If the body is not a recognizable error either, we return our own error. +fn handle_token_parse_failure( + provider_name: &str, + bytes: Bytes, + headers: HeaderMap, + status: StatusCode, +) -> Response { + // Try to deserialize the body as an error + let mut error_response = match serde_json::from_slice::(&bytes) { + Ok(error) => { + // We found an error in the message body. This is not ideal, but we at + // least can understand what the server was trying to tell us + tracing::debug!(?error, provider_name, "Parsed error response from OAuth provider"); + + let mut client_response = Response::new(Body::from(bytes)); + *client_response.headers_mut() = filter_upstream_headers(&headers); + *client_response.status_mut() = status; + + client_response + } + Err(_) => { + // We still do not know what the remote server is doing... and need to + // cancel the request ourselves + tracing::warn!("Remote OAuth provider returned a response that we do not understand"); + + Response::new( + serde_json::to_vec(&ProxyTokenError { + error: "access_denied".to_string(), + error_description: Some(format!( + "{} returned a malformed response", + provider_name + )), + error_uri: None, + }) + .unwrap() + .into(), + ) + } + }; + + *error_response.status_mut() = StatusCode::BAD_REQUEST; + error_response.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + + error_response +} + +#[cfg(test)] +mod tests { + use http::{ + header::{self, HeaderName, SET_COOKIE}, + HeaderMap, HeaderValue, StatusCode, + }; + use hyper::body::Bytes; + + use super::{handle_token_parse_failure, proxy_upstream_response}; + + #[test] + fn test_upstream_set_cookie_is_stripped_from_error_response() { + // A malicious or compromised upstream provider includes a Set-Cookie header + // that would set a cookie on our API's domain in the user's browser + let mut upstream_headers = HeaderMap::new(); + upstream_headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + upstream_headers.insert( + SET_COOKIE, + HeaderValue::from_static("session=malicious-value; Path=/; HttpOnly"), + ); + + let body = Bytes::from_static(b"{\"error\": \"authorization_pending\"}"); + let response = proxy_upstream_response(body, upstream_headers, StatusCode::FORBIDDEN); + + // The Set-Cookie header must NOT be forwarded to our client + assert!( + response.headers().get(SET_COOKIE).is_none(), + "Upstream Set-Cookie header must not be forwarded to the client" + ); + // But Content-Type should still be forwarded + assert!( + response.headers().get(header::CONTENT_TYPE).is_some(), + "Content-Type should be forwarded from upstream" + ); + } + + #[test] + fn test_upstream_cors_headers_are_stripped_from_error_response() { + // A malicious upstream provider injects permissive CORS headers that would + // weaken our API's cross-origin protections + let mut upstream_headers = HeaderMap::new(); + upstream_headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + upstream_headers.insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + HeaderValue::from_static("*"), + ); + upstream_headers.insert( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + + let body = Bytes::from_static(b"{\"error\": \"authorization_pending\"}"); + let response = proxy_upstream_response(body, upstream_headers, StatusCode::BAD_REQUEST); + + // CORS headers must NOT be forwarded from upstream + assert!( + response.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).is_none(), + "Upstream CORS header must not be forwarded to the client" + ); + assert!( + response.headers().get(header::ACCESS_CONTROL_ALLOW_CREDENTIALS).is_none(), + "Upstream CORS credentials header must not be forwarded to the client" + ); + } + + #[test] + fn test_upstream_location_and_framing_headers_are_stripped_from_token_parse_failure() { + // When the upstream returns a 200 status but the body is an error (not a valid + // token), handle_token_parse_failure must not forward dangerous headers. + let mut upstream_headers = HeaderMap::new(); + upstream_headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + upstream_headers.insert( + header::LOCATION, + HeaderValue::from_static("https://evil.example.com/phishing"), + ); + upstream_headers.insert( + HeaderName::from_static("x-frame-options"), + HeaderValue::from_static("ALLOW-FROM https://evil.example.com"), + ); + + // Body that parses as a ProxyTokenError but NOT as a valid token + let body = Bytes::from_static( + b"{\"error\": \"slow_down\", \"error_description\": null, \"error_uri\": null}", + ); + let response = handle_token_parse_failure("test-provider", body, upstream_headers, StatusCode::OK); + + // Dangerous headers must NOT be forwarded + assert!( + response.headers().get(header::LOCATION).is_none(), + "Upstream Location header must not be forwarded to the client" + ); + assert!( + response.headers().get("x-frame-options").is_none(), + "Upstream X-Frame-Options header must not be forwarded to the client" + ); + // But Content-Type should still be present (set by the function itself) + assert!( + response.headers().get(header::CONTENT_TYPE).is_some(), + "Content-Type should be present on the response" + ); + } + + #[test] + fn test_upstream_set_cookie_is_stripped_from_token_parse_failure() { + // Even when the upstream returns 200 and the body is a parseable error, + // a Set-Cookie header must not be forwarded to our client + let mut upstream_headers = HeaderMap::new(); + upstream_headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + upstream_headers.insert( + SET_COOKIE, + HeaderValue::from_static("tracking=evil-tracker; Domain=.our-api.com; Path=/"), + ); + + let body = Bytes::from_static( + b"{\"error\": \"authorization_pending\", \"error_description\": null, \"error_uri\": null}", + ); + let response = handle_token_parse_failure("test-provider", body, upstream_headers, StatusCode::OK); + + // The Set-Cookie header must NOT be forwarded + assert!( + response.headers().get(SET_COOKIE).is_none(), + "Upstream Set-Cookie header must not be forwarded via token parse failure path" + ); + } +} From aa36bb1dd3a171c45bed6625b3647c94f5c65e3c Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 15:08:10 -0500 Subject: [PATCH 12/51] More pkce support --- v-api/src/endpoints/login/oauth/flow/code.rs | 86 ++++++++++++------- v-api/src/endpoints/login/oauth/mod.rs | 8 +- .../endpoints/login/oauth/remote/github.rs | 3 + .../endpoints/login/oauth/remote/google.rs | 3 + .../endpoints/login/oauth/remote/zendesk.rs | 3 + 5 files changed, 71 insertions(+), 32 deletions(-) diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 74999836..156d8eca 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -99,6 +99,10 @@ pub struct OAuthAuthzCodeQuery { pub response_type: String, pub state: String, pub scope: Option, + /// PKCE code challenge (RFC 7636). Required for all authorization code flows. + pub code_challenge: String, + /// PKCE code challenge method. Must be "S256". + pub code_challenge_method: String, } #[derive(Debug, Deserialize, JsonSchema, Serialize)] @@ -198,6 +202,16 @@ where tracing::debug!(?query.client_id, ?query.redirect_uri, "Verified client id and redirect uri"); + // Validate the client's PKCE challenge method. Only S256 is supported. + if query.code_challenge_method != "S256" { + return Err(OAuthError { + error: OAuthErrorCode::InvalidRequest, + error_description: Some("Unsupported code_challenge_method. Only S256 is supported.".to_string()), + error_uri: None, + state: None, + }.into()); + } + // Find the configured provider for the requested remote backend. We should always have a valid // provider value, so if this fails then a 500 is returned let provider = ctx @@ -237,8 +251,14 @@ where // process once before storing so all downstream consumers see the encoded value. attempt.state = Some(percent_encode(query.state.as_bytes(), NON_ALPHANUMERIC).to_string()); - // If the remote provider supports pkce, set up a challenge - let pkce_challenge = if provider.supports_pkce() { + // Always store the client's PKCE challenge so we can verify it during the token exchange. + // This is the client-to-v-api PKCE leg and is mandatory for all flows. + attempt.pkce_challenge = Some(query.code_challenge); + attempt.pkce_challenge_method = Some(query.code_challenge_method); + + // If the remote provider supports PKCE, also set up a challenge for the v-api-to-remote leg. + // This is independent of the client-to-v-api PKCE above. + let remote_pkce_challenge = if provider.supports_pkce() { let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); attempt.provider_pkce_verifier = Some(pkce_verifier.secret().to_string()); Some(pkce_challenge) @@ -255,7 +275,7 @@ where tracing::info!(?attempt.id, "Created login attempt"); - oauth_redirect_response(ctx.public_url(), &*provider, &attempt, pkce_challenge) + oauth_redirect_response(ctx.public_url(), &*provider, &attempt, remote_pkce_challenge) } fn oauth_redirect_response( @@ -477,7 +497,8 @@ pub struct OAuthAuthzCodeExchangeBody { pub redirect_uri: String, pub grant_type: String, pub code: String, - pub pkce_verifier: Option, + /// PKCE code verifier (RFC 7636). Required for all authorization code exchanges. + pub pkce_verifier: String, } #[derive(Debug, Deserialize, JsonSchema, Serialize)] @@ -603,7 +624,7 @@ where &attempt, client_id, &body.redirect_uri, - body.pkce_verifier.as_deref(), + &body.pkce_verifier, )?; tracing::debug!("Verified login attempt"); @@ -740,7 +761,7 @@ fn verify_login_attempt( attempt: &LoginAttempt, client_id: TypedUuid, redirect_uri: &str, - pkce_verifier: Option<&str>, + pkce_verifier: &str, ) -> Result<(), OAuthError> { if attempt.client_id != client_id { Err(OAuthError { @@ -771,16 +792,10 @@ fn verify_login_attempt( state: None, }) } else { - match (attempt.pkce_challenge.as_deref(), pkce_verifier) { - (Some(_), None) => Err(OAuthError { - error: OAuthErrorCode::InvalidRequest, - error_description: Some("Missing pkce verifier".to_string()), - error_uri: None, - state: None, - }), - (Some(challenge), Some(verifier)) => { + match attempt.pkce_challenge.as_deref() { + Some(challenge) => { let mut hasher = Sha256::new(); - hasher.update(verifier); + hasher.update(pkce_verifier); let hash = hasher.finalize(); let computed_challenge = BASE64_URL_SAFE_NO_PAD.encode(hash); @@ -795,7 +810,14 @@ fn verify_login_attempt( }) } } - (None, _) => Ok(()), + // PKCE is mandatory for all authorization code flows. A missing challenge + // means the login attempt was not properly initialized. + None => Err(OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Login attempt is missing a PKCE challenge".to_string()), + error_uri: None, + state: None, + }), } } } @@ -1643,7 +1665,7 @@ mod tests { &bad_client_id, attempt.client_id, &attempt.redirect_uri, - Some(verifier.secret().as_str()), + verifier.secret().as_str(), ) .unwrap_err() ); @@ -1664,7 +1686,7 @@ mod tests { &bad_redirect_uri, attempt.client_id, &attempt.redirect_uri, - Some(verifier.secret().as_str()), + verifier.secret().as_str(), ) .unwrap_err() ); @@ -1685,7 +1707,7 @@ mod tests { &unconfirmed_state, attempt.client_id, &attempt.redirect_uri, - Some(verifier.secret().as_str()), + verifier.secret().as_str(), ) .unwrap_err() ); @@ -1706,7 +1728,7 @@ mod tests { &already_used_state, attempt.client_id, &attempt.redirect_uri, - Some(verifier.secret().as_str()), + verifier.secret().as_str(), ) .unwrap_err() ); @@ -1727,7 +1749,7 @@ mod tests { &failed_state, attempt.client_id, &attempt.redirect_uri, - Some(verifier.secret().as_str()), + verifier.secret().as_str(), ) .unwrap_err() ); @@ -1748,25 +1770,31 @@ mod tests { &expired, attempt.client_id, &attempt.redirect_uri, - Some(verifier.secret().as_str()), + verifier.secret().as_str(), ) .unwrap_err() ); - let missing_pkce = LoginAttempt { ..attempt.clone() }; + // Verify that a login attempt with no stored PKCE challenge is rejected. + // PKCE is mandatory, so a missing challenge means the attempt is invalid. + let missing_challenge = LoginAttempt { + pkce_challenge: None, + pkce_challenge_method: None, + ..attempt.clone() + }; assert_eq!( OAuthError { - error: OAuthErrorCode::InvalidRequest, - error_description: Some("Missing pkce verifier".to_string()), + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Login attempt is missing a PKCE challenge".to_string()), error_uri: None, state: None, }, verify_login_attempt( - &missing_pkce, + &missing_challenge, attempt.client_id, &attempt.redirect_uri, - None, + verifier.secret().as_str(), ) .unwrap_err() ); @@ -1787,7 +1815,7 @@ mod tests { &invalid_pkce, attempt.client_id, &attempt.redirect_uri, - Some(verifier.secret().as_str()), + verifier.secret().as_str(), ) .unwrap_err() ); @@ -1798,7 +1826,7 @@ mod tests { &attempt, attempt.client_id, &attempt.redirect_uri, - Some(verifier.secret().as_str()), + verifier.secret().as_str(), ) .unwrap() ); diff --git a/v-api/src/endpoints/login/oauth/mod.rs b/v-api/src/endpoints/login/oauth/mod.rs index 27b0a468..524967be 100644 --- a/v-api/src/endpoints/login/oauth/mod.rs +++ b/v-api/src/endpoints/login/oauth/mod.rs @@ -74,9 +74,11 @@ pub trait OAuthProvider: ExtractUserInfo + Debug + Send + Sync { fn default_scopes(&self) -> &[String]; - fn supports_pkce(&self) -> bool { - true - } + /// Whether the remote OAuth provider supports PKCE (RFC 7636). Providers must + /// explicitly declare this. This controls whether v-api sends a PKCE challenge + /// to the remote provider during the authorization code exchange. Note: clients + /// calling v-api are always required to use PKCE regardless of this setting. + fn supports_pkce(&self) -> bool; fn as_web_client(&self) -> Result { match self.authz_code_flow_info() { diff --git a/v-api/src/endpoints/login/oauth/remote/github.rs b/v-api/src/endpoints/login/oauth/remote/github.rs index 86705976..7d59e260 100644 --- a/v-api/src/endpoints/login/oauth/remote/github.rs +++ b/v-api/src/endpoints/login/oauth/remote/github.rs @@ -144,6 +144,9 @@ impl OAuthProvider for GitHubOAuthProvider { fn default_scopes(&self) -> &[String] { &self.default_scopes } + fn supports_pkce(&self) -> bool { + false + } fn authz_code_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodeInfo> { self.authz_code_flow_info.as_ref() diff --git a/v-api/src/endpoints/login/oauth/remote/google.rs b/v-api/src/endpoints/login/oauth/remote/google.rs index 9cb7c6f5..23879e89 100644 --- a/v-api/src/endpoints/login/oauth/remote/google.rs +++ b/v-api/src/endpoints/login/oauth/remote/google.rs @@ -169,6 +169,9 @@ impl OAuthProvider for GoogleOAuthProvider { fn default_scopes(&self) -> &[String] { &self.default_scopes } + fn supports_pkce(&self) -> bool { + true + } fn authz_code_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodeInfo> { self.authz_code_flow_info.as_ref() diff --git a/v-api/src/endpoints/login/oauth/remote/zendesk.rs b/v-api/src/endpoints/login/oauth/remote/zendesk.rs index 143c6e69..51ecd31a 100644 --- a/v-api/src/endpoints/login/oauth/remote/zendesk.rs +++ b/v-api/src/endpoints/login/oauth/remote/zendesk.rs @@ -140,6 +140,9 @@ impl OAuthProvider for ZendeskOAuthProvider { fn default_scopes(&self) -> &[String] { &self.default_scopes } + fn supports_pkce(&self) -> bool { + true + } fn authz_code_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodeInfo> { self.authz_code_flow_info.as_ref() From d4c69923cf5f1b615cdbe03276157a2c206d41bf Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 15:12:35 -0500 Subject: [PATCH 13/51] Early return on invalid scope --- v-api/src/endpoints/login/oauth/flow/code.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 156d8eca..e1f8e10d 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -223,9 +223,15 @@ where // Check that the passed in scopes are valid. The scopes are not currently restricted by client let scope = query.scope.unwrap_or_else(|| DEFAULT_SCOPE.to_string()); - let scope_error = VPermission::from_scope_arg(&scope) - .err() - .map(|_| "invalid_scope".to_string()); + if let Err(err) = VPermission::from_scope_arg(&scope) { + tracing::warn!(?err, ?scope, "Client submitted an invalid scope"); + return Err(OAuthError { + error: OAuthErrorCode::InvalidScope, + error_description: Some(format!("Invalid scope: {}", scope)), + error_uri: None, + state: None, + }.into()); + } // Construct a new login attempt with the minimum required values let mut attempt = NewLoginAttempt::new( @@ -243,9 +249,6 @@ where // TODO: Make this configurable attempt.expires_at = Some(Utc::now().add(TimeDelta::try_minutes(5).unwrap())); - // Assign any scope errors that arose - attempt.error = scope_error; - // Add in the user defined state and redirect uri. State is an arbitrary value and may be // malicious. It must be url-encoded before being presented back to the client. Therefore we // process once before storing so all downstream consumers see the encoded value. From 5d94c95761e4d7bd85659a38881acc454d67107f Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 15:20:12 -0500 Subject: [PATCH 14/51] Add check on provider --- v-api/src/endpoints/login/oauth/flow/code.rs | 77 +++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index e1f8e10d..0faf9ca8 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -625,6 +625,7 @@ where // Verify that the login attempt is valid and matches the submitted client credentials verify_login_attempt( &attempt, + &provider.name().to_string(), client_id, &body.redirect_uri, &body.pkce_verifier, @@ -762,11 +763,19 @@ where fn verify_login_attempt( attempt: &LoginAttempt, + provider: &str, client_id: TypedUuid, redirect_uri: &str, pkce_verifier: &str, ) -> Result<(), OAuthError> { - if attempt.client_id != client_id { + if attempt.provider != provider { + Err(OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Provider mismatch".to_string()), + error_uri: None, + state: None, + }) + } else if attempt.client_id != client_id { Err(OAuthError { error: OAuthErrorCode::InvalidGrant, error_description: Some("Invalid client id".to_string()), @@ -1666,6 +1675,7 @@ mod tests { }, verify_login_attempt( &bad_client_id, + &attempt.provider, attempt.client_id, &attempt.redirect_uri, verifier.secret().as_str(), @@ -1687,6 +1697,7 @@ mod tests { }, verify_login_attempt( &bad_redirect_uri, + &attempt.provider, attempt.client_id, &attempt.redirect_uri, verifier.secret().as_str(), @@ -1708,6 +1719,7 @@ mod tests { }, verify_login_attempt( &unconfirmed_state, + &attempt.provider, attempt.client_id, &attempt.redirect_uri, verifier.secret().as_str(), @@ -1729,6 +1741,7 @@ mod tests { }, verify_login_attempt( &already_used_state, + &attempt.provider, attempt.client_id, &attempt.redirect_uri, verifier.secret().as_str(), @@ -1750,6 +1763,7 @@ mod tests { }, verify_login_attempt( &failed_state, + &attempt.provider, attempt.client_id, &attempt.redirect_uri, verifier.secret().as_str(), @@ -1771,6 +1785,7 @@ mod tests { }, verify_login_attempt( &expired, + &attempt.provider, attempt.client_id, &attempt.redirect_uri, verifier.secret().as_str(), @@ -1795,6 +1810,7 @@ mod tests { }, verify_login_attempt( &missing_challenge, + &attempt.provider, attempt.client_id, &attempt.redirect_uri, verifier.secret().as_str(), @@ -1816,6 +1832,63 @@ mod tests { }, verify_login_attempt( &invalid_pkce, + &attempt.provider, + attempt.client_id, + &attempt.redirect_uri, + verifier.secret().as_str(), + ) + .unwrap_err() + ); + + assert_eq!( + (), + verify_login_attempt( + &attempt, + &attempt.provider, + attempt.client_id, + &attempt.redirect_uri, + verifier.secret().as_str(), + ) + .unwrap() + ); + } + + #[tokio::test] + async fn test_provider_mismatch_is_rejected() { + let (challenge, verifier) = PkceCodeChallenge::new_random_sha256(); + + // Login attempt was created via Google + let attempt = LoginAttempt { + id: TypedUuid::new_v4(), + attempt_state: LoginAttemptState::RemoteAuthenticated, + client_id: TypedUuid::new_v4(), + redirect_uri: "https://test.oxeng.dev/callback".to_string(), + state: Some("ox_state".to_string()), + pkce_challenge: Some(challenge.as_str().to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: None, + expires_at: Some(Utc::now().add(TimeDelta::try_seconds(60).unwrap())), + error: None, + provider: "google".to_string(), + provider_pkce_verifier: Some("v_verifier".to_string()), + provider_authz_code: None, + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + scope: String::new(), + }; + + // Exchanging against a different provider must fail + assert_eq!( + OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Provider mismatch".to_string()), + error_uri: None, + state: None, + }, + verify_login_attempt( + &attempt, + "github", attempt.client_id, &attempt.redirect_uri, verifier.secret().as_str(), @@ -1823,10 +1896,12 @@ mod tests { .unwrap_err() ); + // Exchanging against the correct provider must succeed assert_eq!( (), verify_login_attempt( &attempt, + "google", attempt.client_id, &attempt.redirect_uri, verifier.secret().as_str(), From b87a944846bfa3bd94c057c0239597e855450dd6 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 15:38:55 -0500 Subject: [PATCH 15/51] Error fixes --- v-api/src/endpoints/login/oauth/flow/code.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 0faf9ca8..209ee8c1 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -65,7 +65,7 @@ struct OAuthError { } #[derive(Debug, Deserialize, JsonSchema, Serialize, PartialEq, Eq)] -#[serde(untagged)] +#[serde(rename_all = "snake_case")] enum OAuthErrorCode { AccessDenied, InvalidClient, From 77383eec05252609dfcd05f8b3effd1481f78100 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 15:48:38 -0500 Subject: [PATCH 16/51] Adding redirect url validation --- Cargo.lock | 1 + .../src/endpoints/login/magic_link/client.rs | 13 +++++- v-api/src/endpoints/login/oauth/client.rs | 13 +++++- v-api/src/endpoints/login/oauth/flow/code.rs | 11 +++-- v-model/Cargo.toml | 1 + v-model/src/lib.rs | 42 ++++++++++--------- 6 files changed, 56 insertions(+), 25 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9dc0cfb4..51078d23 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3939,6 +3939,7 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tracing", + "url", "uuid", "v-api-installer", ] diff --git a/v-api/src/endpoints/login/magic_link/client.rs b/v-api/src/endpoints/login/magic_link/client.rs index 58a70380..287e3134 100644 --- a/v-api/src/endpoints/login/magic_link/client.rs +++ b/v-api/src/endpoints/login/magic_link/client.rs @@ -8,6 +8,7 @@ use newtype_uuid::{GenericUuid, TypedUuid}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use tracing::instrument; +use url::Url; use v_model::{ permissions::{Caller, PermissionStorage}, MagicLink, MagicLinkId, MagicLinkRedirectUri, MagicLinkRedirectUriId, MagicLinkSecret, @@ -19,7 +20,7 @@ use crate::{ context::{ApiContext, VContextWithCaller}, permissions::{VAppPermission, VPermission}, secrets::OpenApiSecretString, - util::response::to_internal_error, + util::response::{bad_request, to_internal_error}, VContext, }; @@ -192,6 +193,16 @@ where let (ctx, caller) = rqctx.as_ctx().await?; let path = path.into_inner(); let body = body.into_inner(); + + // Validate that the redirect URI is a well-formed URL before storing it. + // Per RFC 6749 §3.1.2, redirect URIs must be absolute URIs and must not + // include a fragment component. + let parsed = Url::parse(&body.redirect_uri) + .map_err(|_| bad_request("Invalid redirect URI: not a valid URL"))?; + if parsed.fragment().is_some() { + return Err(bad_request("Invalid redirect URI: must not contain a fragment")); + } + Ok(HttpResponseOk( ctx.magic_link .add_magic_link_redirect_uri(&caller, &path.client_id, &body.redirect_uri) diff --git a/v-api/src/endpoints/login/oauth/client.rs b/v-api/src/endpoints/login/oauth/client.rs index b7d799d7..e5fb6fbd 100644 --- a/v-api/src/endpoints/login/oauth/client.rs +++ b/v-api/src/endpoints/login/oauth/client.rs @@ -8,6 +8,7 @@ use newtype_uuid::{GenericUuid, TypedUuid}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use tracing::instrument; +use url::Url; use v_model::{ permissions::{Caller, PermissionStorage}, OAuthClient, OAuthClientId, OAuthClientRedirectUri, OAuthClientSecret, OAuthRedirectUriId, @@ -19,7 +20,7 @@ use crate::{ context::{ApiContext, VContextWithCaller}, permissions::{VAppPermission, VPermission}, secrets::OpenApiSecretString, - util::response::to_internal_error, + util::response::{bad_request, to_internal_error}, VContext, }; @@ -191,6 +192,16 @@ where let (ctx, caller) = rqctx.as_ctx().await?; let path = path.into_inner(); let body = body.into_inner(); + + // Validate that the redirect URI is a well-formed URL before storing it. + // Per RFC 6749 §3.1.2, redirect URIs must be absolute URIs and must not + // include a fragment component. + let parsed = Url::parse(&body.redirect_uri) + .map_err(|_| bad_request("Invalid redirect URI: not a valid URL"))?; + if parsed.fragment().is_some() { + return Err(bad_request("Invalid redirect URI: must not contain a fragment")); + } + Ok(HttpResponseOk( ctx.oauth .add_oauth_redirect_uri(&caller, &path.client_id, &body.redirect_uri) diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 209ee8c1..154b0f58 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -485,7 +485,10 @@ where }; // Redirect back to the original authenticator - Ok(attempt.callback_url()) + attempt.callback_url().map_err(|err| { + tracing::error!(?err, redirect_uri = ?attempt.redirect_uri, "Login attempt contains an invalid redirect URI"); + to_internal_error(err) + }) } #[derive(Debug, Deserialize, JsonSchema)] @@ -1250,7 +1253,7 @@ mod tests { .unwrap(); assert_eq!( - format!("https://test.oxeng.dev/callback?error=server_error&state=ox_state",), + format!("https://test.oxeng.dev/callback?state=ox_state&error=server_error",), location ); } @@ -1310,7 +1313,7 @@ mod tests { .unwrap(); assert_eq!( - format!("https://test.oxeng.dev/callback?error=access_denied&state=ox_state",), + format!("https://test.oxeng.dev/callback?state=ox_state&error=access_denied",), location ); } @@ -1370,7 +1373,7 @@ mod tests { let lock = extracted_code.lock(); assert_eq!( format!( - "https://test.oxeng.dev/callback?code={}&state=ox_state", + "https://test.oxeng.dev/callback?state=ox_state&code={}", lock.unwrap().as_ref().unwrap() ), location diff --git a/v-model/Cargo.toml b/v-model/Cargo.toml index 87d5755a..63234757 100644 --- a/v-model/Cargo.toml +++ b/v-model/Cargo.toml @@ -24,6 +24,7 @@ serde_json = { workspace = true } steno = { workspace = true, optional = true } thiserror = { workspace = true } tracing = { workspace = true } +url = { workspace = true } uuid = { workspace = true, features = ["v4", "serde"] } [dev-dependencies] diff --git a/v-model/src/lib.rs b/v-model/src/lib.rs index b02c4efa..bd320a6d 100644 --- a/v-model/src/lib.rs +++ b/v-model/src/lib.rs @@ -16,10 +16,11 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::{ - collections::{BTreeMap, BTreeSet}, + collections::BTreeSet, fmt::Display, }; use thiserror::Error; +use url::Url; pub mod db; pub mod permissions; @@ -282,26 +283,22 @@ pub struct LoginAttempt { } impl LoginAttempt { - pub fn callback_url(&self) -> String { - let mut params = BTreeMap::new(); - - if let Some(state) = &self.state { - params.insert("state", state); - } - - if let Some(error) = &self.error { - params.insert("error", error); - } else if let Some(authz_code) = &self.authz_code { - params.insert("code", authz_code); + pub fn callback_url(&self) -> Result { + let mut url = Url::parse(&self.redirect_uri)?; + + { + let mut pairs = url.query_pairs_mut(); + if let Some(state) = &self.state { + pairs.append_pair("state", state); + } + if let Some(error) = &self.error { + pairs.append_pair("error", error); + } else if let Some(authz_code) = &self.authz_code { + pairs.append_pair("code", authz_code); + } } - let query_string = params - .into_iter() - .map(|(k, v)| format!("{}={}", k, v)) - .collect::>() - .join("&"); - - [self.redirect_uri.as_str(), query_string.as_str()].join("?") + Ok(url.to_string()) } } @@ -324,6 +321,13 @@ impl NewLoginAttempt { redirect_uri: String, scope: String, ) -> Result { + // Validate that the redirect URI is a well-formed URL. This ensures + // callback_url() can always parse it later. + Url::parse(&redirect_uri).map_err(|err| InvalidValueError { + field: "redirect_uri".to_string(), + error: format!("Invalid URL: {}", err), + })?; + Ok(Self { id: TypedUuid::new_v4(), attempt_state: LoginAttemptState::New, From 014ea592cdffe52327f4340f8d6278752f92b5dc Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 15:52:07 -0500 Subject: [PATCH 17/51] Validate pkce challenge --- v-api/src/endpoints/login/oauth/flow/code.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 154b0f58..25b9958b 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -212,6 +212,20 @@ where }.into()); } + // Validate the PKCE code challenge. For S256, this must be a base64url-no-pad + // encoding of a SHA256 hash, which is always exactly 43 characters of [A-Za-z0-9_-] + // (RFC 7636 §4.2). + if query.code_challenge.len() != 43 + || !query.code_challenge.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'-' || b == b'_') + { + return Err(OAuthError { + error: OAuthErrorCode::InvalidRequest, + error_description: Some("Invalid code_challenge. Must be a base64url-encoded SHA256 hash (43 characters).".to_string()), + error_uri: None, + state: None, + }.into()); + } + // Find the configured provider for the requested remote backend. We should always have a valid // provider value, so if this fails then a 500 is returned let provider = ctx From 12abe10935dbc9fdd37cb1ce5abf1eed283c697c Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 15:55:22 -0500 Subject: [PATCH 18/51] Handle idp user info errors --- v-api/src/endpoints/login/mod.rs | 5 +++++ v-api/src/endpoints/login/oauth/mod.rs | 11 ++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/v-api/src/endpoints/login/mod.rs b/v-api/src/endpoints/login/mod.rs index 54073743..ee2481ea 100644 --- a/v-api/src/endpoints/login/mod.rs +++ b/v-api/src/endpoints/login/mod.rs @@ -207,6 +207,11 @@ pub enum UserInfoError { Locked, #[error("User information is missing")] MissingUserInfoData(String), + #[error("User info endpoint returned HTTP {status} for {endpoint}")] + UnexpectedStatus { + endpoint: String, + status: http::StatusCode, + }, } #[async_trait] diff --git a/v-api/src/endpoints/login/oauth/mod.rs b/v-api/src/endpoints/login/oauth/mod.rs index 524967be..b54e0024 100644 --- a/v-api/src/endpoints/login/oauth/mod.rs +++ b/v-api/src/endpoints/login/oauth/mod.rs @@ -138,8 +138,17 @@ where ); let response = self.client().execute(request).await?; + let status = response.status(); - tracing::trace!(status = ?response.status(), "Received response from OAuth provider"); + tracing::trace!(?status, "Received response from OAuth provider"); + + if !status.is_success() { + tracing::error!(?status, endpoint, "User info endpoint returned non-success status"); + return Err(UserInfoError::UnexpectedStatus { + endpoint: endpoint.to_string(), + status, + }); + } let bytes = response.bytes().await?; responses.push(bytes); From 0a5c7d582d07aa1956f68ad9fd4125ce420cc5d0 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 15:58:35 -0500 Subject: [PATCH 19/51] Skip serialize on secret --- v-api/src/endpoints/login/oauth/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/v-api/src/endpoints/login/oauth/mod.rs b/v-api/src/endpoints/login/oauth/mod.rs index b54e0024..6424299b 100644 --- a/v-api/src/endpoints/login/oauth/mod.rs +++ b/v-api/src/endpoints/login/oauth/mod.rs @@ -181,6 +181,7 @@ pub struct OAuthProviderAuthorizationCodeInfo { #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct OAuthProviderAuthorizationCodeRemoteInfo { client_id: String, + #[serde(skip_serializing)] client_secret: OpenApiSecretString, auth_url_endpoint: String, token_endpoint_content_type: String, @@ -200,6 +201,7 @@ pub struct OAuthProviderAuthorizationCodePkceInfo { pub struct OAuthProviderDeviceInfo { client_id: TypedUuid, remote_client_id: String, + #[serde(skip_serializing)] remote_client_secret: OpenApiSecretString, device_code_endpoint: String, token_endpoint_content_type: String, From 0f903fc8d156e3be8cf0affd89446bd02b6cc1b7 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 16:02:33 -0500 Subject: [PATCH 20/51] Cookie scoping --- v-api/src/endpoints/login/oauth/flow/code.rs | 99 +++++++++++++++++--- 1 file changed, 87 insertions(+), 12 deletions(-) diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 25b9958b..44bdf89c 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -51,8 +51,22 @@ use crate::{ }; static LOGIN_ATTEMPT_COOKIE: &str = "__v_login"; +static LOGIN_ATTEMPT_COOKIE_PATH: &str = "/login/oauth/"; static DEFAULT_SCOPE: &str = "user:info:r"; +/// Build the login attempt cookie with consistent attributes. +/// The `Path` is scoped to the OAuth login endpoints so the cookie is not +/// sent to unrelated paths on the same domain. +fn build_login_attempt_cookie<'a>(value: &'a str, public_url: &str, max_age_secs: i64) -> Cookie<'a> { + let mut cookie = Cookie::new(LOGIN_ATTEMPT_COOKIE, value.to_string()); + cookie.set_path(LOGIN_ATTEMPT_COOKIE_PATH); + cookie.set_http_only(true); + cookie.set_same_site(SameSite::Lax); + cookie.set_secure(public_url.starts_with("https")); + cookie.set_max_age(cookie::time::Duration::seconds(max_age_secs)); + cookie +} + #[derive(Debug, Deserialize, JsonSchema, Serialize, PartialEq, Eq)] struct OAuthError { error: OAuthErrorCode, @@ -309,12 +323,8 @@ fn oauth_redirect_response( // Create an attempt cookie header for storing the login attempt. This also acts as our csrf // check - let mut cookie = Cookie::new(LOGIN_ATTEMPT_COOKIE, attempt.id.to_string()); - cookie.set_http_only(true); - cookie.set_same_site(SameSite::Lax); - cookie.set_secure(public_url.starts_with("https")); - cookie.set_max_age(cookie::time::Duration::seconds(600)); - + let attempt_id_str = attempt.id.to_string(); + let cookie = build_login_attempt_cookie(&attempt_id_str, public_url, 600); let login_cookie = HeaderValue::from_str(&cookie.to_string()).map_err(to_internal_error)?; // Generate the url to the remote provider that the user will be redirected to @@ -420,11 +430,7 @@ where let attempt_id = verify_csrf(&rqctx.request, &query)?; // Clear the login attempt cookie - let mut cookie = Cookie::new(LOGIN_ATTEMPT_COOKIE, ""); - cookie.set_http_only(true); - cookie.set_same_site(SameSite::Lax); - cookie.set_secure(ctx.public_url().starts_with("https")); - cookie.set_max_age(cookie::time::Duration::seconds(0)); + let cookie = build_login_attempt_cookie("", ctx.public_url(), 0); let login_cookie = HeaderValue::from_str(&cookie.to_string()).map_err(to_internal_error)?; let mut redirect = http_response_temporary_redirect( @@ -1078,7 +1084,7 @@ mod tests { ); assert_eq!( format!( - "{}; HttpOnly; SameSite=Lax; Secure; Max-Age=600", + "{}; HttpOnly; SameSite=Lax; Secure; Path=/login/oauth/; Max-Age=600", attempt.id ) .as_str(), @@ -1926,4 +1932,73 @@ mod tests { .unwrap() ); } + + #[test] + fn test_login_attempt_cookie_has_path() { + let cookie = super::build_login_attempt_cookie( + "test-attempt-id", + "https://example.com", + 600, + ); + + assert_eq!(cookie.path(), Some(super::LOGIN_ATTEMPT_COOKIE_PATH)); + } + + #[test] + fn test_login_attempt_cookie_is_http_only() { + let cookie = super::build_login_attempt_cookie( + "test-attempt-id", + "https://example.com", + 600, + ); + + assert_eq!(cookie.http_only(), Some(true)); + } + + #[test] + fn test_login_attempt_cookie_is_same_site_lax() { + let cookie = super::build_login_attempt_cookie( + "test-attempt-id", + "https://example.com", + 600, + ); + + assert_eq!(cookie.same_site(), Some(cookie::SameSite::Lax)); + } + + #[test] + fn test_login_attempt_cookie_is_secure_for_https() { + let https_cookie = super::build_login_attempt_cookie( + "test-attempt-id", + "https://example.com", + 600, + ); + assert_eq!(https_cookie.secure(), Some(true)); + + let http_cookie = super::build_login_attempt_cookie( + "test-attempt-id", + "http://localhost", + 600, + ); + assert_eq!(http_cookie.secure(), Some(false)); + } + + #[test] + fn test_login_attempt_clear_cookie_has_same_path() { + // The clear cookie must use the same Path as the set cookie, + // otherwise browsers won't clear it. + let set_cookie = super::build_login_attempt_cookie( + "test-attempt-id", + "https://example.com", + 600, + ); + let clear_cookie = super::build_login_attempt_cookie( + "", + "https://example.com", + 0, + ); + + assert_eq!(set_cookie.path(), clear_cookie.path()); + assert_eq!(clear_cookie.max_age(), Some(cookie::time::Duration::seconds(0))); + } } From 68f7f3ce80c6c996352e2494c5c17042c307f737 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 16:21:02 -0500 Subject: [PATCH 21/51] Fixes for login attempt state transitions --- v-api/src/context/login.rs | 43 +++++++----- v-api/src/context/mod.rs | 12 ++++ v-api/src/endpoints/login/oauth/flow/code.rs | 73 ++++++++++++-------- v-model/src/storage/mod.rs | 9 +++ v-model/src/storage/postgres.rs | 34 ++++++++- 5 files changed, 124 insertions(+), 47 deletions(-) diff --git a/v-api/src/context/login.rs b/v-api/src/context/login.rs index c3ddcb49..75d07fa3 100644 --- a/v-api/src/context/login.rs +++ b/v-api/src/context/login.rs @@ -44,14 +44,12 @@ where attempt: LoginAttempt, code: String, ) -> Result { - let mut attempt: NewLoginAttempt = attempt.into(); - attempt.provider_authz_code = Some(code); + let mut update: NewLoginAttempt = attempt.into(); + update.provider_authz_code = Some(code); + update.attempt_state = LoginAttemptState::RemoteAuthenticated; + update.authz_code = Some(CsrfToken::new_random().secret().to_string()); - // TODO: Internal state changes to the struct - attempt.attempt_state = LoginAttemptState::RemoteAuthenticated; - attempt.authz_code = Some(CsrfToken::new_random().secret().to_string()); - - LoginAttemptStore::upsert(&*self.storage, attempt).await + LoginAttemptStore::update_if_state(&*self.storage, update, LoginAttemptState::New).await } pub async fn get_login_attempt( @@ -84,25 +82,38 @@ where Ok(attempts.pop()) } - pub async fn complete_login_attempt( + /// Atomically claim a login attempt by transitioning it from `RemoteAuthenticated` + /// to `Complete`. Returns an error if the attempt has already been claimed by a + /// concurrent request (i.e., the state is no longer `RemoteAuthenticated`). + /// This must be called before exchanging the authorization code with the remote + /// provider to prevent the same code from being used twice (RFC 6749 §4.1.2). + pub async fn claim_login_attempt( &self, attempt: LoginAttempt, ) -> Result { - let mut attempt: NewLoginAttempt = attempt.into(); - attempt.attempt_state = LoginAttemptState::Complete; - LoginAttemptStore::upsert(&*self.storage, attempt).await + let mut update: NewLoginAttempt = attempt.into(); + update.attempt_state = LoginAttemptState::Complete; + + LoginAttemptStore::update_if_state( + &*self.storage, + update, + LoginAttemptState::RemoteAuthenticated, + ) + .await } pub async fn fail_login_attempt( &self, attempt: LoginAttempt, + expected_state: LoginAttemptState, error: Option<&str>, provider_error: Option<&str>, ) -> Result { - let mut attempt: NewLoginAttempt = attempt.into(); - attempt.attempt_state = LoginAttemptState::Failed; - attempt.error = error.map(|s| s.to_string()); - attempt.provider_error = provider_error.map(|s| s.to_string()); - LoginAttemptStore::upsert(&*self.storage, attempt).await + let mut update: NewLoginAttempt = attempt.into(); + update.attempt_state = LoginAttemptState::Failed; + update.error = error.map(|s| s.to_string()); + update.provider_error = provider_error.map(|s| s.to_string()); + + LoginAttemptStore::update_if_state(&*self.storage, update, expected_state).await } } diff --git a/v-api/src/context/mod.rs b/v-api/src/context/mod.rs index 73808f31..4e5fb41c 100644 --- a/v-api/src/context/mod.rs +++ b/v-api/src/context/mod.rs @@ -1647,6 +1647,18 @@ pub(crate) mod test_mocks { .upsert(attempt) .await } + + async fn update_if_state( + &self, + attempt: NewLoginAttempt, + expected_state: v_model::LoginAttemptState, + ) -> Result { + self.login_attempt_store + .as_ref() + .unwrap() + .update_if_state(attempt, expected_state) + .await + } } #[async_trait] diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 44bdf89c..e642ee7b 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -498,7 +498,7 @@ where // TODO: Specialize the returned error ctx.login - .fail_login_attempt(attempt, Some(error_message), error.as_deref()) + .fail_login_attempt(attempt, LoginAttemptState::New, Some(error_message), error.as_deref()) .await .map_err(to_internal_error)? } @@ -656,8 +656,30 @@ where tracing::debug!("Verified login attempt"); - // Now that the attempt has been confirmed, use it to fetch user information form the remote - // provider + // Atomically claim this login attempt before doing any remote work. This transitions + // the attempt from RemoteAuthenticated -> Complete in a single conditional UPDATE, + // ensuring that a concurrent request using the same authorization code will fail. + // Per RFC 6749 §4.1.2, authorization codes MUST be single-use. + let attempt_id = attempt.id; + attempt = ctx + .login + .claim_login_attempt(attempt) + .await + .map_err(|err| { + tracing::warn!(?err, ?attempt_id, "Failed to claim login attempt (may have been consumed by a concurrent request)"); + OAuthError { + error: OAuthErrorCode::InvalidGrant, + error_description: Some("Authorization code has already been used".to_string()), + error_uri: None, + state: None, + } + })?; + + tracing::debug!("Claimed login attempt"); + + // Now that the attempt has been claimed, use it to fetch user information from the + // remote provider. If this fails, the attempt is already consumed and the user must + // re-authenticate. let info = fetch_user_info( ctx.public_url(), &ctx.web_client(), @@ -670,24 +692,6 @@ where tracing::debug!("Retrieved user information from remote provider"); - // During fetch_user_info we revoke any downstream codes if possible, therefore At this point we - // consider the login attempt to be consumed and can no longer be used. We state transition to - // complete, even though we may fail further along in the handler. If a failure occurs then the - // user will need to re-authenticate. - attempt = ctx - .login - .complete_login_attempt(attempt) - .await - .map_err(|err| { - tracing::error!(?err, "Failed to complete login attempt"); - OAuthError { - error: OAuthErrorCode::ServerError, - error_description: Some("An unexpected error occurred".to_string()), - error_uri: None, - state: None, - } - })?; - // Register this user as an API user if needed let (api_user_info, api_user_provider) = ctx .register_api_user(&ctx.builtin_registration_user(), info) @@ -1249,9 +1253,12 @@ mod tests { .returning(move |_| Ok(Some(original_attempt.clone()))); attempt_store - .expect_upsert() - .withf(|attempt| attempt.attempt_state == LoginAttemptState::Failed) - .returning(move |arg| { + .expect_update_if_state() + .withf(|attempt, expected| { + attempt.attempt_state == LoginAttemptState::Failed + && *expected == LoginAttemptState::New + }) + .returning(move |arg, _| { let mut returned = attempt.clone(); returned.attempt_state = arg.attempt_state; returned.authz_code = arg.authz_code; @@ -1309,9 +1316,12 @@ mod tests { .returning(move |_| Ok(Some(original_attempt.clone()))); attempt_store - .expect_upsert() - .withf(|attempt| attempt.attempt_state == LoginAttemptState::Failed) - .returning(move |arg| { + .expect_update_if_state() + .withf(|attempt, expected| { + attempt.attempt_state == LoginAttemptState::Failed + && *expected == LoginAttemptState::New + }) + .returning(move |arg, _| { let mut returned = attempt.clone(); returned.attempt_state = arg.attempt_state; returned.authz_code = arg.authz_code; @@ -1371,9 +1381,12 @@ mod tests { let extracted_code = Arc::new(Mutex::new(None)); let extractor = extracted_code.clone(); attempt_store - .expect_upsert() - .withf(|attempt| attempt.attempt_state == LoginAttemptState::RemoteAuthenticated) - .returning(move |arg| { + .expect_update_if_state() + .withf(|attempt, expected| { + attempt.attempt_state == LoginAttemptState::RemoteAuthenticated + && *expected == LoginAttemptState::New + }) + .returning(move |arg, _| { let mut returned = attempt.clone(); returned.attempt_state = arg.attempt_state; returned.authz_code = arg.authz_code; diff --git a/v-model/src/storage/mod.rs b/v-model/src/storage/mod.rs index 114a77ec..70568dd3 100644 --- a/v-model/src/storage/mod.rs +++ b/v-model/src/storage/mod.rs @@ -270,6 +270,15 @@ pub trait LoginAttemptStore { pagination: &ListPagination, ) -> Result, StoreError>; async fn upsert(&self, attempt: NewLoginAttempt) -> Result; + /// Atomically update a login attempt, but only if it is currently in the `expected_state`. + /// The `attempt` must have the desired target state and any other field updates already set. + /// Returns `StoreError::InvariantFailed` if the attempt is not in the expected state, + /// which prevents TOCTOU races on state transitions. + async fn update_if_state( + &self, + attempt: NewLoginAttempt, + expected_state: LoginAttemptState, + ) -> Result; } #[derive(Debug, Default)] diff --git a/v-model/src/storage/postgres.rs b/v-model/src/storage/postgres.rs index 5daaaf5c..16a38f82 100644 --- a/v-model/src/storage/postgres.rs +++ b/v-model/src/storage/postgres.rs @@ -29,7 +29,7 @@ use crate::{ magic_link_client_redirect_uri, magic_link_client_secret, mapper, oauth_client, oauth_client_redirect_uri, oauth_client_secret, }, - schema_ext::MagicLinkAttemptState, + schema_ext::{LoginAttemptState, MagicLinkAttemptState}, storage::{LinkRequestFilter, LinkRequestStore, StoreError}, AccessGroup, AccessGroupId, AccessToken, AccessTokenId, ApiKey, ApiKeyId, ApiUser, ApiUserContactEmail, ApiUserInfo, ApiUserProvider, LinkRequest, LinkRequestId, LoginAttempt, @@ -775,6 +775,38 @@ impl LoginAttemptStore for PostgresStore { Ok(attempt_m.into()) } + + async fn update_if_state( + &self, + attempt: NewLoginAttempt, + expected_state: LoginAttemptState, + ) -> Result { + let conn = self.pool.get().await?; + let result: Option = update( + login_attempt::dsl::login_attempt + .filter(login_attempt::id.eq(attempt.id.into_untyped_uuid())) + .filter(login_attempt::attempt_state.eq(expected_state)), + ) + .set(( + login_attempt::attempt_state.eq(attempt.attempt_state.clone()), + login_attempt::authz_code.eq(attempt.authz_code), + login_attempt::expires_at.eq(attempt.expires_at), + login_attempt::error.eq(attempt.error), + login_attempt::provider_authz_code.eq(attempt.provider_authz_code), + login_attempt::provider_error.eq(attempt.provider_error), + )) + .get_result_async::(&*conn) + .await + .optional()?; + + match result { + Some(attempt) => Ok(LoginAttempt::from(attempt)), + None => Err(StoreError::InvariantFailed(format!( + "Login attempt {} is not in expected state for transition to {}", + attempt.id, attempt.attempt_state + ))), + } + } } #[async_trait] From 418126db9217c8216acf62bfee7b715395fc9611 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 16:21:11 -0500 Subject: [PATCH 22/51] Fmt --- .../src/endpoints/login/magic_link/client.rs | 4 +- v-api/src/endpoints/login/mod.rs | 8 +- v-api/src/endpoints/login/oauth/client.rs | 4 +- v-api/src/endpoints/login/oauth/flow/code.rs | 102 ++++++++++-------- .../login/oauth/flow/device_token.rs | 26 +++-- v-api/src/endpoints/login/oauth/mod.rs | 6 +- v-model/src/lib.rs | 5 +- 7 files changed, 90 insertions(+), 65 deletions(-) diff --git a/v-api/src/endpoints/login/magic_link/client.rs b/v-api/src/endpoints/login/magic_link/client.rs index 287e3134..6013eb38 100644 --- a/v-api/src/endpoints/login/magic_link/client.rs +++ b/v-api/src/endpoints/login/magic_link/client.rs @@ -200,7 +200,9 @@ where let parsed = Url::parse(&body.redirect_uri) .map_err(|_| bad_request("Invalid redirect URI: not a valid URL"))?; if parsed.fragment().is_some() { - return Err(bad_request("Invalid redirect URI: must not contain a fragment")); + return Err(bad_request( + "Invalid redirect URI: must not contain a fragment", + )); } Ok(HttpResponseOk( diff --git a/v-api/src/endpoints/login/mod.rs b/v-api/src/endpoints/login/mod.rs index ee2481ea..6c9ef239 100644 --- a/v-api/src/endpoints/login/mod.rs +++ b/v-api/src/endpoints/login/mod.rs @@ -236,8 +236,9 @@ pub fn is_redirect_uri_valid<'a>( return false; } - registered_uris.into_iter().any(|registered| { - match Url::parse(registered) { + registered_uris + .into_iter() + .any(|registered| match Url::parse(registered) { Ok(registered) => { registered.scheme() == candidate.scheme() && registered.host() == candidate.host() @@ -245,6 +246,5 @@ pub fn is_redirect_uri_valid<'a>( && registered.path() == candidate.path() } Err(_) => false, - } - }) + }) } diff --git a/v-api/src/endpoints/login/oauth/client.rs b/v-api/src/endpoints/login/oauth/client.rs index e5fb6fbd..085d1966 100644 --- a/v-api/src/endpoints/login/oauth/client.rs +++ b/v-api/src/endpoints/login/oauth/client.rs @@ -199,7 +199,9 @@ where let parsed = Url::parse(&body.redirect_uri) .map_err(|_| bad_request("Invalid redirect URI: not a valid URL"))?; if parsed.fragment().is_some() { - return Err(bad_request("Invalid redirect URI: must not contain a fragment")); + return Err(bad_request( + "Invalid redirect URI: must not contain a fragment", + )); } Ok(HttpResponseOk( diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index e642ee7b..d5abcb73 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -57,7 +57,11 @@ static DEFAULT_SCOPE: &str = "user:info:r"; /// Build the login attempt cookie with consistent attributes. /// The `Path` is scoped to the OAuth login endpoints so the cookie is not /// sent to unrelated paths on the same domain. -fn build_login_attempt_cookie<'a>(value: &'a str, public_url: &str, max_age_secs: i64) -> Cookie<'a> { +fn build_login_attempt_cookie<'a>( + value: &'a str, + public_url: &str, + max_age_secs: i64, +) -> Cookie<'a> { let mut cookie = Cookie::new(LOGIN_ATTEMPT_COOKIE, value.to_string()); cookie.set_path(LOGIN_ATTEMPT_COOKIE_PATH); cookie.set_http_only(true); @@ -220,24 +224,34 @@ where if query.code_challenge_method != "S256" { return Err(OAuthError { error: OAuthErrorCode::InvalidRequest, - error_description: Some("Unsupported code_challenge_method. Only S256 is supported.".to_string()), + error_description: Some( + "Unsupported code_challenge_method. Only S256 is supported.".to_string(), + ), error_uri: None, state: None, - }.into()); + } + .into()); } // Validate the PKCE code challenge. For S256, this must be a base64url-no-pad // encoding of a SHA256 hash, which is always exactly 43 characters of [A-Za-z0-9_-] // (RFC 7636 §4.2). if query.code_challenge.len() != 43 - || !query.code_challenge.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'-' || b == b'_') + || !query + .code_challenge + .bytes() + .all(|b| b.is_ascii_alphanumeric() || b == b'-' || b == b'_') { return Err(OAuthError { error: OAuthErrorCode::InvalidRequest, - error_description: Some("Invalid code_challenge. Must be a base64url-encoded SHA256 hash (43 characters).".to_string()), + error_description: Some( + "Invalid code_challenge. Must be a base64url-encoded SHA256 hash (43 characters)." + .to_string(), + ), error_uri: None, state: None, - }.into()); + } + .into()); } // Find the configured provider for the requested remote backend. We should always have a valid @@ -258,7 +272,8 @@ where error_description: Some(format!("Invalid scope: {}", scope)), error_uri: None, state: None, - }.into()); + } + .into()); } // Construct a new login attempt with the minimum required values @@ -306,7 +321,12 @@ where tracing::info!(?attempt.id, "Created login attempt"); - oauth_redirect_response(ctx.public_url(), &*provider, &attempt, remote_pkce_challenge) + oauth_redirect_response( + ctx.public_url(), + &*provider, + &attempt, + remote_pkce_challenge, + ) } fn oauth_redirect_response( @@ -498,7 +518,12 @@ where // TODO: Specialize the returned error ctx.login - .fail_login_attempt(attempt, LoginAttemptState::New, Some(error_message), error.as_deref()) + .fail_login_attempt( + attempt, + LoginAttemptState::New, + Some(error_message), + error.as_deref(), + ) .await .map_err(to_internal_error)? } @@ -666,7 +691,11 @@ where .claim_login_attempt(attempt) .await .map_err(|err| { - tracing::warn!(?err, ?attempt_id, "Failed to claim login attempt (may have been consumed by a concurrent request)"); + tracing::warn!( + ?err, + ?attempt_id, + "Failed to claim login attempt (may have been consumed by a concurrent request)" + ); OAuthError { error: OAuthErrorCode::InvalidGrant, error_description: Some("Authorization code has already been used".to_string()), @@ -1948,51 +1977,36 @@ mod tests { #[test] fn test_login_attempt_cookie_has_path() { - let cookie = super::build_login_attempt_cookie( - "test-attempt-id", - "https://example.com", - 600, - ); + let cookie = + super::build_login_attempt_cookie("test-attempt-id", "https://example.com", 600); assert_eq!(cookie.path(), Some(super::LOGIN_ATTEMPT_COOKIE_PATH)); } #[test] fn test_login_attempt_cookie_is_http_only() { - let cookie = super::build_login_attempt_cookie( - "test-attempt-id", - "https://example.com", - 600, - ); + let cookie = + super::build_login_attempt_cookie("test-attempt-id", "https://example.com", 600); assert_eq!(cookie.http_only(), Some(true)); } #[test] fn test_login_attempt_cookie_is_same_site_lax() { - let cookie = super::build_login_attempt_cookie( - "test-attempt-id", - "https://example.com", - 600, - ); + let cookie = + super::build_login_attempt_cookie("test-attempt-id", "https://example.com", 600); assert_eq!(cookie.same_site(), Some(cookie::SameSite::Lax)); } #[test] fn test_login_attempt_cookie_is_secure_for_https() { - let https_cookie = super::build_login_attempt_cookie( - "test-attempt-id", - "https://example.com", - 600, - ); + let https_cookie = + super::build_login_attempt_cookie("test-attempt-id", "https://example.com", 600); assert_eq!(https_cookie.secure(), Some(true)); - let http_cookie = super::build_login_attempt_cookie( - "test-attempt-id", - "http://localhost", - 600, - ); + let http_cookie = + super::build_login_attempt_cookie("test-attempt-id", "http://localhost", 600); assert_eq!(http_cookie.secure(), Some(false)); } @@ -2000,18 +2014,14 @@ mod tests { fn test_login_attempt_clear_cookie_has_same_path() { // The clear cookie must use the same Path as the set cookie, // otherwise browsers won't clear it. - let set_cookie = super::build_login_attempt_cookie( - "test-attempt-id", - "https://example.com", - 600, - ); - let clear_cookie = super::build_login_attempt_cookie( - "", - "https://example.com", - 0, - ); + let set_cookie = + super::build_login_attempt_cookie("test-attempt-id", "https://example.com", 600); + let clear_cookie = super::build_login_attempt_cookie("", "https://example.com", 0); assert_eq!(set_cookie.path(), clear_cookie.path()); - assert_eq!(clear_cookie.max_age(), Some(cookie::time::Duration::seconds(0))); + assert_eq!( + clear_cookie.max_age(), + Some(cookie::time::Duration::seconds(0)) + ); } } diff --git a/v-api/src/endpoints/login/oauth/flow/device_token.rs b/v-api/src/endpoints/login/oauth/flow/device_token.rs index e7b0c051..064ff2d5 100644 --- a/v-api/src/endpoints/login/oauth/flow/device_token.rs +++ b/v-api/src/endpoints/login/oauth/flow/device_token.rs @@ -257,9 +257,7 @@ where /// Only `Content-Type` is needed so the client can parse the body. Polling backoff /// is handled via the JSON body per RFC 8628 (`interval` field / `slow_down` error), /// not via HTTP headers. -const FORWARDED_HEADERS: &[header::HeaderName] = &[ - header::CONTENT_TYPE, -]; +const FORWARDED_HEADERS: &[header::HeaderName] = &[header::CONTENT_TYPE]; /// Copy only allowlisted headers from an upstream response to avoid forwarding /// dangerous headers such as `Set-Cookie`, `Location`, or CORS headers. @@ -296,7 +294,11 @@ fn handle_token_parse_failure( Ok(error) => { // We found an error in the message body. This is not ideal, but we at // least can understand what the server was trying to tell us - tracing::debug!(?error, provider_name, "Parsed error response from OAuth provider"); + tracing::debug!( + ?error, + provider_name, + "Parsed error response from OAuth provider" + ); let mut client_response = Response::new(Body::from(bytes)); *client_response.headers_mut() = filter_upstream_headers(&headers); @@ -395,11 +397,17 @@ mod tests { // CORS headers must NOT be forwarded from upstream assert!( - response.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).is_none(), + response + .headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .is_none(), "Upstream CORS header must not be forwarded to the client" ); assert!( - response.headers().get(header::ACCESS_CONTROL_ALLOW_CREDENTIALS).is_none(), + response + .headers() + .get(header::ACCESS_CONTROL_ALLOW_CREDENTIALS) + .is_none(), "Upstream CORS credentials header must not be forwarded to the client" ); } @@ -426,7 +434,8 @@ mod tests { let body = Bytes::from_static( b"{\"error\": \"slow_down\", \"error_description\": null, \"error_uri\": null}", ); - let response = handle_token_parse_failure("test-provider", body, upstream_headers, StatusCode::OK); + let response = + handle_token_parse_failure("test-provider", body, upstream_headers, StatusCode::OK); // Dangerous headers must NOT be forwarded assert!( @@ -461,7 +470,8 @@ mod tests { let body = Bytes::from_static( b"{\"error\": \"authorization_pending\", \"error_description\": null, \"error_uri\": null}", ); - let response = handle_token_parse_failure("test-provider", body, upstream_headers, StatusCode::OK); + let response = + handle_token_parse_failure("test-provider", body, upstream_headers, StatusCode::OK); // The Set-Cookie header must NOT be forwarded assert!( diff --git a/v-api/src/endpoints/login/oauth/mod.rs b/v-api/src/endpoints/login/oauth/mod.rs index 6424299b..bc071c2c 100644 --- a/v-api/src/endpoints/login/oauth/mod.rs +++ b/v-api/src/endpoints/login/oauth/mod.rs @@ -143,7 +143,11 @@ where tracing::trace!(?status, "Received response from OAuth provider"); if !status.is_success() { - tracing::error!(?status, endpoint, "User info endpoint returned non-success status"); + tracing::error!( + ?status, + endpoint, + "User info endpoint returned non-success status" + ); return Err(UserInfoError::UnexpectedStatus { endpoint: endpoint.to_string(), status, diff --git a/v-model/src/lib.rs b/v-model/src/lib.rs index bd320a6d..ac7b2f7e 100644 --- a/v-model/src/lib.rs +++ b/v-model/src/lib.rs @@ -15,10 +15,7 @@ use schema_ext::MagicLinkAttemptState; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::{ - collections::BTreeSet, - fmt::Display, -}; +use std::{collections::BTreeSet, fmt::Display}; use thiserror::Error; use url::Url; From 1b284c2f253f5bb9897a73fbb34c840d840bdcd4 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 16:29:30 -0500 Subject: [PATCH 23/51] More clippy fixes --- v-cli-sdk/src/cmd/auth/login.rs | 2 ++ v-cli-sdk/src/cmd/auth/oauth/code.rs | 17 ++++++++++------- v-cli-sdk/src/cmd/auth/oauth/mod.rs | 22 +++++++++++++++------- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/v-cli-sdk/src/cmd/auth/login.rs b/v-cli-sdk/src/cmd/auth/login.rs index 64b2b36b..3e21691a 100644 --- a/v-cli-sdk/src/cmd/auth/login.rs +++ b/v-cli-sdk/src/cmd/auth/login.rs @@ -266,11 +266,13 @@ pub trait CliMagicLinkAdapter { type Token: CliAdapterToken; type Error: StdError + Send + Sync + 'static; + #[allow(clippy::type_complexity)] fn create_attempt( &self, email: &str, scope: Option<&str>, ) -> Pin> + Send>>; + #[allow(clippy::type_complexity)] fn exchange( &self, attempt: Self::Attempt, diff --git a/v-cli-sdk/src/cmd/auth/oauth/code.rs b/v-cli-sdk/src/cmd/auth/oauth/code.rs index c8e520dc..70ecc0b5 100644 --- a/v-cli-sdk/src/cmd/auth/oauth/code.rs +++ b/v-cli-sdk/src/cmd/auth/oauth/code.rs @@ -117,6 +117,7 @@ impl CodeOAuth { // Channel to receive the token extracted from the server response. let (token_tx, token_rx) = oneshot::channel::>(); + #[allow(clippy::type_complexity)] let token_tx: Arc>>>> = Arc::new(Mutex::new(Some(token_tx))); @@ -157,13 +158,15 @@ impl CodeOAuth { // Forward the redirect request to the API server. let token = adapter .exchange_authorization_code( - crate::cmd::auth::login::LoginProvider::Zendesk, - client_id, - redirect_uri.clone(), - "authorization_code".to_string(), - code, - pkce_verifier, - request_idp_token, + super::AuthorizationCodeExchange { + provider: crate::cmd::auth::login::LoginProvider::Zendesk, + client_id, + redirect_uri: redirect_uri.clone(), + grant_type: "authorization_code".to_string(), + code, + pkce_verifier, + request_idp_token, + }, ) .await .map_err(|e| anyhow::anyhow!(e))?; diff --git a/v-cli-sdk/src/cmd/auth/oauth/mod.rs b/v-cli-sdk/src/cmd/auth/oauth/mod.rs index 8d7841a0..8a65605a 100644 --- a/v-cli-sdk/src/cmd/auth/oauth/mod.rs +++ b/v-cli-sdk/src/cmd/auth/oauth/mod.rs @@ -12,25 +12,33 @@ pub mod device; use crate::cmd::auth::login::CliAdapterToken; +/// Parameters for exchanging an authorization code for an access token. +pub struct AuthorizationCodeExchange { + pub provider: super::login::LoginProvider, + pub client_id: Uuid, + pub redirect_uri: String, + pub grant_type: String, + pub code: String, + pub pkce_verifier: PkceCodeVerifier, + pub request_idp_token: bool, +} + pub trait CliOAuthAdapter { type ShortToken: CliAdapterToken + Send + 'static; type LongToken: CliAdapterToken + Send + 'static; type Error: StdError + Send + Sync + 'static; + #[allow(clippy::type_complexity)] fn provider( &self, provider: super::login::LoginProvider, ) -> Pin> + Send>>; + #[allow(clippy::type_complexity)] fn exchange_authorization_code( &self, - provider: super::login::LoginProvider, - client_id: Uuid, - redirect_uri: String, - grant_type: String, - code: String, - pkce_verifier: PkceCodeVerifier, - request_idp_token: bool, + exchange: AuthorizationCodeExchange, ) -> Pin> + Send>>; + #[allow(clippy::type_complexity)] fn get_long_lived_token( &self, access_token: &str, From bf412d0006f9b2f3fbb314bc91cc7fc59eeafdcc Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 16:29:39 -0500 Subject: [PATCH 24/51] Fmt --- v-cli-sdk/src/cmd/auth/oauth/code.rs | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/v-cli-sdk/src/cmd/auth/oauth/code.rs b/v-cli-sdk/src/cmd/auth/oauth/code.rs index 70ecc0b5..0177d874 100644 --- a/v-cli-sdk/src/cmd/auth/oauth/code.rs +++ b/v-cli-sdk/src/cmd/auth/oauth/code.rs @@ -157,17 +157,15 @@ impl CodeOAuth { // Forward the redirect request to the API server. let token = adapter - .exchange_authorization_code( - super::AuthorizationCodeExchange { - provider: crate::cmd::auth::login::LoginProvider::Zendesk, - client_id, - redirect_uri: redirect_uri.clone(), - grant_type: "authorization_code".to_string(), - code, - pkce_verifier, - request_idp_token, - }, - ) + .exchange_authorization_code(super::AuthorizationCodeExchange { + provider: crate::cmd::auth::login::LoginProvider::Zendesk, + client_id, + redirect_uri: redirect_uri.clone(), + grant_type: "authorization_code".to_string(), + code, + pkce_verifier, + request_idp_token, + }) .await .map_err(|e| anyhow::anyhow!(e))?; From 91b18c3fe219ba4cfbcf02009c505c739265bcf0 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 18:36:24 -0500 Subject: [PATCH 25/51] More spec compliance --- v-api/src/endpoints/login/mod.rs | 86 +++++++++++++++++++ v-api/src/endpoints/login/oauth/flow/code.rs | 42 +++++++++ .../login/oauth/flow/device_token.rs | 55 +++++++++++- 3 files changed, 182 insertions(+), 1 deletion(-) diff --git a/v-api/src/endpoints/login/mod.rs b/v-api/src/endpoints/login/mod.rs index 6c9ef239..0a55ea85 100644 --- a/v-api/src/endpoints/login/mod.rs +++ b/v-api/src/endpoints/login/mod.rs @@ -244,7 +244,93 @@ pub fn is_redirect_uri_valid<'a>( && registered.host() == candidate.host() && registered.port() == candidate.port() && registered.path() == candidate.path() + && registered.query() == candidate.query() } Err(_) => false, }) } + +#[cfg(test)] +mod tests { + use super::is_redirect_uri_valid; + + #[test] + fn test_redirect_uri_exact_match() { + assert!(is_redirect_uri_valid( + "https://example.com/callback", + ["https://example.com/callback"].iter().copied(), + )); + } + + #[test] + fn test_redirect_uri_rejects_different_host() { + assert!(!is_redirect_uri_valid( + "https://evil.com/callback", + ["https://example.com/callback"].iter().copied(), + )); + } + + #[test] + fn test_redirect_uri_rejects_different_path() { + assert!(!is_redirect_uri_valid( + "https://example.com/other", + ["https://example.com/callback"].iter().copied(), + )); + } + + #[test] + fn test_redirect_uri_rejects_fragment() { + assert!(!is_redirect_uri_valid( + "https://example.com/callback#fragment", + ["https://example.com/callback"].iter().copied(), + )); + } + + #[test] + fn test_redirect_uri_rejects_unparseable() { + assert!(!is_redirect_uri_valid( + "not-a-url", + ["https://example.com/callback"].iter().copied(), + )); + } + + #[test] + fn test_redirect_uri_query_params_must_match() { + // Registered with query params — candidate must have the same query + assert!(is_redirect_uri_valid( + "https://example.com/callback?key=value", + ["https://example.com/callback?key=value"].iter().copied(), + )); + + // Different query param value must be rejected + assert!(!is_redirect_uri_valid( + "https://example.com/callback?key=evil", + ["https://example.com/callback?key=value"].iter().copied(), + )); + + // Missing query params when registered URI has them must be rejected + assert!(!is_redirect_uri_valid( + "https://example.com/callback", + ["https://example.com/callback?key=value"].iter().copied(), + )); + + // Extra query params when registered URI has none must be rejected + assert!(!is_redirect_uri_valid( + "https://example.com/callback?extra=param", + ["https://example.com/callback"].iter().copied(), + )); + } + + #[test] + fn test_redirect_uri_matches_with_port() { + assert!(is_redirect_uri_valid( + "https://example.com:8443/callback", + ["https://example.com:8443/callback"].iter().copied(), + )); + + assert!(!is_redirect_uri_valid( + "https://example.com:9999/callback", + ["https://example.com:8443/callback"].iter().copied(), + )); + } +} diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index d5abcb73..2cbc8949 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -130,6 +130,20 @@ pub struct OAuthAuthzCodeRedirectHeaders { location: String, } +/// Validate that response_type is "code" per RFC 6749 §4.1.1. +fn validate_response_type(response_type: &str) -> Result<(), OAuthError> { + if response_type == "code" { + Ok(()) + } else { + Err(OAuthError { + error: OAuthErrorCode::UnsupportedResponseType, + error_description: Some("Only response_type=code is supported".to_string()), + error_uri: None, + state: None, + }) + } +} + // Lookup the client specified by the provided client id and verify that the redirect uri // is a valid for this client. If either of these fail we return an unauthorized response async fn get_oauth_client( @@ -220,6 +234,9 @@ where tracing::debug!(?query.client_id, ?query.redirect_uri, "Verified client id and redirect uri"); + // Validate response_type. Only "code" is supported (RFC 6749 §4.1.1). + validate_response_type(&query.response_type)?; + // Validate the client's PKCE challenge method. Only S256 is supported. if query.code_challenge_method != "S256" { return Err(OAuthError { @@ -2024,4 +2041,29 @@ mod tests { Some(cookie::time::Duration::seconds(0)) ); } + + #[test] + fn test_valid_response_type_is_accepted() { + assert!(super::validate_response_type("code").is_ok()); + } + + #[test] + fn test_invalid_response_type_is_rejected() { + let err = super::validate_response_type("token").unwrap_err(); + assert_eq!(err.error, OAuthErrorCode::UnsupportedResponseType); + } + + #[test] + fn test_empty_response_type_is_rejected() { + assert!(super::validate_response_type("").is_err()); + } + + #[test] + fn test_response_type_rejects_similar_values() { + assert!(super::validate_response_type("Code").is_err()); + assert!(super::validate_response_type("CODE").is_err()); + assert!(super::validate_response_type("code ").is_err()); + assert!(super::validate_response_type("token").is_err()); + assert!(super::validate_response_type("code token").is_err()); + } } diff --git a/v-api/src/endpoints/login/oauth/flow/device_token.rs b/v-api/src/endpoints/login/oauth/flow/device_token.rs index 064ff2d5..728a6c74 100644 --- a/v-api/src/endpoints/login/oauth/flow/device_token.rs +++ b/v-api/src/endpoints/login/oauth/flow/device_token.rs @@ -146,6 +146,26 @@ where let device_info = device_info.unwrap(); let exchange_request = body.into_inner(); + + // Validate grant_type per RFC 8628 §3.4 + if !validate_device_grant_type(&exchange_request.grant_type) { + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .header(header::CONTENT_TYPE, "application/json") + .body( + serde_json::to_vec(&ProxyTokenError { + error: "unsupported_grant_type".to_string(), + error_description: Some( + "grant_type must be urn:ietf:params:oauth:grant-type:device_code" + .to_string(), + ), + error_uri: None, + }) + .unwrap() + .into(), + )?); + } + let exchange = AccessTokenExchange::new(exchange_request, device_info); let client = reqwest::Client::new(); @@ -253,6 +273,11 @@ where } } +/// Validate the grant_type for device code exchange per RFC 8628 §3.4. +fn validate_device_grant_type(grant_type: &str) -> bool { + grant_type == "urn:ietf:params:oauth:grant-type:device_code" +} + /// Headers that are safe to forward from an upstream OAuth provider response. /// Only `Content-Type` is needed so the client can parse the body. Polling backoff /// is handled via the JSON body per RFC 8628 (`interval` field / `slow_down` error), @@ -343,7 +368,7 @@ mod tests { }; use hyper::body::Bytes; - use super::{handle_token_parse_failure, proxy_upstream_response}; + use super::{handle_token_parse_failure, proxy_upstream_response, validate_device_grant_type}; #[test] fn test_upstream_set_cookie_is_stripped_from_error_response() { @@ -479,4 +504,32 @@ mod tests { "Upstream Set-Cookie header must not be forwarded via token parse failure path" ); } + + #[test] + fn test_valid_device_grant_type_is_accepted() { + assert!(validate_device_grant_type( + "urn:ietf:params:oauth:grant-type:device_code" + )); + } + + #[test] + fn test_invalid_device_grant_type_is_rejected() { + assert!(!validate_device_grant_type("authorization_code")); + } + + #[test] + fn test_empty_device_grant_type_is_rejected() { + assert!(!validate_device_grant_type("")); + } + + #[test] + fn test_device_grant_type_rejects_similar_values() { + assert!(!validate_device_grant_type("device_code")); + assert!(!validate_device_grant_type( + "urn:ietf:params:oauth:grant-type:device_Code" + )); + assert!(!validate_device_grant_type( + "urn:ietf:params:oauth:grant-type:authorization_code" + )); + } } From 6efc48b9cfdf3ad534f981db8e093c6a41e9c187 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 22:45:50 -0500 Subject: [PATCH 26/51] Enfore permission --- v-api/src/endpoints/login/oauth/flow/code.rs | 284 ++++++++++++++++++- 1 file changed, 275 insertions(+), 9 deletions(-) diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 2cbc8949..2bdfe0b3 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -26,7 +26,7 @@ use tap::TapFallible; use tracing::instrument; use uuid::Uuid; use v_model::{ - permissions::{AsScope, PermissionStorage}, + permissions::{AsScope, PermissionStorage, Permissions}, schema_ext::LoginAttemptState, LoginAttempt, LoginAttemptId, NewLoginAttempt, OAuthClient, OAuthClientId, }; @@ -734,15 +734,35 @@ where !query.request_idp_token, ) .await?; - let idp_token = info.idp_token.clone(); tracing::debug!("Retrieved user information from remote provider"); + complete_exchange(ctx, info, &attempt, query.request_idp_token).await +} + +async fn complete_exchange( + ctx: &VContext, + info: UserInfo, + attempt: &LoginAttempt, + request_idp_token: bool, +) -> Result, HttpError> +where + T: VAppPermission + PermissionStorage, +{ + let idp_token = info.idp_token.clone(); + // Register this user as an API user if needed let (api_user_info, api_user_provider) = ctx .register_api_user(&ctx.builtin_registration_user(), info) .await?; + // Only return the IdP token if the caller requested it AND the user has permission + let idp_token = filter_idp_token( + idp_token, + request_idp_token, + &api_user_info.user.permissions, + ); + tracing::info!(api_user_id = ?api_user_info.user.id, "Retrieved api user to generate access token for"); let scope = attempt @@ -768,6 +788,25 @@ where })) } +/// Filter the IdP token based on whether it was requested and whether the user has +/// the `RetrieveRemoteAccessToken` permission. Returns `None` if either condition +/// is not met. +fn filter_idp_token(idp_token: Option, requested: bool, permissions: &Permissions) -> Option +where + T: VAppPermission, +{ + if !requested { + return None; + } + + if permissions.can(&VPermission::RetrieveRemoteAccessToken.into()) { + idp_token + } else { + tracing::info!("User requested IdP token but lacks RetrieveRemoteAccessToken permission"); + None + } +} + async fn authorize_code_exchange( ctx: &VContext, provider: &dyn OAuthProvider, @@ -988,8 +1027,12 @@ mod tests { use uuid::Uuid; use v_model::{ schema_ext::LoginAttemptState, - storage::{MockLoginAttemptStore, MockOAuthClientStore}, - LoginAttempt, OAuthClient, OAuthClientRedirectUri, OAuthClientSecret, + storage::{ + MockAccessTokenStore, MockApiUserProviderStore, MockApiUserStore, + MockLoginAttemptStore, MockMapperStore, MockOAuthClientStore, + }, + AccessToken, ApiUser, ApiUserInfo, ApiUserProvider, LoginAttempt, NewApiUser, + NewApiUserProvider, OAuthClient, OAuthClientRedirectUri, OAuthClientSecret, }; use crate::{ @@ -998,12 +1041,15 @@ mod tests { test_mocks::{mock_context, MockStorage}, VContext, }, - endpoints::login::oauth::{ - flow::code::{ - authz_code_callback_op_inner, verify_csrf, verify_login_attempt, - OAuthAuthzCodeReturnQuery, OAuthError, OAuthErrorCode, LOGIN_ATTEMPT_COOKIE, + endpoints::login::{ + oauth::{ + flow::code::{ + authz_code_callback_op_inner, verify_csrf, verify_login_attempt, + OAuthAuthzCodeReturnQuery, OAuthError, OAuthErrorCode, LOGIN_ATTEMPT_COOKIE, + }, + OAuthProviderName, }, - OAuthProviderName, + ExternalUserId, UserInfo, }, permissions::VPermission, }; @@ -2066,4 +2112,224 @@ mod tests { assert!(super::validate_response_type("token").is_err()); assert!(super::validate_response_type("code token").is_err()); } + + #[test] + fn test_filter_idp_token_returns_token_when_requested_and_permitted() { + let permissions: v_model::permissions::Permissions = + vec![VPermission::RetrieveRemoteAccessToken].into(); + let token = Some("idp-token-value".to_string()); + + let result = super::filter_idp_token(token, true, &permissions); + assert_eq!(result, Some("idp-token-value".to_string())); + } + + #[test] + fn test_filter_idp_token_returns_none_when_not_requested() { + let permissions: v_model::permissions::Permissions = + vec![VPermission::RetrieveRemoteAccessToken].into(); + let token = Some("idp-token-value".to_string()); + + // Even with the permission, if not requested the token is not returned + let result = super::filter_idp_token(token, false, &permissions); + assert_eq!(result, None); + } + + #[test] + fn test_filter_idp_token_returns_none_when_permission_missing() { + // User has some permissions but not RetrieveRemoteAccessToken + let permissions: v_model::permissions::Permissions = + vec![VPermission::CreateApiUser].into(); + let token = Some("idp-token-value".to_string()); + + let result = super::filter_idp_token(token, true, &permissions); + assert_eq!(result, None); + } + + #[test] + fn test_filter_idp_token_returns_none_when_no_permissions() { + let permissions: v_model::permissions::Permissions = + Vec::::new().into(); + let token = Some("idp-token-value".to_string()); + + let result = super::filter_idp_token(token, true, &permissions); + assert_eq!(result, None); + } + + #[test] + fn test_filter_idp_token_returns_none_when_token_is_none() { + let permissions: v_model::permissions::Permissions = + vec![VPermission::RetrieveRemoteAccessToken].into(); + + // Token was None (e.g. revoked upstream) — should stay None regardless of permission + let result = super::filter_idp_token(None, true, &permissions); + assert_eq!(result, None); + } + + /// Set up mock storage for `complete_exchange` tests. The registered user will + /// have the given `user_permissions`. + fn mock_exchange_storage( + user_permissions: Vec, + ) -> MockStorage { + // ApiUserProviderStore: list returns empty (new user), upsert returns a provider + let mut provider_store = MockApiUserProviderStore::new(); + provider_store + .expect_list() + .returning(move |_, _| Ok(vec![])); + provider_store + .expect_upsert() + .returning(move |p: NewApiUserProvider| { + Ok(ApiUserProvider { + id: p.id, + user_id: p.user_id, + provider: p.provider, + provider_id: p.provider_id, + emails: p.emails, + display_names: p.display_names, + created_at: Utc::now(), + updated_at: Utc::now(), + deleted_at: None, + }) + }); + + // ApiUserStore: upsert creates a user with the specified permissions + let mut user_store = MockApiUserStore::new(); + user_store + .expect_upsert() + .returning(move |u: NewApiUser| { + Ok(ApiUserInfo { + user: ApiUser { + id: u.id, + permissions: user_permissions.clone().into(), + groups: u.groups, + created_at: Utc::now(), + updated_at: Utc::now(), + deleted_at: None, + }, + email: None, + providers: vec![], + }) + }); + + // MapperStore: list returns empty (no mappers configured) + let mut mapper_store = MockMapperStore::new(); + mapper_store + .expect_list() + .returning(|_, _| Ok(vec![])); + + // AccessTokenStore: upsert returns a token + let mut access_token_store = MockAccessTokenStore::new(); + access_token_store + .expect_upsert() + .returning(|token| { + Ok(AccessToken { + id: token.id, + user_id: token.user_id, + revoked_at: None, + created_at: Utc::now(), + updated_at: Utc::now(), + }) + }); + + let mut storage = MockStorage::new(); + storage.api_user_provider_store = Some(Arc::new(provider_store)); + storage.api_user_store = Some(Arc::new(user_store)); + storage.mapper_store = Some(Arc::new(mapper_store)); + storage.access_token_store = Some(Arc::new(access_token_store)); + storage + } + + fn mock_user_info_with_idp_token() -> UserInfo { + UserInfo { + external_id: ExternalUserId::Google("test-google-id".to_string()), + verified_emails: vec!["user@example.com".to_string()], + display_name: Some("Test User".to_string()), + idp_token: Some("secret-upstream-token".to_string()), + } + } + + fn mock_completed_attempt() -> LoginAttempt { + LoginAttempt { + id: TypedUuid::new_v4(), + attempt_state: LoginAttemptState::Complete, + client_id: TypedUuid::new_v4(), + redirect_uri: "https://example.com/callback".to_string(), + state: Some("test-state".to_string()), + pkce_challenge: Some("test-challenge".to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: Some("test-code".to_string()), + expires_at: Some(Utc::now().add(TimeDelta::try_seconds(300).unwrap())), + error: None, + provider: "google".to_string(), + provider_pkce_verifier: None, + provider_authz_code: Some("remote-code".to_string()), + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + scope: "user:info:r".to_string(), + } + } + + #[tokio::test] + async fn test_exchange_returns_idp_token_when_requested_and_permitted() { + let storage = mock_exchange_storage(vec![ + VPermission::CreateAccessToken, + VPermission::RetrieveRemoteAccessToken, + ]); + let ctx = mock_context(Arc::new(storage)).await; + let attempt = mock_completed_attempt(); + let info = mock_user_info_with_idp_token(); + + let response = super::complete_exchange(&ctx, info, &attempt, true) + .await + .unwrap() + .0; + + assert_eq!( + response.idp_token, + Some("secret-upstream-token".to_string()), + "IdP token must be returned when requested and user has RetrieveRemoteAccessToken" + ); + } + + #[tokio::test] + async fn test_exchange_omits_idp_token_when_permission_missing() { + let storage = mock_exchange_storage(vec![ + VPermission::CreateAccessToken, + // Notably missing: VPermission::RetrieveRemoteAccessToken + ]); + let ctx = mock_context(Arc::new(storage)).await; + let attempt = mock_completed_attempt(); + let info = mock_user_info_with_idp_token(); + + let response = super::complete_exchange(&ctx, info, &attempt, true) + .await + .unwrap() + .0; + + assert_eq!( + response.idp_token, None, + "IdP token must NOT be returned when user lacks RetrieveRemoteAccessToken" + ); + } + + #[tokio::test] + async fn test_exchange_omits_idp_token_when_not_requested() { + let storage = mock_exchange_storage(vec![ + VPermission::CreateAccessToken, + VPermission::RetrieveRemoteAccessToken, + ]); + let ctx = mock_context(Arc::new(storage)).await; + let attempt = mock_completed_attempt(); + let info = mock_user_info_with_idp_token(); + + let response = super::complete_exchange(&ctx, info, &attempt, false) + .await + .unwrap() + .0; + + assert_eq!( + response.idp_token, None, + "IdP token must NOT be returned when not requested, even with permission" + ); + } } From 61646b1004d3ed00b0b150e3f93e65899e453c6f Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 22:45:57 -0500 Subject: [PATCH 27/51] Fmt --- v-api/src/endpoints/login/oauth/flow/code.rs | 34 +++++++++----------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 2bdfe0b3..3aa10e7c 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -791,7 +791,11 @@ where /// Filter the IdP token based on whether it was requested and whether the user has /// the `RetrieveRemoteAccessToken` permission. Returns `None` if either condition /// is not met. -fn filter_idp_token(idp_token: Option, requested: bool, permissions: &Permissions) -> Option +fn filter_idp_token( + idp_token: Option, + requested: bool, + permissions: &Permissions, +) -> Option where T: VAppPermission, { @@ -2167,9 +2171,7 @@ mod tests { /// Set up mock storage for `complete_exchange` tests. The registered user will /// have the given `user_permissions`. - fn mock_exchange_storage( - user_permissions: Vec, - ) -> MockStorage { + fn mock_exchange_storage(user_permissions: Vec) -> MockStorage { // ApiUserProviderStore: list returns empty (new user), upsert returns a provider let mut provider_store = MockApiUserProviderStore::new(); provider_store @@ -2212,23 +2214,19 @@ mod tests { // MapperStore: list returns empty (no mappers configured) let mut mapper_store = MockMapperStore::new(); - mapper_store - .expect_list() - .returning(|_, _| Ok(vec![])); + mapper_store.expect_list().returning(|_, _| Ok(vec![])); // AccessTokenStore: upsert returns a token let mut access_token_store = MockAccessTokenStore::new(); - access_token_store - .expect_upsert() - .returning(|token| { - Ok(AccessToken { - id: token.id, - user_id: token.user_id, - revoked_at: None, - created_at: Utc::now(), - updated_at: Utc::now(), - }) - }); + access_token_store.expect_upsert().returning(|token| { + Ok(AccessToken { + id: token.id, + user_id: token.user_id, + revoked_at: None, + created_at: Utc::now(), + updated_at: Utc::now(), + }) + }); let mut storage = MockStorage::new(); storage.api_user_provider_store = Some(Arc::new(provider_store)); From e7de8f140ae9c637616d07538398b01fd296ffcc Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 23:20:03 -0500 Subject: [PATCH 28/51] Merge fixes --- v-api/src/endpoints/login/magic_link/client.rs | 1 - v-api/src/endpoints/login/oauth/client.rs | 1 - v-api/src/endpoints/login/oauth/flow/code.rs | 8 +++----- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/v-api/src/endpoints/login/magic_link/client.rs b/v-api/src/endpoints/login/magic_link/client.rs index 4935c5b4..a1981689 100644 --- a/v-api/src/endpoints/login/magic_link/client.rs +++ b/v-api/src/endpoints/login/magic_link/client.rs @@ -22,7 +22,6 @@ use crate::{ permissions::{VAppPermission, VPermission}, secrets::OpenApiSecretString, util::response::{bad_request, to_internal_error}, - VContext, }; #[instrument(skip(rqctx), err(Debug))] diff --git a/v-api/src/endpoints/login/oauth/client.rs b/v-api/src/endpoints/login/oauth/client.rs index 31b67fc8..22e0e0df 100644 --- a/v-api/src/endpoints/login/oauth/client.rs +++ b/v-api/src/endpoints/login/oauth/client.rs @@ -22,7 +22,6 @@ use crate::{ permissions::{VAppPermission, VPermission}, secrets::OpenApiSecretString, util::response::{bad_request, to_internal_error}, - VContext, }; #[instrument(skip(rqctx), err(Debug))] diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 87dce269..db73b7d3 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -25,8 +25,7 @@ use tap::TapFallible; use tracing::instrument; use uuid::Uuid; use v_model::{ - permissions::{AsScope, PermissionStorage, Permissions}, - schema_ext::LoginAttemptState, + LoginAttempt, LoginAttemptId, NewLoginAttempt, OAuthClient, OAuthClientId, permissions::{AsScope, PermissionStorage, Permissions}, schema_ext::LoginAttemptState }; use super::super::{OAuthProvider, OAuthProviderNameParam}; @@ -37,7 +36,6 @@ use crate::{ endpoints::login::{ oauth::{CheckOAuthClient, ClientType, OAuthProviderAuthorizationCodePkceInfo}, LoginError, UserInfo, - oauth::{CheckOAuthClient, ClientType}, }, error::ApiError, permissions::{VAppPermission, VPermission}, @@ -1035,8 +1033,8 @@ mod tests { MockAccessTokenStore, MockApiUserProviderStore, MockApiUserStore, MockLoginAttemptStore, MockMapperStore, MockOAuthClientStore, }, - AccessToken, ApiUser, ApiUserInfo, ApiUserProvider, LoginAttempt, NewApiUser, - NewApiUserProvider, OAuthClient, OAuthClientRedirectUri, OAuthClientSecret, + AccessToken, ApiUser, ApiUserInfo, ApiUserProvider, NewApiUser, + NewApiUserProvider, }; use crate::{ From 1733c8d61b4c4aee0d4f6d2d48909477a63b74d3 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 23:20:17 -0500 Subject: [PATCH 29/51] Fmt --- v-api/src/context/mod.rs | 6 ++--- v-api/src/endpoints/login/oauth/flow/code.rs | 25 +++++++++++-------- .../login/oauth/flow/device_token.rs | 8 +++--- v-api/src/endpoints/login/oauth/mod.rs | 6 ++--- .../endpoints/login/oauth/remote/github.rs | 2 +- .../endpoints/login/oauth/remote/google.rs | 2 +- .../endpoints/login/oauth/remote/zendesk.rs | 2 +- v-cli-sdk/src/cmd/auth/login.rs | 2 +- v-cli-sdk/src/cmd/auth/mod.rs | 2 +- v-cli-sdk/src/cmd/auth/oauth/code.rs | 4 +-- v-cli-sdk/src/cmd/auth/oauth/device.rs | 2 +- v-cli-sdk/src/err.rs | 2 +- v-cli-sdk/src/lib.rs | 4 +-- 13 files changed, 36 insertions(+), 31 deletions(-) diff --git a/v-api/src/context/mod.rs b/v-api/src/context/mod.rs index e2eea0f5..db20eea3 100644 --- a/v-api/src/context/mod.rs +++ b/v-api/src/context/mod.rs @@ -1263,14 +1263,14 @@ pub(crate) mod test_mocks { }; use crate::{ + VContextBuilder, config::{ JwtConfig, ResolvedOAuthConfig, ResolvedOAuthWebConfig, ResolvedOAuthWebProxyConfig, }, endpoints::login::oauth::{ - remote::google::GoogleOAuthProvider, remote::zendesk::ZendeskOAuthProvider, - OAuthProviderName, + OAuthProviderName, remote::google::GoogleOAuthProvider, + remote::zendesk::ZendeskOAuthProvider, }, - VContextBuilder, mapper::DefaultMappingEngine, permissions::VPermission, util::tests::{MockKey, mock_key}, diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index db73b7d3..a3bcaebf 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -10,7 +10,7 @@ use dropshot::{ RequestContext, RequestInfo, SharedExtractor, TypedBody, http_response_temporary_redirect, }; use dropshot_authorization_header::basic::BasicAuth; -use http::{header::SET_COOKIE, HeaderValue}; +use http::{HeaderValue, header::SET_COOKIE}; use newtype_uuid::{GenericUuid, TypedUuid}; use oauth2::{ AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, @@ -25,7 +25,9 @@ use tap::TapFallible; use tracing::instrument; use uuid::Uuid; use v_model::{ - LoginAttempt, LoginAttemptId, NewLoginAttempt, OAuthClient, OAuthClientId, permissions::{AsScope, PermissionStorage, Permissions}, schema_ext::LoginAttemptState + LoginAttempt, LoginAttemptId, NewLoginAttempt, OAuthClient, OAuthClientId, + permissions::{AsScope, PermissionStorage, Permissions}, + schema_ext::LoginAttemptState, }; use super::super::{OAuthProvider, OAuthProviderNameParam}; @@ -34,8 +36,8 @@ use crate::{ authn::key::RawKey, context::{ApiContext, VContext}, endpoints::login::{ - oauth::{CheckOAuthClient, ClientType, OAuthProviderAuthorizationCodePkceInfo}, LoginError, UserInfo, + oauth::{CheckOAuthClient, ClientType, OAuthProviderAuthorizationCodePkceInfo}, }, error::ApiError, permissions::{VAppPermission, VPermission}, @@ -1027,14 +1029,13 @@ mod tests { use secrecy::SecretString; use uuid::Uuid; use v_model::{ - LoginAttempt, OAuthClient, OAuthClientRedirectUri, OAuthClientSecret, + AccessToken, ApiUser, ApiUserInfo, ApiUserProvider, LoginAttempt, NewApiUser, + NewApiUserProvider, OAuthClient, OAuthClientRedirectUri, OAuthClientSecret, schema_ext::LoginAttemptState, storage::{ MockAccessTokenStore, MockApiUserProviderStore, MockApiUserStore, MockLoginAttemptStore, MockMapperStore, MockOAuthClientStore, }, - AccessToken, ApiUser, ApiUserInfo, ApiUserProvider, NewApiUser, - NewApiUserProvider, }; use crate::{ @@ -1044,14 +1045,14 @@ mod tests { test_mocks::{MockStorage, mock_context}, }, endpoints::login::{ + ExternalUserId, UserInfo, oauth::{ + OAuthProviderName, flow::code::{ + LOGIN_ATTEMPT_COOKIE, OAuthAuthzCodeReturnQuery, OAuthError, OAuthErrorCode, authz_code_callback_op_inner, verify_csrf, verify_login_attempt, - OAuthAuthzCodeReturnQuery, OAuthError, OAuthErrorCode, LOGIN_ATTEMPT_COOKIE, }, - OAuthProviderName, }, - ExternalUserId, UserInfo, }, permissions::VPermission, }; @@ -1174,7 +1175,11 @@ mod tests { .unwrap(); let headers = response.headers(); - let expected_location = format!("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=google_web_client_id&state={}&code_challenge={}&code_challenge_method=S256&redirect_uri=https%3A%2F%2Ftest_public_url%2Flogin%2Foauth%2Fgoogle%2Fcode%2Fcallback&scope=openid+email+profile", attempt.id, challenge.as_str()); + let expected_location = format!( + "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=google_web_client_id&state={}&code_challenge={}&code_challenge_method=S256&redirect_uri=https%3A%2F%2Ftest_public_url%2Flogin%2Foauth%2Fgoogle%2Fcode%2Fcallback&scope=openid+email+profile", + attempt.id, + challenge.as_str() + ); assert_eq!( expected_location, diff --git a/v-api/src/endpoints/login/oauth/flow/device_token.rs b/v-api/src/endpoints/login/oauth/flow/device_token.rs index 728a6c74..b2521770 100644 --- a/v-api/src/endpoints/login/oauth/flow/device_token.rs +++ b/v-api/src/endpoints/login/oauth/flow/device_token.rs @@ -4,9 +4,9 @@ use chrono::{DateTime, Utc}; use dropshot::{Body, HttpError, HttpResponseOk, Method, Path, RequestContext, TypedBody}; -use http::{header, HeaderMap, HeaderValue, Response, StatusCode}; +use http::{HeaderMap, HeaderValue, Response, StatusCode, header}; use hyper::body::Bytes; -use oauth2::{basic::BasicTokenType, EmptyExtraTokenFields, StandardTokenResponse, TokenResponse}; +use oauth2::{EmptyExtraTokenFields, StandardTokenResponse, TokenResponse, basic::BasicTokenType}; use schemars::JsonSchema; use secrecy::ExposeSecret; use serde::{Deserialize, Serialize}; @@ -19,7 +19,7 @@ use super::super::OAuthProviderNameParam; use crate::endpoints::login::UserInfoProvider; use crate::{ context::ApiContext, - endpoints::login::{oauth::OAuthProviderDeviceInfo, LoginError}, + endpoints::login::{LoginError, oauth::OAuthProviderDeviceInfo}, error::ApiError, permissions::VAppPermission, response::internal_error, @@ -363,8 +363,8 @@ fn handle_token_parse_failure( #[cfg(test)] mod tests { use http::{ - header::{self, HeaderName, SET_COOKIE}, HeaderMap, HeaderValue, StatusCode, + header::{self, HeaderName, SET_COOKIE}, }; use hyper::body::Bytes; diff --git a/v-api/src/endpoints/login/oauth/mod.rs b/v-api/src/endpoints/login/oauth/mod.rs index 4056aaab..6a3ebbbc 100644 --- a/v-api/src/endpoints/login/oauth/mod.rs +++ b/v-api/src/endpoints/login/oauth/mod.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use http::Method; -use hyper::{body::Bytes, header::HeaderValue, header::AUTHORIZATION}; +use hyper::{body::Bytes, header::AUTHORIZATION, header::HeaderValue}; use newtype_uuid::TypedUuid; use oauth2::{ AuthUrl, ClientId, ClientSecret, EndpointMaybeSet, EndpointNotSet, EndpointSet, RedirectUrl, @@ -20,11 +20,11 @@ use tracing::instrument; use v_model::{OAuthClient, OAuthClientId}; use crate::{ - authn::{key::RawKey, Verify}, + authn::{Verify, key::RawKey}, secrets::OpenApiSecretString, }; -use super::{is_redirect_uri_valid, UserInfo, UserInfoError, UserInfoProvider}; +use super::{UserInfo, UserInfoError, UserInfoProvider, is_redirect_uri_valid}; pub mod client; pub mod flow; diff --git a/v-api/src/endpoints/login/oauth/remote/github.rs b/v-api/src/endpoints/login/oauth/remote/github.rs index 3c0d9467..4bae1b55 100644 --- a/v-api/src/endpoints/login/oauth/remote/github.rs +++ b/v-api/src/endpoints/login/oauth/remote/github.rs @@ -11,11 +11,11 @@ use std::fmt; use crate::{ config::ResolvedOAuthConfig, endpoints::login::{ + ExternalUserId, UserInfo, UserInfoError, oauth::{ OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo, }, - ExternalUserId, UserInfo, UserInfoError, }, }; diff --git a/v-api/src/endpoints/login/oauth/remote/google.rs b/v-api/src/endpoints/login/oauth/remote/google.rs index 23879e89..2a24e956 100644 --- a/v-api/src/endpoints/login/oauth/remote/google.rs +++ b/v-api/src/endpoints/login/oauth/remote/google.rs @@ -10,11 +10,11 @@ use std::fmt; use crate::{ config::ResolvedOAuthConfig, endpoints::login::{ + ExternalUserId, UserInfo, UserInfoError, oauth::{ OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo, }, - ExternalUserId, UserInfo, UserInfoError, }, }; diff --git a/v-api/src/endpoints/login/oauth/remote/zendesk.rs b/v-api/src/endpoints/login/oauth/remote/zendesk.rs index 51ecd31a..211b8707 100644 --- a/v-api/src/endpoints/login/oauth/remote/zendesk.rs +++ b/v-api/src/endpoints/login/oauth/remote/zendesk.rs @@ -10,11 +10,11 @@ use std::fmt; use crate::{ config::ResolvedOAuthConfig, endpoints::login::{ + ExternalUserId, UserInfo, UserInfoError, oauth::{ OAuthProviderAuthorizationCodeInfo, OAuthProviderAuthorizationCodePkceInfo, OAuthProviderAuthorizationCodeRemoteInfo, OAuthProviderDeviceInfo, }, - ExternalUserId, UserInfo, UserInfoError, }, }; diff --git a/v-cli-sdk/src/cmd/auth/login.rs b/v-cli-sdk/src/cmd/auth/login.rs index 3e21691a..55652adb 100644 --- a/v-cli-sdk/src/cmd/auth/login.rs +++ b/v-cli-sdk/src/cmd/auth/login.rs @@ -8,8 +8,8 @@ use oauth2::TokenResponse; use std::{error::Error as StdError, fmt::Debug, future::Future, io::Write, pin::Pin, sync::Arc}; use crate::{ - cmd::auth::oauth::{self, CliOAuthAdapter, CliOAuthProviderInfo}, VCliConfig, VCliContext, + cmd::auth::oauth::{self, CliOAuthAdapter, CliOAuthProviderInfo}, }; pub trait CliAdapterToken { diff --git a/v-cli-sdk/src/cmd/auth/mod.rs b/v-cli-sdk/src/cmd/auth/mod.rs index ad7d6a24..d4b8bcaa 100644 --- a/v-cli-sdk/src/cmd/auth/mod.rs +++ b/v-cli-sdk/src/cmd/auth/mod.rs @@ -6,7 +6,7 @@ use anyhow::Result; use clap::{Parser, Subcommand}; use std::error::Error as StdError; -use crate::{cmd::auth::login::CliConsumerLoginProvider, VCliContext}; +use crate::{VCliContext, cmd::auth::login::CliConsumerLoginProvider}; pub mod login; pub mod oauth; diff --git a/v-cli-sdk/src/cmd/auth/oauth/code.rs b/v-cli-sdk/src/cmd/auth/oauth/code.rs index 0177d874..5bbfce0b 100644 --- a/v-cli-sdk/src/cmd/auth/oauth/code.rs +++ b/v-cli-sdk/src/cmd/auth/oauth/code.rs @@ -10,8 +10,8 @@ use http_body_util::Full; use hyper::body::{Bytes, Incoming}; use oauth2::{ - basic::BasicClient, AuthType, AuthUrl, ClientId, CsrfToken, EndpointNotSet, EndpointSet, - PkceCodeChallenge, RedirectUrl, Scope, TokenUrl, + AuthType, AuthUrl, ClientId, CsrfToken, EndpointNotSet, EndpointSet, PkceCodeChallenge, + RedirectUrl, Scope, TokenUrl, basic::BasicClient, }; use tokio::sync::oneshot; use uuid::Uuid; diff --git a/v-cli-sdk/src/cmd/auth/oauth/device.rs b/v-cli-sdk/src/cmd/auth/oauth/device.rs index 4acf227b..9ca43df4 100644 --- a/v-cli-sdk/src/cmd/auth/oauth/device.rs +++ b/v-cli-sdk/src/cmd/auth/oauth/device.rs @@ -1,8 +1,8 @@ use anyhow::Result; use oauth2::{ - basic::{BasicClient, BasicTokenType}, AuthType, AuthUrl, ClientId, DeviceAuthorizationUrl, EmptyExtraTokenFields, EndpointNotSet, EndpointSet, Scope, StandardDeviceAuthorizationResponse, StandardTokenResponse, TokenUrl, + basic::{BasicClient, BasicTokenType}, }; use crate::cmd::auth::oauth::CliOAuthProviderInfo; diff --git a/v-cli-sdk/src/err.rs b/v-cli-sdk/src/err.rs index 6f8066a6..0d54b489 100644 --- a/v-cli-sdk/src/err.rs +++ b/v-cli-sdk/src/err.rs @@ -2,7 +2,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at https://mozilla.org/MPL/2.0/. -use anyhow::{anyhow, Error}; +use anyhow::{Error, anyhow}; use progenitor_client::Error as ProgenitorClientError; use crate::{VApiErrorMessage, VCliContext, VerbosityLevel}; diff --git a/v-cli-sdk/src/lib.rs b/v-cli-sdk/src/lib.rs index 27f678a5..bd398174 100644 --- a/v-cli-sdk/src/lib.rs +++ b/v-cli-sdk/src/lib.rs @@ -57,8 +57,8 @@ pub trait VCliContext { LongToken = Self::LongToken, Error = Self::Error, > + Send - + Sync - + 'static; + + Sync + + 'static; fn mlink_adapter( &self, ) -> impl CliMagicLinkAdapter + Send + Sync + 'static; From a539ffe05dd50a99a18b9700a7f72331336063ae Mon Sep 17 00:00:00 2001 From: augustuswm Date: Wed, 6 May 2026 23:21:56 -0500 Subject: [PATCH 30/51] Remove extraneous dep --- Cargo.lock | 1 - v-cli-sdk/Cargo.toml | 1 - 2 files changed, 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f5bbf1ba..521070ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3649,7 +3649,6 @@ name = "v-cli-sdk" version = "0.2.0" dependencies = [ "anyhow", - "chrono", "clap", "http", "http-body-util", diff --git a/v-cli-sdk/Cargo.toml b/v-cli-sdk/Cargo.toml index 030df94a..962e0217 100644 --- a/v-cli-sdk/Cargo.toml +++ b/v-cli-sdk/Cargo.toml @@ -5,7 +5,6 @@ edition = "2021" [dependencies] anyhow = { workspace = true } -chrono = { workspace = true } clap = { workspace = true } http = { workspace = true } http-body-util = { workspace = true } From 0f9b4b3ff97c6d36842e92b70f46983818bdba1c Mon Sep 17 00:00:00 2001 From: Augustus Mayo Date: Wed, 6 May 2026 23:22:41 -0500 Subject: [PATCH 31/51] Update v-cli-sdk/src/cmd/auth/oauth/code.rs Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- v-cli-sdk/src/cmd/auth/oauth/code.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/v-cli-sdk/src/cmd/auth/oauth/code.rs b/v-cli-sdk/src/cmd/auth/oauth/code.rs index 5bbfce0b..39c17c89 100644 --- a/v-cli-sdk/src/cmd/auth/oauth/code.rs +++ b/v-cli-sdk/src/cmd/auth/oauth/code.rs @@ -1,3 +1,7 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + use std::{ future::Future, pin::Pin, From 29f2c7a888ad29a744bf913f3430e60641c6ef07 Mon Sep 17 00:00:00 2001 From: Augustus Mayo Date: Wed, 6 May 2026 23:22:50 -0500 Subject: [PATCH 32/51] Update v-api/src/endpoints/login/oauth/remote/mod.rs Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- v-api/src/endpoints/login/oauth/remote/mod.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/v-api/src/endpoints/login/oauth/remote/mod.rs b/v-api/src/endpoints/login/oauth/remote/mod.rs index db2a1dd9..3a924871 100644 --- a/v-api/src/endpoints/login/oauth/remote/mod.rs +++ b/v-api/src/endpoints/login/oauth/remote/mod.rs @@ -1,3 +1,7 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + pub mod github; pub mod google; pub mod zendesk; From e984b911db28ca0e40b6826a4074b87af26b0ab4 Mon Sep 17 00:00:00 2001 From: Augustus Mayo Date: Wed, 6 May 2026 23:22:59 -0500 Subject: [PATCH 33/51] Update v-cli-sdk/src/cmd/auth/oauth/device.rs Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- v-cli-sdk/src/cmd/auth/oauth/device.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/v-cli-sdk/src/cmd/auth/oauth/device.rs b/v-cli-sdk/src/cmd/auth/oauth/device.rs index 9ca43df4..b1956dbb 100644 --- a/v-cli-sdk/src/cmd/auth/oauth/device.rs +++ b/v-cli-sdk/src/cmd/auth/oauth/device.rs @@ -1,3 +1,7 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + use anyhow::Result; use oauth2::{ AuthType, AuthUrl, ClientId, DeviceAuthorizationUrl, EmptyExtraTokenFields, EndpointNotSet, From 16257bb9053e4b3b1c84293ab908c3610b76320b Mon Sep 17 00:00:00 2001 From: Augustus Mayo Date: Wed, 6 May 2026 23:23:06 -0500 Subject: [PATCH 34/51] Update v-api/src/endpoints/login/oauth/flow/mod.rs Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- v-api/src/endpoints/login/oauth/flow/mod.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/v-api/src/endpoints/login/oauth/flow/mod.rs b/v-api/src/endpoints/login/oauth/flow/mod.rs index 304abf26..305cd9ab 100644 --- a/v-api/src/endpoints/login/oauth/flow/mod.rs +++ b/v-api/src/endpoints/login/oauth/flow/mod.rs @@ -1,2 +1,6 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + pub mod code; pub mod device_token; From a1ce9ed10dd0f8758294457c5a2aaec421375284 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 7 May 2026 08:12:55 -0500 Subject: [PATCH 35/51] Fix local dev endpoints --- v-api/src/endpoints/login/local/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/v-api/src/endpoints/login/local/mod.rs b/v-api/src/endpoints/login/local/mod.rs index ca077496..636f214c 100644 --- a/v-api/src/endpoints/login/local/mod.rs +++ b/v-api/src/endpoints/login/local/mod.rs @@ -13,7 +13,7 @@ use v_model::permissions::PermissionStorage; use crate::{ authn::jwt::Claims, context::ApiContext, - endpoints::login::{ExternalUserId, UserInfo, oauth::device_token::ProxyTokenResponse}, + endpoints::login::{ExternalUserId, UserInfo, oauth::flow::device_token::ProxyTokenResponse}, permissions::{VAppPermission, VPermission}, }; @@ -38,6 +38,7 @@ where external_id: ExternalUserId::Local(body.external_id), verified_emails: vec![body.email], display_name: Some("Local Dev".to_string()), + idp_token: None, }; let (api_user, api_user_provider) = ctx From 2c22bc069ce92029d8150f7a245216b675bb3a4c Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 7 May 2026 08:24:37 -0500 Subject: [PATCH 36/51] Pass down provider --- v-cli-sdk/src/cmd/auth/login.rs | 1 + v-cli-sdk/src/cmd/auth/oauth/code.rs | 8 +++++--- v-cli-sdk/src/cmd/auth/oauth/mod.rs | 3 ++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/v-cli-sdk/src/cmd/auth/login.rs b/v-cli-sdk/src/cmd/auth/login.rs index 55652adb..e641062f 100644 --- a/v-cli-sdk/src/cmd/auth/login.rs +++ b/v-cli-sdk/src/cmd/auth/login.rs @@ -90,6 +90,7 @@ where }, } +#[derive(Copy, Clone)] pub enum LoginProvider { Google, GitHub, diff --git a/v-cli-sdk/src/cmd/auth/oauth/code.rs b/v-cli-sdk/src/cmd/auth/oauth/code.rs index 39c17c89..6ad44012 100644 --- a/v-cli-sdk/src/cmd/auth/oauth/code.rs +++ b/v-cli-sdk/src/cmd/auth/oauth/code.rs @@ -21,8 +21,7 @@ use tokio::sync::oneshot; use uuid::Uuid; use crate::cmd::auth::{ - oauth::{CliOAuthAdapter, CliOAuthProviderInfo}, - proxy::run_proxy_server, + login::LoginProvider, oauth::{CliOAuthAdapter, CliOAuthProviderInfo}, proxy::run_proxy_server }; type CodeClient = BasicClient< @@ -39,6 +38,7 @@ type CodeClient = BasicClient< >; pub struct CodeOAuth { + provider: LoginProvider, client: CodeClient, client_id: Uuid, redirect_uri: String, @@ -72,6 +72,7 @@ impl CodeOAuth { )?); Ok(Self { + provider: provider.provider(), client, client_id: provider.client_id(), redirect_uri: provider.redirect_endpoint().unwrap_or_default().to_string(), @@ -136,6 +137,7 @@ impl CodeOAuth { let error_token_tx = Arc::clone(&token_tx); let client_id = self.client_id; let redirect_uri = self.redirect_uri.clone(); + let provider = self.provider; async move { let callback: crate::cmd::auth::proxy::Callback = Arc::new(Mutex::new(Some( @@ -162,7 +164,7 @@ impl CodeOAuth { // Forward the redirect request to the API server. let token = adapter .exchange_authorization_code(super::AuthorizationCodeExchange { - provider: crate::cmd::auth::login::LoginProvider::Zendesk, + provider, client_id, redirect_uri: redirect_uri.clone(), grant_type: "authorization_code".to_string(), diff --git a/v-cli-sdk/src/cmd/auth/oauth/mod.rs b/v-cli-sdk/src/cmd/auth/oauth/mod.rs index 8a65605a..4b0a0169 100644 --- a/v-cli-sdk/src/cmd/auth/oauth/mod.rs +++ b/v-cli-sdk/src/cmd/auth/oauth/mod.rs @@ -10,7 +10,7 @@ use uuid::Uuid; pub mod code; pub mod device; -use crate::cmd::auth::login::CliAdapterToken; +use crate::cmd::auth::login::{CliAdapterToken, LoginProvider}; /// Parameters for exchanging an authorization code for an access token. pub struct AuthorizationCodeExchange { @@ -46,6 +46,7 @@ pub trait CliOAuthAdapter { } pub trait CliOAuthProviderInfo { + fn provider(&self) -> LoginProvider; fn client_id(&self) -> Uuid; fn remote_client_id(&self) -> &str; fn public_pkce_port(&self) -> Option; From 450ade65c1d9f6f73b1d7225040447140d8f47e6 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 7 May 2026 08:32:24 -0500 Subject: [PATCH 37/51] Fmt --- v-cli-sdk/src/cmd/auth/oauth/code.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/v-cli-sdk/src/cmd/auth/oauth/code.rs b/v-cli-sdk/src/cmd/auth/oauth/code.rs index 6ad44012..8a471b88 100644 --- a/v-cli-sdk/src/cmd/auth/oauth/code.rs +++ b/v-cli-sdk/src/cmd/auth/oauth/code.rs @@ -21,7 +21,9 @@ use tokio::sync::oneshot; use uuid::Uuid; use crate::cmd::auth::{ - login::LoginProvider, oauth::{CliOAuthAdapter, CliOAuthProviderInfo}, proxy::run_proxy_server + login::LoginProvider, + oauth::{CliOAuthAdapter, CliOAuthProviderInfo}, + proxy::run_proxy_server, }; type CodeClient = BasicClient< From 1a7b798a387b1cecfb5f239ac0e3e16a0451ea12 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 7 May 2026 09:16:36 -0500 Subject: [PATCH 38/51] Permissions need to be resolved during login to determine idp token access --- v-api/src/context/magic_link.rs | 2 +- v-api/src/context/mod.rs | 8 +- v-api/src/endpoints/api_user.rs | 16 +-- v-api/src/endpoints/login/local/mod.rs | 2 +- v-api/src/endpoints/login/oauth/flow/code.rs | 136 +++++++++++++------ v-api/src/endpoints/login/oauth/mod.rs | 16 ++- v-model/tests/postgres.rs | 30 ++-- 7 files changed, 129 insertions(+), 81 deletions(-) diff --git a/v-api/src/context/magic_link.rs b/v-api/src/context/magic_link.rs index 0d6b4c36..5c27d5c2 100644 --- a/v-api/src/context/magic_link.rs +++ b/v-api/src/context/magic_link.rs @@ -731,7 +731,7 @@ mod tests { .expect_transition() .returning(move |id, signature, from, _to| { if &attempt_transition.id == id - && &attempt_transition.nonce_signature == signature + && attempt_transition.nonce_signature == signature && attempt_transition.attempt_state == from { Ok(Some(MagicLinkAttempt { diff --git a/v-api/src/context/mod.rs b/v-api/src/context/mod.rs index db20eea3..43321639 100644 --- a/v-api/src/context/mod.rs +++ b/v-api/src/context/mod.rs @@ -1005,7 +1005,7 @@ mod tests { let provider = ApiUserProvider { id: TypedUuid::new_v4(), - user_id: user_id, + user_id, provider: "test".to_string(), provider_id: "test_id".to_string(), emails: vec![], @@ -1027,9 +1027,7 @@ mod tests { .await .unwrap(); - let jwt = AuthToken::Jwt(Jwt::new(&ctx, &user_token).await.unwrap()); - - jwt + AuthToken::Jwt(Jwt::new(ctx, &user_token).await.unwrap()) } #[tokio::test] @@ -1091,7 +1089,7 @@ mod tests { .returning(move |_, _| { Ok(Some(AccessToken { id: valid_token_id, - user_id: user_id, + user_id, revoked_at: None, created_at: Utc::now(), updated_at: Utc::now(), diff --git a/v-api/src/endpoints/api_user.rs b/v-api/src/endpoints/api_user.rs index 0c566c0b..55244e4f 100644 --- a/v-api/src/endpoints/api_user.rs +++ b/v-api/src/endpoints/api_user.rs @@ -820,9 +820,7 @@ mod tests { let mut store = MockApiUserStore::new(); store .expect_upsert() - .withf(|x: &NewApiUser| { - x.permissions.can(&VPermission::CreateApiUser.into()) - }) + .withf(|x: &NewApiUser| x.permissions.can(&VPermission::CreateApiUser)) .returning(|user| { Ok(ApiUserInfo { user: ApiUser { @@ -839,9 +837,7 @@ mod tests { }); store .expect_upsert() - .withf(|x: &NewApiUser| { - x.permissions.can(&VPermission::GetApiUsersAll.into()) - }) + .withf(|x: &NewApiUser| x.permissions.can(&VPermission::GetApiUsersAll)) .returning(|_| Err(StoreError::Unknown)); let mut api_user_provider_store = MockApiUserProviderStore::new(); api_user_provider_store @@ -921,7 +917,7 @@ mod tests { let mut store = MockApiUserStore::new(); store .expect_upsert() - .withf(move |x: &NewApiUser| &x.id == &success_id) + .withf(move |x: &NewApiUser| x.id == success_id) .returning(|user| { Ok(ApiUserInfo { user: ApiUser { @@ -938,7 +934,7 @@ mod tests { }); store .expect_upsert() - .withf(move |x: &NewApiUser| &x.id == &failure_id) + .withf(move |x: &NewApiUser| x.id == failure_id) .returning(|_| Err(StoreError::Unknown)); let mut api_user_provider_store = MockApiUserProviderStore::new(); api_user_provider_store @@ -1372,7 +1368,7 @@ mod tests { let mut token_store = MockApiKeyStore::new(); token_store .expect_get() - .with(eq(api_user_token_path.api_key_id.clone()), eq(false)) + .with(eq(api_user_token_path.api_key_id), eq(false)) .returning(move |_, _| Ok(Some(token.clone()))); token_store .expect_get() @@ -1677,7 +1673,7 @@ mod tests { let mut email_store = MockApiUserContactEmailStore::new(); email_store .expect_upsert() - .withf(move |arg| arg.user_id == user.id && arg.email == "user@company".to_string()) + .withf(move |arg| arg.user_id == user.id && arg.email == "user@company") .returning(|new| { Ok(ApiUserContactEmail { id: new.id, diff --git a/v-api/src/endpoints/login/local/mod.rs b/v-api/src/endpoints/login/local/mod.rs index 636f214c..92f0e307 100644 --- a/v-api/src/endpoints/login/local/mod.rs +++ b/v-api/src/endpoints/login/local/mod.rs @@ -56,7 +56,7 @@ where ctx.v_ctx().jwt_signer(), &api_user.user.id, &Claims::new( - &ctx.v_ctx(), + ctx.v_ctx(), None, &api_user.user.id, &api_user_provider.id, diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index a3bcaebf..3ec7902b 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -26,7 +26,7 @@ use tracing::instrument; use uuid::Uuid; use v_model::{ LoginAttempt, LoginAttemptId, NewLoginAttempt, OAuthClient, OAuthClientId, - permissions::{AsScope, PermissionStorage, Permissions}, + permissions::{AsScope, PermissionStorage}, schema_ext::LoginAttemptState, }; @@ -755,12 +755,10 @@ where .register_api_user(&ctx.builtin_registration_user(), info) .await?; - // Only return the IdP token if the caller requested it AND the user has permission - let idp_token = filter_idp_token( - idp_token, - request_idp_token, - &api_user_info.user.permissions, - ); + // Only return the IdP token if the caller requested it AND the user has permission. + // We must resolve the full caller (including group permissions) rather than checking + // only the directly assigned user permissions. + let idp_token = filter_idp_token(ctx, idp_token, request_idp_token, &api_user_info).await; tracing::info!(api_user_id = ?api_user_info.user.id, "Retrieved api user to generate access token for"); @@ -788,21 +786,42 @@ where } /// Filter the IdP token based on whether it was requested and whether the user has -/// the `RetrieveRemoteAccessToken` permission. Returns `None` if either condition -/// is not met. -fn filter_idp_token( +/// the `RetrieveRemoteAccessToken` permission (including permissions inherited from +/// groups). Returns `None` if either condition is not met. +async fn filter_idp_token( + ctx: &VContext, idp_token: Option, requested: bool, - permissions: &Permissions, + api_user_info: &v_model::ApiUserInfo, ) -> Option where - T: VAppPermission, + T: VAppPermission + PermissionStorage, { if !requested { return None; } - if permissions.can(&VPermission::RetrieveRemoteAccessToken.into()) { + // Resolve the caller so that group-inherited permissions are included in the + // permission check, not just directly-assigned user permissions. + let caller = match ctx + .user + .resolve_caller(api_user_info, crate::context::BasePermissions::Full) + .await + { + Ok(caller) => caller, + Err(err) => { + tracing::warn!( + ?err, + "Failed to resolve caller permissions for IdP token check" + ); + return None; + } + }; + + if caller + .permissions + .can(&VPermission::RetrieveRemoteAccessToken.into()) + { idp_token } else { tracing::info!("User requested IdP token but lacks RetrieveRemoteAccessToken permission"); @@ -1033,7 +1052,7 @@ mod tests { NewApiUserProvider, OAuthClient, OAuthClientRedirectUri, OAuthClientSecret, schema_ext::LoginAttemptState, storage::{ - MockAccessTokenStore, MockApiUserProviderStore, MockApiUserStore, + MockAccessGroupStore, MockAccessTokenStore, MockApiUserProviderStore, MockApiUserStore, MockLoginAttemptStore, MockMapperStore, MockOAuthClientStore, }, }; @@ -1063,7 +1082,7 @@ mod tests { let ctx = mock_context(Arc::new(MockStorage::new())).await; let client_id = TypedUuid::new_v4(); let key = RawKey::generate::<8>(&Uuid::new_v4()) - .sign(&*ctx.signer()) + .sign(ctx.signer()) .await .unwrap(); let secret_signature = key.signature().to_string(); @@ -1162,7 +1181,7 @@ mod tests { }; let response = oauth_redirect_response( - &ctx.public_url(), + ctx.public_url(), &*ctx .get_oauth_provider(&OAuthProviderName::Google) .await @@ -1722,7 +1741,7 @@ mod tests { .unwrap(); let invalid_secret = RawKey::generate::<8>(&Uuid::new_v4()) - .sign(&*ctx.signer()) + .sign(ctx.signer()) .await .unwrap() .signature() @@ -2120,55 +2139,81 @@ mod tests { assert!(super::validate_response_type("code token").is_err()); } - #[test] - fn test_filter_idp_token_returns_token_when_requested_and_permitted() { - let permissions: v_model::permissions::Permissions = - vec![VPermission::RetrieveRemoteAccessToken].into(); + /// Create a mock context and ApiUserInfo for `filter_idp_token` tests. + async fn mock_filter_idp_token_ctx( + user_permissions: Vec, + ) -> (VContext, ApiUserInfo) { + let mut access_group_store = MockAccessGroupStore::new(); + access_group_store + .expect_list() + .returning(|_, _| Ok(vec![])); + + let mut storage = MockStorage::new(); + storage.access_group_store = Some(Arc::new(access_group_store)); + + let ctx = mock_context(Arc::new(storage)).await; + let info = ApiUserInfo { + user: ApiUser { + id: TypedUuid::new_v4(), + permissions: user_permissions.into(), + groups: Default::default(), + created_at: Utc::now(), + updated_at: Utc::now(), + deleted_at: None, + }, + email: None, + providers: vec![], + }; + (ctx, info) + } + + #[tokio::test] + async fn test_filter_idp_token_returns_token_when_requested_and_permitted() { + let (ctx, info) = + mock_filter_idp_token_ctx(vec![VPermission::RetrieveRemoteAccessToken]).await; let token = Some("idp-token-value".to_string()); - let result = super::filter_idp_token(token, true, &permissions); + let result = super::filter_idp_token(&ctx, token, true, &info).await; assert_eq!(result, Some("idp-token-value".to_string())); } - #[test] - fn test_filter_idp_token_returns_none_when_not_requested() { - let permissions: v_model::permissions::Permissions = - vec![VPermission::RetrieveRemoteAccessToken].into(); + #[tokio::test] + async fn test_filter_idp_token_returns_none_when_not_requested() { + let (ctx, info) = + mock_filter_idp_token_ctx(vec![VPermission::RetrieveRemoteAccessToken]).await; let token = Some("idp-token-value".to_string()); // Even with the permission, if not requested the token is not returned - let result = super::filter_idp_token(token, false, &permissions); + let result = super::filter_idp_token(&ctx, token, false, &info).await; assert_eq!(result, None); } - #[test] - fn test_filter_idp_token_returns_none_when_permission_missing() { + #[tokio::test] + async fn test_filter_idp_token_returns_none_when_permission_missing() { // User has some permissions but not RetrieveRemoteAccessToken - let permissions: v_model::permissions::Permissions = - vec![VPermission::CreateApiUser].into(); + let (ctx, info) = mock_filter_idp_token_ctx(vec![VPermission::CreateApiUser]).await; let token = Some("idp-token-value".to_string()); - let result = super::filter_idp_token(token, true, &permissions); + let result = super::filter_idp_token(&ctx, token, true, &info).await; assert_eq!(result, None); } - #[test] - fn test_filter_idp_token_returns_none_when_no_permissions() { - let permissions: v_model::permissions::Permissions = - Vec::::new().into(); + #[tokio::test] + async fn test_filter_idp_token_returns_none_when_no_permissions() { + let (ctx, info) = mock_filter_idp_token_ctx(vec![]).await; let token = Some("idp-token-value".to_string()); - let result = super::filter_idp_token(token, true, &permissions); + let result = super::filter_idp_token(&ctx, token, true, &info).await; assert_eq!(result, None); } - #[test] - fn test_filter_idp_token_returns_none_when_token_is_none() { - let permissions: v_model::permissions::Permissions = - vec![VPermission::RetrieveRemoteAccessToken].into(); + #[tokio::test] + async fn test_filter_idp_token_returns_none_when_token_is_none() { + let (ctx, info) = + mock_filter_idp_token_ctx(vec![VPermission::RetrieveRemoteAccessToken]).await; // Token was None (e.g. revoked upstream) — should stay None regardless of permission - let result = super::filter_idp_token(None, true, &permissions); + let result = super::filter_idp_token(&ctx, None, true, &info).await; assert_eq!(result, None); } @@ -2231,11 +2276,18 @@ mod tests { }) }); + // AccessGroupStore: list returns empty (no groups configured) + let mut access_group_store = MockAccessGroupStore::new(); + access_group_store + .expect_list() + .returning(|_, _| Ok(vec![])); + let mut storage = MockStorage::new(); storage.api_user_provider_store = Some(Arc::new(provider_store)); storage.api_user_store = Some(Arc::new(user_store)); storage.mapper_store = Some(Arc::new(mapper_store)); storage.access_token_store = Some(Arc::new(access_token_store)); + storage.access_group_store = Some(Arc::new(access_group_store)); storage } diff --git a/v-api/src/endpoints/login/oauth/mod.rs b/v-api/src/endpoints/login/oauth/mod.rs index 6a3ebbbc..c5f8d18c 100644 --- a/v-api/src/endpoints/login/oauth/mod.rs +++ b/v-api/src/endpoints/login/oauth/mod.rs @@ -164,7 +164,7 @@ where } } -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, JsonSchema)] pub struct OAuthProviderInfo { provider: OAuthProviderName, client_id: String, @@ -173,7 +173,7 @@ pub struct OAuthProviderInfo { device: Option, } -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, JsonSchema)] pub struct OAuthProviderAuthorizationCodeInfo { auth_url_endpoint: String, redirect_endpoint: String, @@ -182,10 +182,11 @@ pub struct OAuthProviderAuthorizationCodeInfo { remote: OAuthProviderAuthorizationCodeRemoteInfo, } -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, JsonSchema)] pub struct OAuthProviderAuthorizationCodeRemoteInfo { client_id: String, - #[serde(skip_serializing)] + #[schemars(skip)] + #[serde(skip)] client_secret: OpenApiSecretString, auth_url_endpoint: String, token_endpoint_content_type: String, @@ -193,7 +194,7 @@ pub struct OAuthProviderAuthorizationCodeRemoteInfo { revocation_endpoint: Option, } -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, JsonSchema)] pub struct OAuthProviderAuthorizationCodePkceInfo { client_id: TypedUuid, redirect_endpoint: String, @@ -201,11 +202,12 @@ pub struct OAuthProviderAuthorizationCodePkceInfo { web: OAuthProviderAuthorizationCodeInfo, } -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, JsonSchema)] pub struct OAuthProviderDeviceInfo { client_id: TypedUuid, remote_client_id: String, - #[serde(skip_serializing)] + #[schemars(skip)] + #[serde(skip)] remote_client_secret: OpenApiSecretString, device_code_endpoint: String, token_endpoint_content_type: String, diff --git a/v-model/tests/postgres.rs b/v-model/tests/postgres.rs index e133f501..c24d6d7f 100644 --- a/v-model/tests/postgres.rs +++ b/v-model/tests/postgres.rs @@ -71,7 +71,7 @@ impl TestDb { }; println!("Creating database {}", db.db_name); - let create_result = sql_query(&format!("CREATE DATABASE {}", db.db_name)) + let create_result = sql_query(format!("CREATE DATABASE {}", db.db_name)) .execute(&mut db.conn()) .unwrap(); println!("Created database {:?}", create_result); @@ -100,7 +100,7 @@ impl TestDb { impl Drop for TestDb { fn drop(&mut self) { if self.should_drop { - sql_query(&format!("DROP DATABASE {}", self.db_name)) + sql_query(format!("DROP DATABASE {}", self.db_name)) .execute(&mut self.conn()) .unwrap(); } @@ -133,7 +133,7 @@ async fn test_api_user() { &store, NewApiUser { id: api_user_id, - permissions: vec![TestPermission::CreateApiKey(api_user_id).into()].into(), + permissions: vec![TestPermission::CreateApiKey(api_user_id)].into(), groups: BTreeSet::new(), }, ) @@ -153,7 +153,7 @@ async fn test_api_user() { &store, NewApiUser { id: api_user_id, - permissions: vec![TestPermission::CreateApiKey(api_user_id).into()].into(), + permissions: vec![TestPermission::CreateApiKey(api_user_id)].into(), groups: BTreeSet::new(), }, ) @@ -168,9 +168,9 @@ async fn test_api_user() { NewApiUser { id: api_user_id, permissions: vec![ - TestPermission::CreateApiKey(api_user_id).into(), - TestPermission::GetApiKey(api_user_id).into(), - TestPermission::DeleteApiKey(api_user_id).into(), + TestPermission::CreateApiKey(api_user_id), + TestPermission::GetApiKey(api_user_id), + TestPermission::DeleteApiKey(api_user_id), ] .into(), groups: BTreeSet::new(), @@ -183,13 +183,13 @@ async fn test_api_user() { api_user .user .permissions - .can(&TestPermission::GetApiKey(api_user_id).into()) + .can(&TestPermission::GetApiKey(api_user_id)) ); assert!( api_user .user .permissions - .can(&TestPermission::DeleteApiKey(api_user_id).into()) + .can(&TestPermission::DeleteApiKey(api_user_id)) ); // 5. Create an API token for the user @@ -199,7 +199,7 @@ async fn test_api_user() { id: TypedUuid::new_v4(), user_id: api_user.user.id, key_signature: format!("key-{}", Uuid::new_v4()), - permissions: Some(vec![TestPermission::GetApiKey(api_user_id).into()].into()), + permissions: Some(vec![TestPermission::GetApiKey(api_user_id)].into()), expires_at: Utc::now() + TimeDelta::try_seconds(5 * 60).unwrap(), }, ) @@ -215,8 +215,8 @@ async fn test_api_user() { key_signature: format!("key-{}", Uuid::new_v4()), permissions: Some( vec![ - TestPermission::CreateApiUser.into(), - TestPermission::GetApiKey(api_user_id).into(), + TestPermission::CreateApiUser, + TestPermission::GetApiKey(api_user_id), ] .into(), ), @@ -231,7 +231,7 @@ async fn test_api_user() { .permissions .as_ref() .unwrap() - .can(&TestPermission::CreateApiUser.into()) + .can(&TestPermission::CreateApiUser) ); // 7. Create an API token with excess permissions for the user @@ -243,8 +243,8 @@ async fn test_api_user() { key_signature: format!("key-{}", Uuid::new_v4()), permissions: Some( vec![ - TestPermission::CreateApiUser.into(), - TestPermission::GetApiKey(api_user_id).into(), + TestPermission::CreateApiUser, + TestPermission::GetApiKey(api_user_id), ] .into(), ), From 0e8bb73eae0ac447b6e0c092587fd1a4ef781ede Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 7 May 2026 12:42:00 -0500 Subject: [PATCH 39/51] Add support for Zendesk expires_in extension --- v-api/src/endpoints/login/oauth/flow/code.rs | 4 ++++ v-api/src/endpoints/login/oauth/mod.rs | 1 + v-api/src/endpoints/login/oauth/remote/github.rs | 4 ++++ v-api/src/endpoints/login/oauth/remote/google.rs | 4 ++++ v-api/src/endpoints/login/oauth/remote/zendesk.rs | 4 ++++ 5 files changed, 17 insertions(+) diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 3ec7902b..6b55bcdf 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -996,6 +996,10 @@ async fn fetch_user_info( request = request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier.to_string())) } + if let Some(expires_in) = provider.expires_in() { + request = request.add_extra_param("expires_in", expires_in.to_string()); + } + let oauth_client: oauth2_reqwest::ReqwestClient = provider.client().clone().into(); let response = request .request_async(&oauth_client) diff --git a/v-api/src/endpoints/login/oauth/mod.rs b/v-api/src/endpoints/login/oauth/mod.rs index c5f8d18c..a592373f 100644 --- a/v-api/src/endpoints/login/oauth/mod.rs +++ b/v-api/src/endpoints/login/oauth/mod.rs @@ -72,6 +72,7 @@ pub trait OAuthProvider: ExtractUserInfo + Debug + Send + Sync { fn authz_code_pkce_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodePkceInfo>; fn device_code_flow_info(&self) -> Option<&OAuthProviderDeviceInfo>; + fn expires_in(&self) -> Option; fn default_scopes(&self) -> &[String]; /// Whether the remote OAuth provider supports PKCE (RFC 7636). Providers must diff --git a/v-api/src/endpoints/login/oauth/remote/github.rs b/v-api/src/endpoints/login/oauth/remote/github.rs index 4bae1b55..ec814495 100644 --- a/v-api/src/endpoints/login/oauth/remote/github.rs +++ b/v-api/src/endpoints/login/oauth/remote/github.rs @@ -141,6 +141,10 @@ impl OAuthProvider for GitHubOAuthProvider { "https://api.github.com/user/emails", ] } + + fn expires_in(&self) -> Option { + None + } fn default_scopes(&self) -> &[String] { &self.default_scopes } diff --git a/v-api/src/endpoints/login/oauth/remote/google.rs b/v-api/src/endpoints/login/oauth/remote/google.rs index 2a24e956..6a91eeeb 100644 --- a/v-api/src/endpoints/login/oauth/remote/google.rs +++ b/v-api/src/endpoints/login/oauth/remote/google.rs @@ -166,6 +166,10 @@ impl OAuthProvider for GoogleOAuthProvider { "https://people.googleapis.com/v1/people/me?personFields=names", ] } + + fn expires_in(&self) -> Option { + None + } fn default_scopes(&self) -> &[String] { &self.default_scopes } diff --git a/v-api/src/endpoints/login/oauth/remote/zendesk.rs b/v-api/src/endpoints/login/oauth/remote/zendesk.rs index 211b8707..b1c18597 100644 --- a/v-api/src/endpoints/login/oauth/remote/zendesk.rs +++ b/v-api/src/endpoints/login/oauth/remote/zendesk.rs @@ -137,6 +137,10 @@ impl OAuthProvider for ZendeskOAuthProvider { fn user_info_endpoints(&self) -> Vec<&str> { vec![&self.user_info_endpoint] } + + fn expires_in(&self) -> Option { + Some(172800) + } fn default_scopes(&self) -> &[String] { &self.default_scopes } From 9dbfee44fab63b1c9ac45aa535a4ad90c4afe6be Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 7 May 2026 13:18:18 -0500 Subject: [PATCH 40/51] More changes to align with spec --- .../src/endpoints/login/oauth/device_token.rs | 279 ------------------ v-api/src/endpoints/login/oauth/flow/code.rs | 129 +++++++- 2 files changed, 124 insertions(+), 284 deletions(-) delete mode 100644 v-api/src/endpoints/login/oauth/device_token.rs diff --git a/v-api/src/endpoints/login/oauth/device_token.rs b/v-api/src/endpoints/login/oauth/device_token.rs deleted file mode 100644 index ef0c4e72..00000000 --- a/v-api/src/endpoints/login/oauth/device_token.rs +++ /dev/null @@ -1,279 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -use chrono::{DateTime, Utc}; -use dropshot::{Body, HttpError, HttpResponseOk, Method, Path, RequestContext, TypedBody}; -use http::{HeaderValue, Response, StatusCode, header}; -use oauth2::{EmptyExtraTokenFields, StandardTokenResponse, TokenResponse, basic::BasicTokenType}; -use schemars::JsonSchema; -use secrecy::ExposeSecret; -use serde::{Deserialize, Serialize}; -use std::fmt::Debug; -use tap::TapFallible; -use tracing::instrument; -use v_model::permissions::PermissionStorage; - -use super::{ - ClientType, OAuthProvider, OAuthProviderInfo, OAuthProviderNameParam, UserInfoProvider, -}; -use crate::{ - context::ApiContext, endpoints::login::LoginError, error::ApiError, - permissions::VAppPermission, response::internal_error, util::response::bad_request, -}; - -#[instrument(skip(rqctx), err(Debug))] -pub async fn get_device_provider_op( - rqctx: &RequestContext>, - path: Path, -) -> Result, HttpError> -where - T: VAppPermission + PermissionStorage, -{ - let path = path.into_inner(); - - tracing::trace!("Getting OAuth data for {}", path.provider); - - let provider = rqctx - .v_ctx() - .get_oauth_provider(&path.provider) - .await - .map_err(ApiError::OAuth)?; - - Ok(HttpResponseOk(provider.provider_info( - rqctx.v_ctx().public_url(), - &ClientType::Device, - ))) -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct AccessTokenExchangeRequest { - pub device_code: String, - pub grant_type: String, - pub expires_at: Option>, -} - -#[derive(Serialize)] -pub struct AccessTokenExchange { - provider: ProviderTokenExchange, - expires_at: Option>, -} - -#[derive(Serialize)] -pub struct ProviderTokenExchange { - client_id: String, - device_code: String, - grant_type: String, - client_secret: String, -} - -impl AccessTokenExchange { - pub fn new( - req: AccessTokenExchangeRequest, - provider: &(dyn OAuthProvider + Send + Sync), - ) -> Option { - provider - .client_secret(&ClientType::Device) - .map(|client_secret| Self { - provider: ProviderTokenExchange { - client_id: provider.client_id(&ClientType::Device).to_string(), - device_code: req.device_code, - grant_type: req.grant_type, - client_secret: client_secret.expose_secret().to_string(), - }, - expires_at: req.expires_at, - }) - } -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct ProxyTokenResponse { - pub access_token: String, - pub token_type: String, - pub expires_in: Option, - pub refresh_token: Option, - pub scopes: Option>, -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct ProxyTokenError { - error: String, - error_description: Option, - error_uri: Option, -} - -// Complete a device exchange request against the specified provider. This effectively proxies the -// requests that would go to the provider, captures the returned access tokens, and registers a -// new internal user as needed. The user is then returned an token that is valid for interacting -// with the API -#[instrument(skip(rqctx, body), err(Debug))] -pub async fn exchange_device_token_op( - rqctx: &RequestContext>, - path: Path, - body: TypedBody, -) -> Result, HttpError> -where - T: VAppPermission + PermissionStorage, -{ - let ctx = rqctx.v_ctx(); - let path = path.into_inner(); - let provider = ctx - .get_oauth_provider(&path.provider) - .await - .map_err(ApiError::OAuth)?; - - tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for token exchange"); - - let exchange_request = body.into_inner(); - - if let Some(exchange) = AccessTokenExchange::new(exchange_request, &*provider) { - let token_exchange_endpoint = provider.token_exchange_endpoint(); - let client = reqwest::Client::new(); - - let response = client - .request(Method::POST, token_exchange_endpoint) - .header(header::CONTENT_TYPE, provider.token_exchange_content_type()) - .header(header::ACCEPT, HeaderValue::from_static("application/json")) - .body( - // We know that this is safe to unwrap as we just deserialized it via the body Extractor - serde_urlencoded::to_string(&exchange.provider).unwrap(), - ) - .send() - .await - .tap_err(|err| tracing::error!(?err, "Token exchange request failed")) - .map_err(internal_error)?; - - // Take a part the response as we will need the individual parts later - let status = response.status(); - let headers = response.headers().clone(); - let bytes = response.bytes().await.map_err(internal_error)?; - - // We unfortunately can not trust our providers to follow specs and therefore need to do - // our own inspection of the response to determine what to do - if !status.is_success() { - // If the server returned a non-success status then we are going to trust the server and - // report their error back to the client - tracing::debug!(provider = ?path.provider, ?headers, ?status, "Received error response from OAuth provider"); - - let mut client_response = Response::new(Body::from(bytes)); - *client_response.headers_mut() = headers; - *client_response.status_mut() = status; - - Ok(client_response) - } else { - // The server gave us back a non-error response but it still may not be a success. - // GitHub for instance does not use a status code for indicating the success or failure - // of a call. So instead we try to deserialize the body into an access token, with the - // understanding that it may fail and we will need to try and treat the response as - // an error instead. - - let parsed: Result< - StandardTokenResponse, - serde_json::Error, - > = serde_json::from_slice(&bytes); - - match parsed { - Ok(parsed) => { - let info = provider - .get_user_info(parsed.access_token().secret()) - .await - .map_err(LoginError::UserInfo) - .tap_err(|err| { - tracing::error!(?err, "Failed to look up user information") - })?; - - tracing::debug!("Verified and validated OAuth user"); - - let (api_user_info, api_user_provider) = ctx - .register_api_user(&ctx.builtin_registration_user(), info) - .await?; - - tracing::info!(api_user_id = ?api_user_info.user.id, api_user_provider_id = ?api_user_provider.id, "Retrieved api user to generate device token for"); - - let claims = - ctx.generate_claims(&api_user_info.user.id, &api_user_provider.id, None); - let token = ctx - .user - .register_access_token( - &ctx.builtin_registration_user(), - ctx.jwt_signer(), - &api_user_info.user.id, - &claims, - ) - .await?; - - tracing::info!(provider = ?path.provider, api_user_id = ?api_user_info.user.id, "Generated access token"); - - Ok(Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "application/json") - .body( - serde_json::to_string(&ProxyTokenResponse { - access_token: token.signed_token, - token_type: "Bearer".to_string(), - expires_in: Some(claims.exp - Utc::now().timestamp()), - refresh_token: None, - scopes: None, - }) - .unwrap() - .into(), - )?) - } - Err(_) => { - // Do not log the error here as we want to ensure we do not leak token information - tracing::debug!( - "Failed to parse a success response from the remote token endpoint" - ); - - // Try to deserialize the body again, but this time as an error - let mut error_response = match serde_json::from_slice::(&bytes) - { - Ok(error) => { - // We found an error in the message body. This is not ideal, but we at - // least can understand what the server was trying to tell us - tracing::debug!(?error, provider = ?path.provider, "Parsed error response from OAuth provider"); - - let mut client_response = Response::new(Body::from(bytes)); - *client_response.headers_mut() = headers; - *client_response.status_mut() = status; - - client_response - } - Err(_) => { - // We still do not know what the remote server is doing... and need to - // cancel the request ourselves - tracing::warn!( - "Remote OAuth provide returned a response that we do not undestand" - ); - - Response::new( - serde_json::to_vec(&ProxyTokenError { - error: "access_denied".to_string(), - error_description: Some(format!( - "{} returned a malformed response", - path.provider - )), - error_uri: None, - }) - .unwrap() - .into(), - ) - } - }; - - *error_response.status_mut() = StatusCode::BAD_REQUEST; - error_response.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static("application/json"), - ); - - Ok(error_response) - } - } - } - } else { - tracing::info!(provider = ?path.provider, "Found an OAuth provider, but it is not configured properly"); - - Err(bad_request("Invalid provider")) - } -} diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 6b55bcdf..fcaaa67a 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -15,7 +15,7 @@ use newtype_uuid::{GenericUuid, TypedUuid}; use oauth2::{ AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, }; -use percent_encoding::{NON_ALPHANUMERIC, percent_encode}; + use schemars::JsonSchema; use secrecy::SecretString; use serde::{Deserialize, Serialize}; @@ -308,10 +308,10 @@ where // TODO: Make this configurable attempt.expires_at = Some(Utc::now().add(TimeDelta::try_minutes(5).unwrap())); - // Add in the user defined state and redirect uri. State is an arbitrary value and may be - // malicious. It must be url-encoded before being presented back to the client. Therefore we - // process once before storing so all downstream consumers see the encoded value. - attempt.state = Some(percent_encode(query.state.as_bytes(), NON_ALPHANUMERIC).to_string()); + // Store the client's state value as-is. Per RFC 6749 §4.1.1, the authorization server + // MUST return the state parameter unmodified. The value will be properly percent-encoded + // when it is placed into the redirect URL by `callback_url()` via `append_pair`. + attempt.state = Some(query.state); // Always store the client's PKCE challenge so we can verify it during the token exchange. // This is the client-to-v-api PKCE leg and is mandatory for all flows. @@ -573,6 +573,8 @@ pub struct OAuthAuthzCodeExchangeResponse { pub access_token: String, pub token_type: String, pub expires_in: i64, + /// The scope granted to the access token (RFC 6749 §5.1). + pub scope: String, pub idp_token: Option, } @@ -781,6 +783,7 @@ where token_type: "Bearer".to_string(), access_token: token.signed_token, expires_in: token.expires_in, + scope: attempt.scope.clone(), idp_token, })) } @@ -2369,6 +2372,122 @@ mod tests { ); } + /// Verifies that the `state` parameter survives the authorization code flow + /// round trip without modification, as required by RFC 6749 §4.1.1. The + /// authorization server MUST return the exact `state` value that the client + /// originally provided. This test uses a state value containing characters + /// that require percent-encoding (`+`, `/`, spaces, `&`, `=`) to ensure + /// they are encoded exactly once in the final redirect URL and decoded back + /// to the original value by standard URL parsing. + #[tokio::test] + async fn test_state_roundtrip_preserves_special_characters() { + let attempt_id = TypedUuid::new_v4(); + let original_state = "random+state/with spaces&special=chars"; + + // State is now stored as-is (no pre-encoding). callback_url() handles + // percent-encoding when building the redirect URL. + let attempt = LoginAttempt { + id: attempt_id, + attempt_state: LoginAttemptState::New, + client_id: TypedUuid::new_v4(), + redirect_uri: "https://test.oxeng.dev/callback".to_string(), + state: Some(original_state.to_string()), + pkce_challenge: Some("ox_challenge".to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: None, + expires_at: None, + error: None, + provider: "google".to_string(), + provider_pkce_verifier: Some("v_verifier".to_string()), + provider_authz_code: None, + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + scope: String::new(), + }; + + let mut attempt_store = MockLoginAttemptStore::new(); + let original_attempt = attempt.clone(); + attempt_store + .expect_get() + .with(eq(attempt.id)) + .returning(move |_| Ok(Some(original_attempt.clone()))); + + attempt_store + .expect_update_if_state() + .withf(|attempt, expected| { + attempt.attempt_state == LoginAttemptState::RemoteAuthenticated + && *expected == LoginAttemptState::New + }) + .returning(move |arg, _| { + let mut returned = attempt.clone(); + returned.attempt_state = arg.attempt_state; + returned.authz_code = arg.authz_code; + Ok(returned) + }); + + let mut storage = MockStorage::new(); + storage.login_attempt_store = Some(Arc::new(attempt_store)); + let ctx = mock_context(Arc::new(storage)).await; + + let location = + authz_code_callback_op_inner(&ctx, &attempt_id, Some("remote-code".to_string()), None) + .await + .unwrap(); + + let url = url::Url::parse(&location).unwrap(); + let returned_state = url + .query_pairs() + .find(|(k, _)| k == "state") + .map(|(_, v)| v.into_owned()) + .expect("state parameter must be present in callback URL"); + + // RFC 6749 §4.1.1: the state value MUST be returned to the client + // unmodified. The client sent `original_state`, so it should get back + // exactly `original_state` after URL decoding. + assert_eq!( + original_state, returned_state, + "RFC 6749 §4.1.1 requires the state parameter to be returned unmodified. \ + The client sent {:?} but received {:?}.", + original_state, returned_state, + ); + } + + /// RFC 6749 §5.1 requires the token response to include a `scope` parameter + /// when the issued scope differs from what the client requested, and recommends + /// it in all cases. The token response should echo back the scope that was + /// granted so clients can verify what permissions they received. + #[tokio::test] + async fn test_exchange_response_includes_scope() { + let storage = mock_exchange_storage(vec![VPermission::CreateAccessToken]); + let ctx = mock_context(Arc::new(storage)).await; + let attempt = mock_completed_attempt(); // scope = "user:info:r" + let info = UserInfo { + external_id: ExternalUserId::Google("test-google-id".to_string()), + verified_emails: vec!["user@example.com".to_string()], + display_name: Some("Test User".to_string()), + idp_token: None, + }; + + let response = super::complete_exchange(&ctx, info, &attempt, false) + .await + .unwrap() + .0; + + // Serialize the response to JSON and check for a "scope" field. + // Per RFC 6749 §5.1, the authorization server SHOULD include the scope + // in the token response, and MUST include it if it differs from what + // the client requested. + let json = serde_json::to_value(&response).unwrap(); + assert!( + json.get("scope").is_some(), + "Token response must include a 'scope' field per RFC 6749 §5.1. \ + The login attempt had scope {:?} but the response was: {}", + attempt.scope, + serde_json::to_string_pretty(&json).unwrap(), + ); + } + #[tokio::test] async fn test_exchange_omits_idp_token_when_not_requested() { let storage = mock_exchange_storage(vec![ From 2115ca40f0af200d019be654c4783c9d24a94667 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 7 May 2026 13:39:35 -0500 Subject: [PATCH 41/51] Verify provider --- v-api/src/context/login.rs | 2 + v-api/src/endpoints/login/oauth/flow/code.rs | 77 +++++++++++++++++++- v-model/src/storage/mod.rs | 1 + v-model/src/storage/postgres.rs | 5 ++ 4 files changed, 84 insertions(+), 1 deletion(-) diff --git a/v-api/src/context/login.rs b/v-api/src/context/login.rs index 3f9bd234..d2dde23a 100644 --- a/v-api/src/context/login.rs +++ b/v-api/src/context/login.rs @@ -62,10 +62,12 @@ where pub async fn get_login_attempt_for_code( &self, code: &str, + provider: &str, ) -> Result, StoreError> { let filter = LoginAttemptFilter { attempt_state: Some(vec![LoginAttemptState::RemoteAuthenticated]), authz_code: Some(vec![code.to_string()]), + provider: Some(vec![provider.to_string()]), ..Default::default() }; diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index fcaaa67a..5082b19c 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -678,7 +678,7 @@ where // Lookup the request assigned to this code let mut attempt = ctx .login - .get_login_attempt_for_code(&body.code) + .get_login_attempt_for_code(&body.code, &provider.name().to_string()) .await .map_err(to_internal_error)? .ok_or(OAuthError { @@ -2508,4 +2508,79 @@ mod tests { "IdP token must NOT be returned when not requested, even with permission" ); } + + /// The authorization code lookup should filter by provider so that a code + /// issued for one provider (e.g. Google) is not returned when exchanging + /// against a different provider (e.g. GitHub). This is a defense-in-depth + /// measure — codes should be scoped to their issuing provider at the query + /// level rather than relying solely on post-lookup validation. + #[tokio::test] + async fn test_code_lookup_filters_by_provider() { + // Create a login attempt that was authenticated via Google + let google_attempt = LoginAttempt { + id: TypedUuid::new_v4(), + attempt_state: LoginAttemptState::RemoteAuthenticated, + client_id: TypedUuid::new_v4(), + redirect_uri: "https://test.oxeng.dev/callback".to_string(), + state: Some("test-state".to_string()), + pkce_challenge: Some("test-challenge".to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: Some("authz-code-for-google".to_string()), + expires_at: None, + error: None, + provider: "google".to_string(), + provider_pkce_verifier: None, + provider_authz_code: Some("remote-code".to_string()), + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + scope: "user:info:r".to_string(), + }; + + // The mock store simulates a real database: it only returns the + // attempt when the filter's provider field matches. + let returned_attempt = google_attempt.clone(); + let mut attempt_store = MockLoginAttemptStore::new(); + attempt_store.expect_list().returning(move |filter, _| { + let dominated = &returned_attempt; + if let Some(providers) = &filter.provider { + if providers.iter().any(|p| p == &dominated.provider) { + Ok(vec![dominated.clone()]) + } else { + Ok(vec![]) + } + } else { + Ok(vec![dominated.clone()]) + } + }); + + let mut storage = MockStorage::new(); + storage.login_attempt_store = Some(Arc::new(attempt_store)); + let ctx = mock_context(Arc::new(storage)).await; + + // Looking up the code for the correct provider should succeed. + let google_result = ctx + .login + .get_login_attempt_for_code("authz-code-for-google", "google") + .await + .unwrap(); + assert!( + google_result.is_some(), + "Code lookup for the issuing provider must return the attempt" + ); + + // Looking up the same code but for a different provider should return + // None, because the provider filter now scopes the query. + let github_result = ctx + .login + .get_login_attempt_for_code("authz-code-for-google", "github") + .await + .unwrap(); + assert!( + github_result.is_none(), + "Code lookup must not return an attempt for a different provider. \ + Expected None, but got {:?}.", + github_result.as_ref().map(|a| &a.provider), + ); + } } diff --git a/v-model/src/storage/mod.rs b/v-model/src/storage/mod.rs index d99d692f..ae7b5d83 100644 --- a/v-model/src/storage/mod.rs +++ b/v-model/src/storage/mod.rs @@ -257,6 +257,7 @@ pub struct LoginAttemptFilter { pub client_id: Option>>, pub attempt_state: Option>, pub authz_code: Option>, + pub provider: Option>, } #[cfg_attr(feature = "mock", automock)] diff --git a/v-model/src/storage/postgres.rs b/v-model/src/storage/postgres.rs index 7b23ce1b..f5e8ebcd 100644 --- a/v-model/src/storage/postgres.rs +++ b/v-model/src/storage/postgres.rs @@ -705,6 +705,7 @@ impl LoginAttemptStore for PostgresStore { client_id, attempt_state, authz_code, + provider, } = filter; if let Some(id) = id { @@ -731,6 +732,10 @@ impl LoginAttemptStore for PostgresStore { query = query.filter(login_attempt::authz_code.eq_any(authz_code)); } + if let Some(provider) = provider { + query = query.filter(login_attempt::provider.eq_any(provider)); + } + let results = query .offset(pagination.offset) .limit(pagination.limit) From a20ef122991b2e5d57084b9d10324596ec5239dd Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 7 May 2026 14:02:20 -0500 Subject: [PATCH 42/51] Fix for race with redirect uris --- v-api/src/endpoints/login/oauth/flow/code.rs | 174 ++++++++++++++++++- 1 file changed, 170 insertions(+), 4 deletions(-) diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 5082b19c..9f389650 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -506,6 +506,23 @@ where } })?; + // Re-validate the redirect URI against the OAuth client's current registered URIs. + // The URI was checked when the login attempt was created, but it may have been removed + // since then. We must not redirect to a URI that is no longer registered (TOCTOU). + let client = ctx + .oauth + .get_oauth_client(&ctx.builtin_registration_user(), &attempt.client_id) + .await + .map_err(to_internal_error)?; + if !client.is_redirect_uri_valid(&attempt.redirect_uri) { + tracing::warn!( + redirect_uri = ?attempt.redirect_uri, + client_id = ?attempt.client_id, + "Login attempt redirect URI is no longer registered on the OAuth client" + ); + return Err(unauthorized()); + } + attempt = match (code, error) { (Some(code), None) => { tracing::info!(?attempt.id, "Received valid login attempt. Storing authorization code"); @@ -1085,6 +1102,37 @@ mod tests { use super::{authorize_code_exchange, get_oauth_client, oauth_redirect_response}; + /// Create a mock `OAuthClientStore` that returns a client with the given + /// `client_id` and a single registered `redirect_uri`. This is needed by + /// any test that exercises `authz_code_callback_op_inner`, which re-validates + /// the redirect URI against the client before redirecting. + fn mock_oauth_client_store_for_callback( + client_id: TypedUuid, + redirect_uri: &str, + ) -> Arc { + let redirect_uri = redirect_uri.to_string(); + let mut store = MockOAuthClientStore::new(); + store + .expect_get() + .with(eq(client_id), eq(false)) + .returning(move |_, _| { + Ok(Some(OAuthClient { + id: client_id, + secrets: vec![], + redirect_uris: vec![OAuthClientRedirectUri { + id: TypedUuid::new_v4(), + oauth_client_id: client_id, + redirect_uri: redirect_uri.clone(), + created_at: Utc::now(), + deleted_at: None, + }], + created_at: Utc::now(), + deleted_at: None, + })) + }); + Arc::new(store) + } + async fn mock_client() -> (VContext, OAuthClient, SecretString) { let ctx = mock_context(Arc::new(MockStorage::new())).await; let client_id = TypedUuid::new_v4(); @@ -1350,10 +1398,11 @@ mod tests { #[tokio::test] async fn test_callback_fails_when_error_is_passed() { let attempt_id = TypedUuid::new_v4(); + let client_id = TypedUuid::new_v4(); let attempt = LoginAttempt { id: attempt_id, attempt_state: LoginAttemptState::New, - client_id: TypedUuid::new_v4(), + client_id, redirect_uri: "https://test.oxeng.dev/callback".to_string(), state: Some("ox_state".to_string()), pkce_challenge: Some("ox_challenge".to_string()), @@ -1393,6 +1442,10 @@ mod tests { let mut storage = MockStorage::new(); storage.login_attempt_store = Some(Arc::new(attempt_store)); + storage.oauth_client_store = Some(mock_oauth_client_store_for_callback( + client_id, + "https://test.oxeng.dev/callback", + )); let ctx = mock_context(Arc::new(storage)).await; let location = authz_code_callback_op_inner( @@ -1413,10 +1466,11 @@ mod tests { #[tokio::test] async fn test_callback_forwards_access_denied() { let attempt_id = TypedUuid::new_v4(); + let client_id = TypedUuid::new_v4(); let attempt = LoginAttempt { id: attempt_id, attempt_state: LoginAttemptState::New, - client_id: TypedUuid::new_v4(), + client_id, redirect_uri: "https://test.oxeng.dev/callback".to_string(), state: Some("ox_state".to_string()), pkce_challenge: Some("ox_challenge".to_string()), @@ -1456,6 +1510,10 @@ mod tests { let mut storage = MockStorage::new(); storage.login_attempt_store = Some(Arc::new(attempt_store)); + storage.oauth_client_store = Some(mock_oauth_client_store_for_callback( + client_id, + "https://test.oxeng.dev/callback", + )); let ctx = mock_context(Arc::new(storage)).await; let location = authz_code_callback_op_inner( @@ -1476,10 +1534,11 @@ mod tests { #[tokio::test] async fn test_handles_callback_with_code() { let attempt_id = TypedUuid::new_v4(); + let client_id = TypedUuid::new_v4(); let attempt = LoginAttempt { id: attempt_id, attempt_state: LoginAttemptState::New, - client_id: TypedUuid::new_v4(), + client_id, redirect_uri: "https://test.oxeng.dev/callback".to_string(), state: Some("ox_state".to_string()), pkce_challenge: Some("ox_challenge".to_string()), @@ -1521,6 +1580,10 @@ mod tests { let mut storage = MockStorage::new(); storage.login_attempt_store = Some(Arc::new(attempt_store)); + storage.oauth_client_store = Some(mock_oauth_client_store_for_callback( + client_id, + "https://test.oxeng.dev/callback", + )); let ctx = mock_context(Arc::new(storage)).await; let location = @@ -2382,6 +2445,7 @@ mod tests { #[tokio::test] async fn test_state_roundtrip_preserves_special_characters() { let attempt_id = TypedUuid::new_v4(); + let client_id = TypedUuid::new_v4(); let original_state = "random+state/with spaces&special=chars"; // State is now stored as-is (no pre-encoding). callback_url() handles @@ -2389,7 +2453,7 @@ mod tests { let attempt = LoginAttempt { id: attempt_id, attempt_state: LoginAttemptState::New, - client_id: TypedUuid::new_v4(), + client_id, redirect_uri: "https://test.oxeng.dev/callback".to_string(), state: Some(original_state.to_string()), pkce_challenge: Some("ox_challenge".to_string()), @@ -2428,6 +2492,10 @@ mod tests { let mut storage = MockStorage::new(); storage.login_attempt_store = Some(Arc::new(attempt_store)); + storage.oauth_client_store = Some(mock_oauth_client_store_for_callback( + client_id, + "https://test.oxeng.dev/callback", + )); let ctx = mock_context(Arc::new(storage)).await; let location = @@ -2509,6 +2577,104 @@ mod tests { ); } + /// The OAuth callback (`authz_code_callback_op_inner`) redirects the user to + /// the `redirect_uri` stored in the login attempt without re-validating it + /// against the OAuth client's currently registered redirect URIs. This means + /// that if a redirect URI is removed from the client between the authorization + /// request and the callback, the redirect still proceeds to the now-deregistered + /// URI (a TOCTOU gap). The callback should re-validate the redirect URI before + /// using it. + #[tokio::test] + async fn test_callback_revalidates_redirect_uri() { + let client_id = TypedUuid::new_v4(); + // The login attempt was created with a redirect_uri that was valid at the + // time, but has since been removed from the client's allowed list. + let deregistered_uri = "https://formerly-valid.example.com/callback"; + + let attempt_id = TypedUuid::new_v4(); + let attempt = LoginAttempt { + id: attempt_id, + attempt_state: LoginAttemptState::New, + client_id, + redirect_uri: deregistered_uri.to_string(), + state: Some("test-state".to_string()), + pkce_challenge: Some("test-challenge".to_string()), + pkce_challenge_method: Some("S256".to_string()), + authz_code: None, + expires_at: None, + error: None, + provider: "google".to_string(), + provider_pkce_verifier: None, + provider_authz_code: None, + provider_error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + scope: "user:info:r".to_string(), + }; + + let mut attempt_store = MockLoginAttemptStore::new(); + let original_attempt = attempt.clone(); + attempt_store + .expect_get() + .with(eq(attempt_id)) + .returning(move |_| Ok(Some(original_attempt.clone()))); + + attempt_store + .expect_update_if_state() + .withf(|attempt, expected| { + attempt.attempt_state == LoginAttemptState::RemoteAuthenticated + && *expected == LoginAttemptState::New + }) + .returning(move |arg, _| { + let mut returned = attempt.clone(); + returned.attempt_state = arg.attempt_state; + returned.authz_code = arg.authz_code; + Ok(returned) + }); + + // Configure the OAuth client with NO registered redirect URIs, + // simulating that the URI was removed after the login attempt + // was created. + let mut client_store = MockOAuthClientStore::new(); + client_store + .expect_get() + .with(eq(client_id), eq(false)) + .returning(move |_, _| { + Ok(Some(OAuthClient { + id: client_id, + secrets: vec![], + redirect_uris: vec![], // No registered URIs + created_at: Utc::now(), + deleted_at: None, + })) + }); + + let mut storage = MockStorage::new(); + storage.login_attempt_store = Some(Arc::new(attempt_store)); + storage.oauth_client_store = Some(Arc::new(client_store)); + let ctx = mock_context(Arc::new(storage)).await; + + // The callback should reject the request because the redirect URI is no + // longer registered on the OAuth client. + let err = authz_code_callback_op_inner( + &ctx, + &attempt_id, + Some("remote-code".to_string()), + None, + ) + .await + .expect_err( + "Callback should fail when the redirect URI is no longer registered on the client", + ); + + assert_eq!( + err.status_code, + StatusCode::UNAUTHORIZED, + "Expected 401 when redirect URI is deregistered, got {}", + err.status_code, + ); + } + /// The authorization code lookup should filter by provider so that a code /// issued for one provider (e.g. Google) is not returned when exchanging /// against a different provider (e.g. GitHub). This is a defense-in-depth From 8306f45c63f81aecfcc2bd2aa6244e1f6b80d160 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 7 May 2026 14:05:33 -0500 Subject: [PATCH 43/51] Version bump --- Cargo.lock | 14 +++++++------- Cargo.toml | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 521070ec..f9d0d73f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -752,7 +752,7 @@ dependencies = [ [[package]] name = "dropshot-authorization-header" -version = "0.3.0" +version = "0.4.0" dependencies = [ "async-trait", "base64", @@ -3561,7 +3561,7 @@ dependencies = [ [[package]] name = "v-api" -version = "0.3.0" +version = "0.4.0" dependencies = [ "anyhow", "async-trait", @@ -3610,7 +3610,7 @@ dependencies = [ [[package]] name = "v-api-installer" -version = "0.3.0" +version = "0.4.0" dependencies = [ "diesel", "diesel_migrations", @@ -3618,7 +3618,7 @@ dependencies = [ [[package]] name = "v-api-param" -version = "0.3.0" +version = "0.4.0" dependencies = [ "secrecy", "serde", @@ -3629,7 +3629,7 @@ dependencies = [ [[package]] name = "v-api-permission-derive" -version = "0.3.0" +version = "0.4.0" dependencies = [ "heck", "newtype-uuid", @@ -3669,7 +3669,7 @@ dependencies = [ [[package]] name = "v-model" -version = "0.3.0" +version = "0.4.0" dependencies = [ "async-bb8-diesel", "async-trait", @@ -4328,7 +4328,7 @@ checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" [[package]] name = "xtask" -version = "0.3.0" +version = "0.4.0" dependencies = [ "clap", "regex", diff --git a/Cargo.toml b/Cargo.toml index 8ccde396..f7066e83 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ resolver = "2" [workspace.package] publish = true edition = "2024" -version = "0.3.0" +version = "0.4.0" [workspace.dependencies] anyhow = "1.0" From 4bc588e9c9655ad8cfc2a5487a5cae21407916df Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 7 May 2026 14:12:36 -0500 Subject: [PATCH 44/51] Remove percent-encoding --- Cargo.lock | 1 - Cargo.toml | 1 - v-api/Cargo.toml | 1 - 3 files changed, 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f9d0d73f..b8763bb6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3583,7 +3583,6 @@ dependencies = [ "oauth2", "oauth2-reqwest", "partial-struct", - "percent-encoding", "rand 0.10.1", "reqwest", "rsa", diff --git a/Cargo.toml b/Cargo.toml index f7066e83..d3ad8588 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,6 @@ oauth2 = { version = "5.0.0", default-features = false } oauth2-reqwest = "0.1.0-alpha.3" owo-colors = "4.2.3" partial-struct = { git = "https://github.com/oxidecomputer/partial-struct" } -percent-encoding = "2.3.2" proc-macro2 = "1" progenitor-client = "0.14.0" quote = "1" diff --git a/v-api/Cargo.toml b/v-api/Cargo.toml index 35e0a586..0e036309 100644 --- a/v-api/Cargo.toml +++ b/v-api/Cargo.toml @@ -29,7 +29,6 @@ oauth2 = { workspace = true } oauth2-reqwest = { workspace = true } newtype-uuid = { workspace = true } partial-struct = { workspace = true } -percent-encoding = { workspace = true } rand = { workspace = true, features = ["std"] } reqwest = { workspace = true } rsa = { workspace = true, features = ["sha2"] } From 9b7d7db5939761e669ea952d1908012c7ca292f9 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 7 May 2026 15:00:07 -0500 Subject: [PATCH 45/51] Fix new provider info --- v-api/src/endpoints/login/oauth/flow/code.rs | 1 + v-api/src/endpoints/login/oauth/remote/github.rs | 2 +- v-api/src/endpoints/login/oauth/remote/google.rs | 2 +- v-api/src/endpoints/login/oauth/remote/zendesk.rs | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 9f389650..6a2a183b 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -571,6 +571,7 @@ where #[derive(Debug, Deserialize, JsonSchema)] pub struct OAuthAuthzCodeExchangeQuery { + #[serde(default)] pub request_idp_token: bool, } diff --git a/v-api/src/endpoints/login/oauth/remote/github.rs b/v-api/src/endpoints/login/oauth/remote/github.rs index ec814495..dc19de81 100644 --- a/v-api/src/endpoints/login/oauth/remote/github.rs +++ b/v-api/src/endpoints/login/oauth/remote/github.rs @@ -51,7 +51,7 @@ impl GitHubOAuthProvider { auth_url_endpoint: format!("{}/login/oauth/github/code/authorize", public_url), redirect_endpoint: format!("{}/login/oauth/github/code/callback", public_url), token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), - token_endpoint: format!("{}/login/oauth/github/device/exchange", public_url), + token_endpoint: format!("{}/login/oauth/github/code/token", public_url), remote: OAuthProviderAuthorizationCodeRemoteInfo { client_id: web.remote_client_id, client_secret: web.remote_client_secret.into(), diff --git a/v-api/src/endpoints/login/oauth/remote/google.rs b/v-api/src/endpoints/login/oauth/remote/google.rs index 6a91eeeb..92bae32d 100644 --- a/v-api/src/endpoints/login/oauth/remote/google.rs +++ b/v-api/src/endpoints/login/oauth/remote/google.rs @@ -51,7 +51,7 @@ impl GoogleOAuthProvider { auth_url_endpoint: format!("{}/login/oauth/google/code/authorize", public_url), redirect_endpoint: format!("{}/login/oauth/google/code/callback", public_url), token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), - token_endpoint: format!("{}/login/oauth/google/device/exchange", public_url), + token_endpoint: format!("{}/login/oauth/google/code/token", public_url), remote: OAuthProviderAuthorizationCodeRemoteInfo { client_id: web.remote_client_id, client_secret: web.remote_client_secret.into(), diff --git a/v-api/src/endpoints/login/oauth/remote/zendesk.rs b/v-api/src/endpoints/login/oauth/remote/zendesk.rs index b1c18597..cffb63df 100644 --- a/v-api/src/endpoints/login/oauth/remote/zendesk.rs +++ b/v-api/src/endpoints/login/oauth/remote/zendesk.rs @@ -50,7 +50,7 @@ impl ZendeskOAuthProvider { auth_url_endpoint: format!("{}/login/oauth/zendesk/code/authorize", public_url), redirect_endpoint: format!("{}/login/oauth/zendesk/code/callback", public_url), token_endpoint_content_type: "application/x-www-form-urlencoded".to_string(), - token_endpoint: format!("{}/login/oauth/zendesk/device/exchange", public_url), + token_endpoint: format!("{}/login/oauth/zendesk/code/token", public_url), remote: OAuthProviderAuthorizationCodeRemoteInfo { client_id: web.remote_client_id, client_secret: web.remote_client_secret.into(), From d0842df4276b60682b883413ecf320730ffa5455 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 7 May 2026 15:10:33 -0500 Subject: [PATCH 46/51] Clean up from main merge. Fix config error consumed. Update state transition error --- Cargo.lock | 2 +- v-api/src/config.rs | 4 +- v-api/src/endpoints/login/oauth/code.rs | 1637 ------------------ v-api/src/endpoints/login/oauth/flow/code.rs | 7 +- v-cli-sdk/Cargo.toml | 5 +- v-model/src/schema_ext.rs | 2 +- v-model/src/storage/postgres.rs | 4 +- 7 files changed, 10 insertions(+), 1651 deletions(-) delete mode 100644 v-api/src/endpoints/login/oauth/code.rs diff --git a/Cargo.lock b/Cargo.lock index b8763bb6..ea8930a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3645,7 +3645,7 @@ dependencies = [ [[package]] name = "v-cli-sdk" -version = "0.2.0" +version = "0.4.0" dependencies = [ "anyhow", "clap", diff --git a/v-api/src/config.rs b/v-api/src/config.rs index faf6bba1..cae7c2c1 100644 --- a/v-api/src/config.rs +++ b/v-api/src/config.rs @@ -202,8 +202,8 @@ impl OAuthConfig { .device .as_ref() .and_then(|d| d.resolve(base.clone()).ok()); - let web = self.web.as_ref().and_then(|w| w.resolve(base.clone()).ok()); - let proxy_web = self.proxy_web.as_ref().and_then(|p| p.resolve(base).ok()); + let web = self.web.as_ref().map(|w| w.resolve(base.clone())).transpose()?; + let proxy_web = self.proxy_web.as_ref().map(|p| p.resolve(base)).transpose()?; Ok(ResolvedOAuthConfig { device, web, diff --git a/v-api/src/endpoints/login/oauth/code.rs b/v-api/src/endpoints/login/oauth/code.rs deleted file mode 100644 index 6ba1d0d1..00000000 --- a/v-api/src/endpoints/login/oauth/code.rs +++ /dev/null @@ -1,1637 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -use base64::{Engine, prelude::BASE64_URL_SAFE_NO_PAD}; -use chrono::{TimeDelta, Utc}; -use cookie::{Cookie, SameSite}; -use dropshot::{ - ClientErrorStatusCode, HttpError, HttpResponseOk, HttpResponseTemporaryRedirect, Path, Query, - RequestContext, RequestInfo, SharedExtractor, TypedBody, http_response_temporary_redirect, -}; -use dropshot_authorization_header::basic::BasicAuth; -use http::{HeaderValue, header::SET_COOKIE}; -use newtype_uuid::TypedUuid; -use oauth2::{ - AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, -}; -use percent_encoding::{NON_ALPHANUMERIC, percent_encode}; -use schemars::JsonSchema; -use secrecy::SecretString; -use serde::{Deserialize, Serialize}; -use sha2::{Digest, Sha256}; -use std::{fmt::Debug, ops::Add}; -use tap::TapFallible; -use tracing::instrument; -use v_model::{ - LoginAttempt, LoginAttemptId, NewLoginAttempt, OAuthClient, OAuthClientId, - permissions::{AsScope, PermissionStorage}, - schema_ext::LoginAttemptState, -}; - -use super::{OAuthProvider, OAuthProviderNameParam, UserInfoProvider, WebClientConfig}; -use crate::{ - authn::key::RawKey, - context::{ApiContext, VContext}, - endpoints::login::{ - LoginError, UserInfo, - oauth::{CheckOAuthClient, ClientType}, - }, - error::ApiError, - permissions::{VAppPermission, VPermission}, - secrets::OpenApiSecretString, - util::{ - request::RequestCookies, - response::{ResourceError, internal_error, to_internal_error, unauthorized}, - }, -}; - -static LOGIN_ATTEMPT_COOKIE: &str = "__v_login"; -static DEFAULT_SCOPE: &str = "user:info:r"; - -#[derive(Debug, Deserialize, JsonSchema, Serialize, PartialEq, Eq)] -struct OAuthError { - error: OAuthErrorCode, - #[serde(skip_serializing_if = "Option::is_none")] - error_description: Option, - #[serde(skip_serializing_if = "Option::is_none")] - error_uri: Option, - #[serde(skip_serializing_if = "Option::is_none")] - state: Option, -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize, PartialEq, Eq)] -#[serde(untagged)] -enum OAuthErrorCode { - AccessDenied, - InvalidClient, - InvalidGrant, - InvalidRequest, - InvalidScope, - ServerError, - TemporarilyUnavailable, - UnauthorizedClient, - UnsupportedGrantType, - UnsupportedResponseType, -} - -impl From for HttpError { - fn from(value: OAuthError) -> Self { - let serialized = serde_json::to_string(&value).unwrap(); - HttpError { - headers: None, - status_code: ClientErrorStatusCode::BAD_REQUEST.into(), - error_code: None, - external_message: serialized.clone(), - internal_message: serialized, - } - } -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct OAuthAuthzCodeQuery { - pub client_id: TypedUuid, - pub redirect_uri: String, - pub response_type: String, - pub state: String, - pub scope: Option, -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct OAuthAuthzCodeRedirectHeaders { - #[serde(rename = "set-cookies")] - cookies: String, - location: String, -} - -// Lookup the client specified by the provided client id and verify that the redirect uri -// is a valid for this client. If either of these fail we return an unauthorized response -async fn get_oauth_client( - ctx: &VContext, - client_id: &TypedUuid, - redirect_uri: &str, -) -> Result -where - T: VAppPermission + PermissionStorage, -{ - let client = ctx - .oauth - .get_oauth_client(&ctx.builtin_registration_user(), client_id) - .await - .map_err(|err| { - tracing::error!(?err, "Failed to lookup OAuth client"); - - match err { - ResourceError::DoesNotExist => OAuthError { - error: OAuthErrorCode::InvalidClient, - error_description: Some("Unknown client id".to_string()), - error_uri: None, - state: None, - }, - // Given that the builtin caller should have access to all OAuth clients, any other - // error is considered an internal error - _ => OAuthError { - error: OAuthErrorCode::ServerError, - error_description: None, - error_uri: None, - state: None, - }, - } - })?; - - if client.is_redirect_uri_valid(redirect_uri) { - Ok(client) - } else { - Err(OAuthError { - error: OAuthErrorCode::InvalidRequest, - error_description: Some("Invalid redirect uri".to_string()), - error_uri: None, - state: None, - }) - } -} - -#[instrument(skip(rqctx), err(Debug))] -pub async fn authz_code_redirect_op( - rqctx: &RequestContext>, - path: Path, - query: Query, -) -> Result -where - T: VAppPermission + PermissionStorage, -{ - let ctx = rqctx.v_ctx(); - let path = path.into_inner(); - let query = query.into_inner(); - - get_oauth_client(ctx, &query.client_id, &query.redirect_uri).await?; - - tracing::debug!(?query.client_id, ?query.redirect_uri, "Verified client id and redirect uri"); - - // Find the configured provider for the requested remote backend. We should always have a valid - // provider value, so if this fails then a 500 is returned - let provider = ctx - .get_oauth_provider(&path.provider) - .await - .map_err(ApiError::OAuth)?; - - tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for authz code login"); - - // Check that the passed in scopes are valid. The scopes are not currently restricted by client - let scope = query.scope.unwrap_or_else(|| DEFAULT_SCOPE.to_string()); - let scope_error = VPermission::from_scope_arg(&scope) - .err() - .map(|_| "invalid_scope".to_string()); - - // Construct a new login attempt with the minimum required values - let mut attempt = NewLoginAttempt::new( - provider.name().to_string(), - query.client_id, - query.redirect_uri, - scope, - ) - .map_err(|err| { - tracing::error!(?err, "Attempted to construct invalid login attempt"); - internal_error("Attempted to construct invalid login attempt".to_string()) - })?; - - // Set a default expiration for the login attempt - // TODO: Make this configurable - attempt.expires_at = Some(Utc::now().add(TimeDelta::try_minutes(5).unwrap())); - - // Assign any scope errors that arose - attempt.error = scope_error; - - // Add in the user defined state and redirect uri. State is an arbitrary value and may be - // malicious. It must be url-encoded before being presented back to the client. Therefore we - // process once before storing so all downstream consumers see the encoded value. - attempt.state = Some(percent_encode(query.state.as_bytes(), NON_ALPHANUMERIC).to_string()); - - // If the remote provider supports pkce, set up a challenge - let pkce_challenge = if provider.supports_pkce() { - let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); - attempt.provider_pkce_verifier = Some(pkce_verifier.secret().to_string()); - Some(pkce_challenge) - } else { - None - }; - - // Store the generated attempt - let attempt = ctx - .login - .create_login_attempt(attempt) - .await - .map_err(to_internal_error)?; - - tracing::info!(?attempt.id, "Created login attempt"); - - oauth_redirect_response(ctx.public_url(), &*provider, &attempt, pkce_challenge) -} - -fn oauth_redirect_response( - public_url: &str, - provider: &dyn OAuthProvider, - attempt: &LoginAttempt, - code_challenge: Option, -) -> Result { - // We may fail if the provider configuration is not correctly configured - // TODO: This behavior should be changed so that clients are precomputed. We do not need to be - // constructing a new client on every request. That said, we need to ensure the client does not - // maintain state between requests - let client = provider - .as_web_client(&WebClientConfig { - prefix: public_url.to_string(), - }) - .map_err(to_internal_error)?; - - // Create an attempt cookie header for storing the login attempt. This also acts as our csrf - // check - let mut cookie = Cookie::new(LOGIN_ATTEMPT_COOKIE, attempt.id.to_string()); - cookie.set_http_only(true); - cookie.set_same_site(SameSite::Lax); - cookie.set_secure(public_url.starts_with("https")); - cookie.set_max_age(cookie::time::Duration::seconds(600)); - - let login_cookie = HeaderValue::from_str(&cookie.to_string()).map_err(to_internal_error)?; - - // Generate the url to the remote provider that the user will be redirected to - let mut authz_url = client - .authorize_url(|| CsrfToken::new(attempt.id.to_string())) - .add_scopes( - provider - .scopes() - .into_iter() - .map(|s| Scope::new(s.to_string())) - .collect::>(), - ); - - // If the caller has provided a code challenge, add it to the url - if let Some(challenge) = code_challenge { - authz_url = authz_url.set_pkce_challenge(challenge); - }; - - let mut redirect = http_response_temporary_redirect(authz_url.url().0.to_string())?; - redirect.headers_mut().append(SET_COOKIE, login_cookie); - - Ok(redirect) -} - -// TODO: Determine if 401 empty responses are correct here -fn verify_csrf( - request: &RequestInfo, - query: &OAuthAuthzCodeReturnQuery, -) -> Result, HttpError> { - // If we are missing the expected state parameter then we can not proceed at all with verifying - // this callback request. We also do not have a redirect uri to send the user to so we instead - // report unauthorized - let attempt_id = query - .state - .as_ref() - .ok_or_else(|| { - tracing::warn!("OAuth callback is missing a state parameter"); - unauthorized() - })? - .parse() - .map_err(|err| { - tracing::warn!(?err, "Failed to parse state"); - unauthorized() - })?; - - // The client must present the attempt cookie at a minimum. Without it we are unable to lookup a - // login attempt to match against. Without the cookie to verify the state parameter we can not - // determine a redirect uri so we instead report unauthorized - let attempt_cookie = request - .cookie(LOGIN_ATTEMPT_COOKIE) - .ok_or_else(|| { - tracing::warn!("OAuth callback is missing a login state cookie"); - unauthorized() - })? - .value() - .parse() - .map_err(|err| { - tracing::warn!(?err, "Failed to parse state"); - unauthorized() - })?; - - // Verify that the attempt_id returned from the state matches the expected client value. If they - // do not match we can not lookup a redirect uri so we instead return unauthorized - if attempt_id != attempt_cookie { - tracing::warn!( - ?attempt_id, - ?attempt_cookie, - "OAuth state does not match expected cookie value" - ); - Err(unauthorized()) - } else { - Ok(attempt_id) - } -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct OAuthAuthzCodeReturnQuery { - pub state: Option, - pub code: Option, - pub error: Option, -} - -#[instrument(skip(rqctx), err(Debug))] -pub async fn authz_code_callback_op( - rqctx: &RequestContext>, - path: Path, - query: Query, -) -> Result -where - T: VAppPermission + PermissionStorage, -{ - let ctx = rqctx.v_ctx(); - let path = path.into_inner(); - let query = query.into_inner(); - let provider = ctx - .get_oauth_provider(&path.provider) - .await - .map_err(ApiError::OAuth)?; - - tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for authz code exchange"); - - // Verify and extract the attempt id before performing any work - let attempt_id = verify_csrf(&rqctx.request, &query)?; - - // Clear the login attempt cookie - let mut cookie = Cookie::new(LOGIN_ATTEMPT_COOKIE, ""); - cookie.set_http_only(true); - cookie.set_same_site(SameSite::Lax); - cookie.set_secure(ctx.public_url().starts_with("https")); - cookie.set_max_age(cookie::time::Duration::seconds(0)); - let login_cookie = HeaderValue::from_str(&cookie.to_string()).map_err(to_internal_error)?; - - let mut redirect = http_response_temporary_redirect( - authz_code_callback_op_inner(ctx, &attempt_id, query.code, query.error).await?, - )?; - redirect.headers_mut().append(SET_COOKIE, login_cookie); - - Ok(redirect) -} - -pub async fn authz_code_callback_op_inner( - ctx: &VContext, - attempt_id: &TypedUuid, - code: Option, - error: Option, -) -> Result -where - T: VAppPermission + PermissionStorage, -{ - // We have now verified the attempt id and can use it to look up the rest of the login attempt - // material to try and complete the flow - let mut attempt = ctx - .login - .get_login_attempt(attempt_id) - .await - .map_err(to_internal_error)? - .ok_or_else(|| { - // If we fail to find a matching attempt, there is not much we can do other than return - // unauthorized - unauthorized() - }) - .and_then(|attempt| { - if attempt.attempt_state == LoginAttemptState::New { - Ok(attempt) - } else { - Err(unauthorized()) - } - })?; - - attempt = match (code, error) { - (Some(code), None) => { - tracing::info!(?attempt.id, "Received valid login attempt. Storing authorization code"); - - // Store the authorization code returned by the underlying OAuth provider and transition the - // attempt to the awaiting state - ctx.login - .set_login_provider_authz_code(attempt, code.to_string()) - .await - .map_err(to_internal_error)? - } - (code, error) => { - tracing::info!(?attempt.id, ?error, "Received an error response from the remote server"); - - // Store the provider return error for future debugging, but if an error has been - // returned or there is a missing code, then we can not report a successful process - attempt.provider_authz_code = code; - - // When a user has explicitly denied access we want to forward that error message - // onwards to the upstream requester. All other errors should be opaque to the - // original requester and are returned as server errors - let error_message = match error.as_deref() { - Some("access_denied") => "access_denied", - _ => "server_error", - }; - - // TODO: Specialize the returned error - ctx.login - .fail_login_attempt(attempt, Some(error_message), error.as_deref()) - .await - .map_err(to_internal_error)? - } - }; - - // Redirect back to the original authenticator - Ok(attempt.callback_url()) -} - -#[derive(Debug, Deserialize, JsonSchema)] -pub struct OAuthAuthzCodeExchangeBody { - pub client_id: Option>, - pub client_secret: Option, - pub redirect_uri: String, - pub grant_type: String, - pub code: String, - pub pkce_verifier: Option, -} - -#[derive(Debug, Deserialize, JsonSchema, Serialize)] -pub struct OAuthAuthzCodeExchangeResponse { - pub access_token: String, - pub token_type: String, - pub expires_in: i64, -} - -#[instrument(skip(rqctx), err(Debug))] -pub async fn authz_code_exchange_op( - rqctx: &RequestContext>, - path: Path, - body: TypedBody, -) -> Result, HttpError> -where - T: VAppPermission + PermissionStorage, -{ - let ctx = rqctx.v_ctx(); - let path = path.into_inner(); - let body = body.into_inner(); - - let (client_id, client_secret) = - if let (Some(client_id), Some(client_secret)) = (body.client_id, body.client_secret) { - Ok::<_, HttpError>((client_id, client_secret)) - } else { - // Attempt to extract basic authorization credentials from the request if they were not - // present in the request body - let auth = ::from_request(rqctx) - .await - .tap_err(|err| { - tracing::warn!(?err, "Failed to extract basic authentication values"); - }); - let (client_id, client_secret) = match auth { - Ok(auth) if auth.username().is_some() && auth.password().is_some() => Ok(( - auth.username().unwrap().to_string(), - auth.password().unwrap().to_string(), - )), - _ => Err(internal_error( - "Missing client id and client secret from authz code exchange", - )), - }?; - - Ok(( - client_id.parse().map_err(to_internal_error)?, - OpenApiSecretString(client_secret.into()), - )) - }?; - - let provider = ctx - .get_oauth_provider(&path.provider) - .await - .map_err(ApiError::OAuth)?; - - tracing::debug!("Attempting code exchange"); - - // Verify the submitted client credentials - authorize_code_exchange( - ctx, - &body.grant_type, - client_id, - &client_secret.0, - &body.redirect_uri, - ) - .await?; - - tracing::debug!("Authorized code exchange"); - - // Lookup the request assigned to this code - let mut attempt = ctx - .login - .get_login_attempt_for_code(&body.code) - .await - .map_err(to_internal_error)? - .ok_or(OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: None, - error_uri: None, - state: None, - })?; - - // Verify that the login attempt is valid and matches the submitted client credentials - verify_login_attempt( - &attempt, - client_id, - &body.redirect_uri, - body.pkce_verifier.as_deref(), - )?; - - tracing::debug!("Verified login attempt"); - - // Now that the attempt has been confirmed, use it to fetch user information form the remote - // provider - let info = fetch_user_info(ctx.public_url(), &ctx.web_client(), &*provider, &attempt).await?; - - tracing::debug!("Retrieved user information from remote provider"); - - // During fetch_user_info we revoke any downstream codes if possible, therefore At this point we - // consider the login attempt to be consumed and can no longer be used. We state transition to - // complete, even though we may fail further along in the handler. If a failure occurs then the - // user will need to re-authenticate. - attempt = ctx - .login - .complete_login_attempt(attempt) - .await - .map_err(|err| { - tracing::error!(?err, "Failed to complete login attempt"); - OAuthError { - error: OAuthErrorCode::ServerError, - error_description: Some("An unexpected error occurred".to_string()), - error_uri: None, - state: None, - } - })?; - - // Register this user as an API user if needed - let (api_user_info, api_user_provider) = ctx - .register_api_user(&ctx.builtin_registration_user(), info) - .await?; - - tracing::info!(api_user_id = ?api_user_info.user.id, "Retrieved api user to generate access token for"); - - let scope = attempt - .scope - .split(' ') - .map(|s| s.to_string()) - .collect::>(); - - let token = ctx - .generate_access_token( - &ctx.builtin_registration_user(), - &api_user_info.user.id, - &api_user_provider.id, - Some(scope), - ) - .await?; - - Ok(HttpResponseOk(OAuthAuthzCodeExchangeResponse { - token_type: "Bearer".to_string(), - access_token: token.signed_token, - expires_in: token.expires_in, - })) -} - -async fn authorize_code_exchange( - ctx: &VContext, - grant_type: &str, - client_id: TypedUuid, - client_secret: &SecretString, - redirect_uri: &str, -) -> Result<(), OAuthError> -where - T: VAppPermission + PermissionStorage, -{ - let client = get_oauth_client(ctx, &client_id, redirect_uri).await?; - - // Verify that we received the expected grant type - if grant_type != "authorization_code" { - return Err(OAuthError { - error: OAuthErrorCode::UnsupportedGrantType, - error_description: None, - error_uri: None, - state: None, - }); - } - - tracing::debug!(grant_type, "Verified grant type"); - - let client_secret = RawKey::try_from(client_secret).map_err(|err| { - tracing::warn!(?err, "Failed to parse OAuth client secret"); - - OAuthError { - error: OAuthErrorCode::InvalidRequest, - error_description: Some("Malformed client secret".to_string()), - error_uri: None, - state: None, - } - })?; - - tracing::debug!("Constructed client secret"); - - if !client.is_secret_valid(&client_secret, ctx) { - Err(OAuthError { - error: OAuthErrorCode::InvalidClient, - error_description: Some("Invalid client secret".to_string()), - error_uri: None, - state: None, - }) - } else { - tracing::debug!("Verified client secret validity"); - - Ok(()) - } -} - -fn verify_login_attempt( - attempt: &LoginAttempt, - client_id: TypedUuid, - redirect_uri: &str, - pkce_verifier: Option<&str>, -) -> Result<(), OAuthError> { - if attempt.client_id != client_id { - Err(OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Invalid client id".to_string()), - error_uri: None, - state: None, - }) - } else if attempt.redirect_uri != redirect_uri { - Err(OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Invalid redirect uri".to_string()), - error_uri: None, - state: None, - }) - } else if attempt.attempt_state != LoginAttemptState::RemoteAuthenticated { - Err(OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Grant is in an invalid state".to_string()), - error_uri: None, - state: None, - }) - } else if attempt.expires_at.map(|t| t <= Utc::now()).unwrap_or(true) { - Err(OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Grant has expired".to_string()), - error_uri: None, - state: None, - }) - } else { - match (attempt.pkce_challenge.as_deref(), pkce_verifier) { - (Some(_), None) => Err(OAuthError { - error: OAuthErrorCode::InvalidRequest, - error_description: Some("Missing pkce verifier".to_string()), - error_uri: None, - state: None, - }), - (Some(challenge), Some(verifier)) => { - let mut hasher = Sha256::new(); - hasher.update(verifier); - let hash = hasher.finalize(); - let computed_challenge = BASE64_URL_SAFE_NO_PAD.encode(hash); - - if challenge == computed_challenge { - Ok(()) - } else { - Err(OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Invalid pkce verifier".to_string()), - error_uri: None, - state: None, - }) - } - } - (None, _) => Ok(()), - } - } -} - -#[instrument(skip(attempt))] -async fn fetch_user_info( - public_url: &str, - client_type: &ClientType, - provider: &dyn OAuthProvider, - attempt: &LoginAttempt, -) -> Result { - // Exchange the stored authorization code with the remote provider for a remote access token - let client = provider - .as_web_client(&WebClientConfig { - prefix: public_url.to_string(), - }) - .map_err(to_internal_error)?; - - let mut request = client.exchange_code(AuthorizationCode::new( - attempt - .provider_authz_code - .as_ref() - .ok_or_else(|| { - internal_error("Expected authorization code to exist due to attempt state") - })? - .to_string(), - )); - - if let Some(pkce_verifier) = &attempt.provider_pkce_verifier { - request = request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier.to_string())) - } - - let oauth_client: oauth2_reqwest::ReqwestClient = provider.client().clone().into(); - let response = request - .request_async(&oauth_client) - .await - .map_err(to_internal_error)?; - - tracing::info!("Fetched access token from remote service"); - - // Use the retrieved access token to fetch the user information from the remote API - let info = provider - .get_user_info(response.access_token().secret()) - .await - .map_err(LoginError::UserInfo) - .tap_err(|err| tracing::error!(?err, "Failed to look up user information"))?; - - tracing::info!("Fetched user info from remote service"); - - // Now that we are done with fetching user information from the remote API, we can revoke it if - // the provider supports it - if provider.token_revocation_endpoint().is_some() { - client - .revoke_token(response.access_token().into()) - .map_err(internal_error)? - .request_async(&oauth_client) - .await - .map_err(internal_error)?; - } - - Ok(info) -} - -#[cfg(test)] -mod tests { - use std::{ - net::{Ipv4Addr, SocketAddrV4}, - ops::Add, - sync::{Arc, Mutex}, - }; - - use chrono::{TimeDelta, Utc}; - use dropshot::{HttpResponse, RequestInfo}; - use http::{ - HeaderValue, StatusCode, - header::{COOKIE, LOCATION, SET_COOKIE}, - }; - use http_body_util::Empty; - use mockall::predicate::eq; - use newtype_uuid::TypedUuid; - use oauth2::PkceCodeChallenge; - use secrecy::SecretString; - use uuid::Uuid; - use v_model::{ - LoginAttempt, OAuthClient, OAuthClientRedirectUri, OAuthClientSecret, - schema_ext::LoginAttemptState, - storage::{MockLoginAttemptStore, MockOAuthClientStore}, - }; - - use crate::{ - authn::key::RawKey, - context::{ - VContext, - test_mocks::{MockStorage, mock_context}, - }, - endpoints::login::oauth::{ - OAuthProviderName, - code::{ - LOGIN_ATTEMPT_COOKIE, OAuthAuthzCodeReturnQuery, OAuthError, OAuthErrorCode, - authz_code_callback_op_inner, verify_csrf, verify_login_attempt, - }, - }, - permissions::VPermission, - }; - - use super::{authorize_code_exchange, get_oauth_client, oauth_redirect_response}; - - async fn mock_client() -> (VContext, OAuthClient, SecretString) { - let ctx = mock_context(Arc::new(MockStorage::new())).await; - let client_id = TypedUuid::new_v4(); - let key = RawKey::generate::<8>(&Uuid::new_v4()) - .sign(ctx.signer()) - .await - .unwrap(); - let secret_signature = key.signature().to_string(); - let client_secret = key.key(); - let redirect_uri = "callback-destination"; - - ( - ctx, - OAuthClient { - id: client_id, - secrets: vec![OAuthClientSecret { - id: TypedUuid::new_v4(), - oauth_client_id: client_id, - secret_signature, - created_at: Utc::now(), - deleted_at: None, - }], - redirect_uris: vec![OAuthClientRedirectUri { - id: TypedUuid::new_v4(), - oauth_client_id: client_id, - redirect_uri: redirect_uri.to_string(), - created_at: Utc::now(), - deleted_at: None, - }], - created_at: Utc::now(), - deleted_at: None, - }, - client_secret, - ) - } - - #[tokio::test] - async fn test_oauth_client_lookup_checks_redirect_uri() { - let client_id = TypedUuid::new_v4(); - let client = OAuthClient { - id: client_id, - secrets: vec![], - redirect_uris: vec![OAuthClientRedirectUri { - id: TypedUuid::new_v4(), - oauth_client_id: client_id, - redirect_uri: "https://test.oxeng.dev/callback".to_string(), - created_at: Utc::now(), - deleted_at: None, - }], - created_at: Utc::now(), - deleted_at: None, - }; - - let mut client_store = MockOAuthClientStore::new(); - client_store - .expect_get() - .with(eq(client_id), eq(false)) - .returning(move |_, _| Ok(Some(client.clone()))); - - let mut storage = MockStorage::new(); - storage.oauth_client_store = Some(Arc::new(client_store)); - let ctx = mock_context(Arc::new(storage)).await; - - let failure = get_oauth_client(&ctx, &client_id, "https://not-test.oxeng.dev/callback") - .await - .unwrap_err(); - assert_eq!(OAuthErrorCode::InvalidRequest, failure.error); - assert_eq!( - Some("Invalid redirect uri".to_string()), - failure.error_description - ); - - let success = get_oauth_client(&ctx, &client_id, "https://test.oxeng.dev/callback").await; - assert_eq!(client_id, success.unwrap().id); - } - - #[tokio::test] - async fn test_remote_provider_redirect_url() { - let storage = MockStorage::new(); - let mut ctx = mock_context(Arc::new(storage)).await; - ctx.with_public_url("https://api.oxeng.dev"); - - let (challenge, _) = PkceCodeChallenge::new_random_sha256(); - let attempt = LoginAttempt { - id: TypedUuid::new_v4(), - attempt_state: LoginAttemptState::New, - client_id: TypedUuid::new_v4(), - redirect_uri: "https://test.oxeng.dev/callback".to_string(), - state: Some("ox_state".to_string()), - pkce_challenge: Some("ox_challenge".to_string()), - pkce_challenge_method: Some("S256".to_string()), - authz_code: None, - expires_at: None, - error: None, - provider: "google".to_string(), - provider_pkce_verifier: Some("v_verifier".to_string()), - provider_authz_code: None, - provider_error: None, - created_at: Utc::now(), - updated_at: Utc::now(), - scope: String::new(), - }; - - let response = oauth_redirect_response( - ctx.public_url(), - &*ctx - .get_oauth_provider(&OAuthProviderName::Google) - .await - .unwrap(), - &attempt, - Some(challenge.clone()), - ) - .unwrap() - .to_result() - .unwrap(); - let headers = response.headers(); - - let expected_location = format!( - "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=google_web_client_id&state={}&code_challenge={}&code_challenge_method=S256&redirect_uri=https%3A%2F%2Fapi.oxeng.dev%2Flogin%2Foauth%2Fgoogle%2Fcode%2Fcallback&scope=openid+email+profile", - attempt.id, - challenge.as_str() - ); - - assert_eq!( - expected_location, - String::from_utf8(headers.get(LOCATION).unwrap().as_bytes().to_vec()).unwrap() - ); - assert_eq!( - format!( - "{}; HttpOnly; SameSite=Lax; Secure; Max-Age=600", - attempt.id - ) - .as_str(), - String::from_utf8(headers.get(SET_COOKIE).unwrap().as_bytes().to_vec()) - .unwrap() - .split_once('=') - .unwrap() - .1 - ) - } - - #[tokio::test] - async fn test_csrf_check() { - let id = TypedUuid::new_v4(); - - let mut rq = hyper::Request::new(Empty::<()>::new()); - rq.headers_mut().insert( - COOKIE, - HeaderValue::from_str(&format!("{}={}", LOGIN_ATTEMPT_COOKIE, id)).unwrap(), - ); - let with_valid_cookie = RequestInfo::new( - &rq, - std::net::SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8888)), - ); - let query = OAuthAuthzCodeReturnQuery { - state: Some(id.to_string()), - code: None, - error: None, - }; - assert_eq!(id, verify_csrf(&with_valid_cookie, &query).unwrap()); - - let query = OAuthAuthzCodeReturnQuery { - state: None, - code: None, - error: None, - }; - assert_eq!( - StatusCode::UNAUTHORIZED, - verify_csrf(&with_valid_cookie, &query) - .unwrap_err() - .status_code - ); - - let mut rq = hyper::Request::new(Empty::<()>::new()); - rq.headers_mut().insert( - COOKIE, - HeaderValue::from_str(&format!("{}={}", LOGIN_ATTEMPT_COOKIE, Uuid::new_v4())).unwrap(), - ); - let with_invalid_cookie = RequestInfo::new( - &rq, - std::net::SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8888)), - ); - let query = OAuthAuthzCodeReturnQuery { - state: Some(id.to_string()), - code: None, - error: None, - }; - assert_eq!( - StatusCode::UNAUTHORIZED, - verify_csrf(&with_invalid_cookie, &query) - .unwrap_err() - .status_code - ); - - let rq = hyper::Request::new(Empty::<()>::new()); - let with_missing_cookie = RequestInfo::new( - &rq, - std::net::SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 8888)), - ); - let query = OAuthAuthzCodeReturnQuery { - state: Some(id.to_string()), - code: None, - error: None, - }; - assert_eq!( - StatusCode::UNAUTHORIZED, - verify_csrf(&with_missing_cookie, &query) - .unwrap_err() - .status_code - ); - } - - #[tokio::test] - async fn test_callback_fails_when_not_in_new_state() { - let invalid_states = [ - LoginAttemptState::Complete, - LoginAttemptState::Failed, - LoginAttemptState::RemoteAuthenticated, - ]; - - for state in invalid_states { - let attempt_id = TypedUuid::new_v4(); - let attempt = LoginAttempt { - id: attempt_id, - attempt_state: state, - client_id: TypedUuid::new_v4(), - redirect_uri: "https://test.oxeng.dev/callback".to_string(), - state: Some("ox_state".to_string()), - pkce_challenge: Some("ox_challenge".to_string()), - pkce_challenge_method: Some("S256".to_string()), - authz_code: None, - expires_at: None, - error: None, - provider: "google".to_string(), - provider_pkce_verifier: Some("v_verifier".to_string()), - provider_authz_code: None, - provider_error: None, - created_at: Utc::now(), - updated_at: Utc::now(), - scope: String::new(), - }; - - let mut storage = MockStorage::new(); - let mut attempt_store = MockLoginAttemptStore::new(); - attempt_store - .expect_get() - .with(eq(attempt.id)) - .returning(move |_| Ok(Some(attempt.clone()))); - storage.login_attempt_store = Some(Arc::new(attempt_store)); - - let ctx = mock_context(Arc::new(storage)).await; - let err = authz_code_callback_op_inner( - &ctx, - &attempt_id, - Some("remote-code".to_string()), - None, - ) - .await; - - assert_eq!(StatusCode::UNAUTHORIZED, err.unwrap_err().status_code); - } - } - - #[tokio::test] - async fn test_callback_fails_when_error_is_passed() { - let attempt_id = TypedUuid::new_v4(); - let attempt = LoginAttempt { - id: attempt_id, - attempt_state: LoginAttemptState::New, - client_id: TypedUuid::new_v4(), - redirect_uri: "https://test.oxeng.dev/callback".to_string(), - state: Some("ox_state".to_string()), - pkce_challenge: Some("ox_challenge".to_string()), - pkce_challenge_method: Some("S256".to_string()), - authz_code: None, - expires_at: None, - error: None, - provider: "google".to_string(), - provider_pkce_verifier: Some("v_verifier".to_string()), - provider_authz_code: None, - provider_error: None, - created_at: Utc::now(), - updated_at: Utc::now(), - scope: String::new(), - }; - - let mut attempt_store = MockLoginAttemptStore::new(); - let original_attempt = attempt.clone(); - attempt_store - .expect_get() - .with(eq(attempt.id)) - .returning(move |_| Ok(Some(original_attempt.clone()))); - - attempt_store - .expect_upsert() - .withf(|attempt| attempt.attempt_state == LoginAttemptState::Failed) - .returning(move |arg| { - let mut returned = attempt.clone(); - returned.attempt_state = arg.attempt_state; - returned.authz_code = arg.authz_code; - returned.error = arg.error; - Ok(returned) - }); - - let mut storage = MockStorage::new(); - storage.login_attempt_store = Some(Arc::new(attempt_store)); - let ctx = mock_context(Arc::new(storage)).await; - - let location = authz_code_callback_op_inner( - &ctx, - &attempt_id, - Some("remote-code".to_string()), - Some("not_access_denied".to_string()), - ) - .await - .unwrap(); - - assert_eq!( - format!("https://test.oxeng.dev/callback?error=server_error&state=ox_state",), - location - ); - } - - #[tokio::test] - async fn test_callback_forwards_access_denied() { - let attempt_id = TypedUuid::new_v4(); - let attempt = LoginAttempt { - id: attempt_id, - attempt_state: LoginAttemptState::New, - client_id: TypedUuid::new_v4(), - redirect_uri: "https://test.oxeng.dev/callback".to_string(), - state: Some("ox_state".to_string()), - pkce_challenge: Some("ox_challenge".to_string()), - pkce_challenge_method: Some("S256".to_string()), - authz_code: None, - expires_at: None, - error: None, - provider: "google".to_string(), - provider_pkce_verifier: Some("v_verifier".to_string()), - provider_authz_code: None, - provider_error: None, - created_at: Utc::now(), - updated_at: Utc::now(), - scope: String::new(), - }; - - let mut attempt_store = MockLoginAttemptStore::new(); - let original_attempt = attempt.clone(); - attempt_store - .expect_get() - .with(eq(attempt.id)) - .returning(move |_| Ok(Some(original_attempt.clone()))); - - attempt_store - .expect_upsert() - .withf(|attempt| attempt.attempt_state == LoginAttemptState::Failed) - .returning(move |arg| { - let mut returned = attempt.clone(); - returned.attempt_state = arg.attempt_state; - returned.authz_code = arg.authz_code; - returned.error = arg.error; - Ok(returned) - }); - - let mut storage = MockStorage::new(); - storage.login_attempt_store = Some(Arc::new(attempt_store)); - let ctx = mock_context(Arc::new(storage)).await; - - let location = authz_code_callback_op_inner( - &ctx, - &attempt_id, - Some("remote-code".to_string()), - Some("access_denied".to_string()), - ) - .await - .unwrap(); - - assert_eq!( - format!("https://test.oxeng.dev/callback?error=access_denied&state=ox_state",), - location - ); - } - - #[tokio::test] - async fn test_handles_callback_with_code() { - let attempt_id = TypedUuid::new_v4(); - let attempt = LoginAttempt { - id: attempt_id, - attempt_state: LoginAttemptState::New, - client_id: TypedUuid::new_v4(), - redirect_uri: "https://test.oxeng.dev/callback".to_string(), - state: Some("ox_state".to_string()), - pkce_challenge: Some("ox_challenge".to_string()), - pkce_challenge_method: Some("S256".to_string()), - authz_code: None, - expires_at: None, - error: None, - provider: "google".to_string(), - provider_pkce_verifier: Some("v_verifier".to_string()), - provider_authz_code: None, - provider_error: None, - created_at: Utc::now(), - updated_at: Utc::now(), - scope: String::new(), - }; - - let mut attempt_store = MockLoginAttemptStore::new(); - let original_attempt = attempt.clone(); - attempt_store - .expect_get() - .with(eq(attempt.id)) - .returning(move |_| Ok(Some(original_attempt.clone()))); - - let extracted_code = Arc::new(Mutex::new(None)); - let extractor = extracted_code.clone(); - attempt_store - .expect_upsert() - .withf(|attempt| attempt.attempt_state == LoginAttemptState::RemoteAuthenticated) - .returning(move |arg| { - let mut returned = attempt.clone(); - returned.attempt_state = arg.attempt_state; - returned.authz_code = arg.authz_code; - *extractor.lock().unwrap() = returned.authz_code.clone(); - Ok(returned) - }); - - let mut storage = MockStorage::new(); - storage.login_attempt_store = Some(Arc::new(attempt_store)); - let ctx = mock_context(Arc::new(storage)).await; - - let location = - authz_code_callback_op_inner(&ctx, &attempt_id, Some("remote-code".to_string()), None) - .await - .unwrap(); - - let lock = extracted_code.lock(); - assert_eq!( - format!( - "https://test.oxeng.dev/callback?code={}&state=ox_state", - lock.unwrap().as_ref().unwrap() - ), - location - ); - } - - #[tokio::test] - async fn test_fails_callback_with_error() {} - - #[tokio::test] - async fn test_exchange_checks_client_id_and_redirect() { - let (mut ctx, client, client_secret) = mock_client().await; - let client_id = client.id; - let redirect_uri = client.redirect_uris[0].redirect_uri.clone(); - let wrong_client_id = TypedUuid::new_v4(); - - let mut client_store = MockOAuthClientStore::new(); - client_store - .expect_get() - .with(eq(wrong_client_id), eq(false)) - .returning(move |_, _| Ok(None)); - client_store - .expect_get() - .with(eq(client_id), eq(false)) - .returning(move |_, _| Ok(Some(client.clone()))); - - let mut storage = MockStorage::new(); - storage.oauth_client_store = Some(Arc::new(client_store)); - - ctx.set_storage(Arc::new(storage)); - - // 1. Verify exchange fails when passing an incorrect client id - assert_eq!( - Some("Unknown client id".to_string()), - authorize_code_exchange( - &ctx, - "authorization_code", - wrong_client_id, - &client_secret, - &redirect_uri, - ) - .await - .unwrap_err() - .error_description - ); - - // 2. Verify exchange fails when passing an incorrect redirect uri - assert_eq!( - Some("Invalid redirect uri".to_string()), - authorize_code_exchange( - &ctx, - "authorization_code", - client_id, - &client_secret, - "wrong-callback-destination", - ) - .await - .unwrap_err() - .error_description - ); - - // 3. Verify a successful exchange - assert_eq!( - (), - authorize_code_exchange( - &ctx, - "authorization_code", - client_id, - &client_secret, - &redirect_uri, - ) - .await - .unwrap() - ); - } - - #[tokio::test] - async fn test_exchange_checks_grant_type() { - let (mut ctx, client, client_secret) = mock_client().await; - let client_id = client.id; - let redirect_uri = client.redirect_uris[0].redirect_uri.clone(); - - let mut client_store = MockOAuthClientStore::new(); - client_store - .expect_get() - .with(eq(client_id), eq(false)) - .returning(move |_, _| Ok(Some(client.clone()))); - - let mut storage = MockStorage::new(); - storage.oauth_client_store = Some(Arc::new(client_store)); - - ctx.set_storage(Arc::new(storage)); - - assert_eq!( - OAuthErrorCode::UnsupportedGrantType, - authorize_code_exchange( - &ctx, - "not_authorization_code", - client_id, - &client_secret, - &redirect_uri - ) - .await - .unwrap_err() - .error - ); - - assert_eq!( - (), - authorize_code_exchange( - &ctx, - "authorization_code", - client_id, - &client_secret, - &redirect_uri - ) - .await - .unwrap() - ); - } - - #[tokio::test] - async fn test_exchange_checks_for_valid_secret() { - let (mut ctx, client, client_secret) = mock_client().await; - let client_id = client.id; - let redirect_uri = client.redirect_uris[0].redirect_uri.clone(); - - let mut client_store = MockOAuthClientStore::new(); - client_store - .expect_get() - .with(eq(client_id), eq(false)) - .returning(move |_, _| Ok(Some(client.clone()))); - - let mut storage = MockStorage::new(); - storage.oauth_client_store = Some(Arc::new(client_store)); - - ctx.set_storage(Arc::new(storage)); - - let invalid_secret = RawKey::generate::<8>(&Uuid::new_v4()) - .sign(ctx.signer()) - .await - .unwrap() - .signature() - .to_string(); - - assert_eq!( - OAuthErrorCode::InvalidRequest, - authorize_code_exchange( - &ctx, - "authorization_code", - client_id, - &"too-short".to_string().into(), - &redirect_uri - ) - .await - .unwrap_err() - .error - ); - - assert_eq!( - OAuthErrorCode::InvalidClient, - authorize_code_exchange( - &ctx, - "authorization_code", - client_id, - &invalid_secret.into(), - &redirect_uri - ) - .await - .unwrap_err() - .error - ); - - assert_eq!( - (), - authorize_code_exchange( - &ctx, - "authorization_code", - client_id, - &client_secret, - &redirect_uri - ) - .await - .unwrap() - ); - } - - #[tokio::test] - async fn test_login_attempt_verification() { - let (challenge, verifier) = PkceCodeChallenge::new_random_sha256(); - let attempt = LoginAttempt { - id: TypedUuid::new_v4(), - attempt_state: LoginAttemptState::RemoteAuthenticated, - client_id: TypedUuid::new_v4(), - redirect_uri: "https://test.oxeng.dev/callback".to_string(), - state: Some("ox_state".to_string()), - pkce_challenge: Some(challenge.as_str().to_string()), - pkce_challenge_method: Some("S256".to_string()), - authz_code: None, - expires_at: Some(Utc::now().add(TimeDelta::try_seconds(60).unwrap())), - error: None, - provider: "google".to_string(), - provider_pkce_verifier: Some("v_verifier".to_string()), - provider_authz_code: None, - provider_error: None, - created_at: Utc::now(), - updated_at: Utc::now(), - scope: String::new(), - }; - - let bad_client_id = LoginAttempt { - client_id: TypedUuid::new_v4(), - ..attempt.clone() - }; - - assert_eq!( - OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Invalid client id".to_string()), - error_uri: None, - state: None, - }, - verify_login_attempt( - &bad_client_id, - attempt.client_id, - &attempt.redirect_uri, - Some(verifier.secret().as_str()), - ) - .unwrap_err() - ); - - let bad_redirect_uri = LoginAttempt { - redirect_uri: "https://bad.oxeng.dev/callback".to_string(), - ..attempt.clone() - }; - - assert_eq!( - OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Invalid redirect uri".to_string()), - error_uri: None, - state: None, - }, - verify_login_attempt( - &bad_redirect_uri, - attempt.client_id, - &attempt.redirect_uri, - Some(verifier.secret().as_str()), - ) - .unwrap_err() - ); - - let unconfirmed_state = LoginAttempt { - attempt_state: LoginAttemptState::New, - ..attempt.clone() - }; - - assert_eq!( - OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Grant is in an invalid state".to_string()), - error_uri: None, - state: None, - }, - verify_login_attempt( - &unconfirmed_state, - attempt.client_id, - &attempt.redirect_uri, - Some(verifier.secret().as_str()), - ) - .unwrap_err() - ); - - let already_used_state = LoginAttempt { - attempt_state: LoginAttemptState::Complete, - ..attempt.clone() - }; - - assert_eq!( - OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Grant is in an invalid state".to_string()), - error_uri: None, - state: None, - }, - verify_login_attempt( - &already_used_state, - attempt.client_id, - &attempt.redirect_uri, - Some(verifier.secret().as_str()), - ) - .unwrap_err() - ); - - let failed_state = LoginAttempt { - attempt_state: LoginAttemptState::Failed, - ..attempt.clone() - }; - - assert_eq!( - OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Grant is in an invalid state".to_string()), - error_uri: None, - state: None, - }, - verify_login_attempt( - &failed_state, - attempt.client_id, - &attempt.redirect_uri, - Some(verifier.secret().as_str()), - ) - .unwrap_err() - ); - - let expired = LoginAttempt { - expires_at: Some(Utc::now()), - ..attempt.clone() - }; - - assert_eq!( - OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Grant has expired".to_string()), - error_uri: None, - state: None, - }, - verify_login_attempt( - &expired, - attempt.client_id, - &attempt.redirect_uri, - Some(verifier.secret().as_str()), - ) - .unwrap_err() - ); - - let missing_pkce = LoginAttempt { ..attempt.clone() }; - - assert_eq!( - OAuthError { - error: OAuthErrorCode::InvalidRequest, - error_description: Some("Missing pkce verifier".to_string()), - error_uri: None, - state: None, - }, - verify_login_attempt( - &missing_pkce, - attempt.client_id, - &attempt.redirect_uri, - None, - ) - .unwrap_err() - ); - - let invalid_pkce = LoginAttempt { - pkce_challenge: Some("no-the-correct-value".to_string()), - ..attempt.clone() - }; - - assert_eq!( - OAuthError { - error: OAuthErrorCode::InvalidGrant, - error_description: Some("Invalid pkce verifier".to_string()), - error_uri: None, - state: None, - }, - verify_login_attempt( - &invalid_pkce, - attempt.client_id, - &attempt.redirect_uri, - Some(verifier.secret().as_str()), - ) - .unwrap_err() - ); - - assert_eq!( - (), - verify_login_attempt( - &attempt, - attempt.client_id, - &attempt.redirect_uri, - Some(verifier.secret().as_str()), - ) - .unwrap() - ); - } -} diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 6a2a183b..27988985 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -37,7 +37,7 @@ use crate::{ context::{ApiContext, VContext}, endpoints::login::{ LoginError, UserInfo, - oauth::{CheckOAuthClient, ClientType, OAuthProviderAuthorizationCodePkceInfo}, + oauth::{CheckOAuthClient, OAuthProviderAuthorizationCodePkceInfo}, }, error::ApiError, permissions::{VAppPermission, VPermission}, @@ -747,7 +747,6 @@ where // re-authenticate. let info = fetch_user_info( ctx.public_url(), - &ctx.web_client(), &*provider, &attempt, !query.request_idp_token, @@ -992,7 +991,6 @@ fn verify_login_attempt( #[instrument(skip(attempt))] async fn fetch_user_info( public_url: &str, - client_type: &ClientType, provider: &dyn OAuthProvider, attempt: &LoginAttempt, revoke_idp_token: bool, @@ -1602,9 +1600,6 @@ mod tests { ); } - #[tokio::test] - async fn test_fails_callback_with_error() {} - #[tokio::test] async fn test_exchange_checks_client_id_and_redirect() { let (mut ctx, client, client_secret) = mock_client().await; diff --git a/v-cli-sdk/Cargo.toml b/v-cli-sdk/Cargo.toml index 962e0217..4054def7 100644 --- a/v-cli-sdk/Cargo.toml +++ b/v-cli-sdk/Cargo.toml @@ -1,7 +1,8 @@ [package] name = "v-cli-sdk" -version = "0.2.0" -edition = "2021" +version.workspace = true +edition.workspace = true +publish.workspace = true [dependencies] anyhow = { workspace = true } diff --git a/v-model/src/schema_ext.rs b/v-model/src/schema_ext.rs index 8ac7126d..303bfe4e 100644 --- a/v-model/src/schema_ext.rs +++ b/v-model/src/schema_ext.rs @@ -51,7 +51,7 @@ macro_rules! sql_conversion { } #[derive( - Debug, PartialEq, Clone, FromSqlRow, AsExpression, Serialize, Deserialize, JsonSchema, Default, + Copy, Debug, PartialEq, Clone, FromSqlRow, AsExpression, Serialize, Deserialize, JsonSchema, Default, )] #[diesel(sql_type = AttemptState)] #[serde(rename_all = "lowercase")] diff --git a/v-model/src/storage/postgres.rs b/v-model/src/storage/postgres.rs index f5e8ebcd..267d873c 100644 --- a/v-model/src/storage/postgres.rs +++ b/v-model/src/storage/postgres.rs @@ -807,8 +807,8 @@ impl LoginAttemptStore for PostgresStore { match result { Some(attempt) => Ok(LoginAttempt::from(attempt)), None => Err(StoreError::InvariantFailed(format!( - "Login attempt {} is not in expected state for transition to {}", - attempt.id, attempt.attempt_state + "Login attempt {} is not in expected state for transition to {}. It is in {}", + attempt.id, attempt.attempt_state, expected_state ))), } } From 09f4eed5411cdbee779cc42dd595689d2665aeb0 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 7 May 2026 15:16:02 -0500 Subject: [PATCH 47/51] One more error propagation fix --- v-api/src/config.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/v-api/src/config.rs b/v-api/src/config.rs index cae7c2c1..f4269167 100644 --- a/v-api/src/config.rs +++ b/v-api/src/config.rs @@ -201,7 +201,7 @@ impl OAuthConfig { let device = self .device .as_ref() - .and_then(|d| d.resolve(base.clone()).ok()); + .map(|d| d.resolve(base.clone())).transpose()?; let web = self.web.as_ref().map(|w| w.resolve(base.clone())).transpose()?; let proxy_web = self.proxy_web.as_ref().map(|p| p.resolve(base)).transpose()?; Ok(ResolvedOAuthConfig { From 51fb286d6b4ac01745808345da3888ae90bbf753 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 7 May 2026 15:18:23 -0500 Subject: [PATCH 48/51] More cleanup and notes --- v-api/src/config.rs | 15 ++++++++++++--- v-api/src/endpoints/login/oauth/remote/github.rs | 2 +- v-api/src/endpoints/login/oauth/remote/zendesk.rs | 2 ++ v-cli-sdk/src/printer/mod.rs | 2 +- v-model/src/schema_ext.rs | 11 ++++++++++- 5 files changed, 26 insertions(+), 6 deletions(-) diff --git a/v-api/src/config.rs b/v-api/src/config.rs index f4269167..0823d05a 100644 --- a/v-api/src/config.rs +++ b/v-api/src/config.rs @@ -201,9 +201,18 @@ impl OAuthConfig { let device = self .device .as_ref() - .map(|d| d.resolve(base.clone())).transpose()?; - let web = self.web.as_ref().map(|w| w.resolve(base.clone())).transpose()?; - let proxy_web = self.proxy_web.as_ref().map(|p| p.resolve(base)).transpose()?; + .map(|d| d.resolve(base.clone())) + .transpose()?; + let web = self + .web + .as_ref() + .map(|w| w.resolve(base.clone())) + .transpose()?; + let proxy_web = self + .proxy_web + .as_ref() + .map(|p| p.resolve(base)) + .transpose()?; Ok(ResolvedOAuthConfig { device, web, diff --git a/v-api/src/endpoints/login/oauth/remote/github.rs b/v-api/src/endpoints/login/oauth/remote/github.rs index dc19de81..4f3405bd 100644 --- a/v-api/src/endpoints/login/oauth/remote/github.rs +++ b/v-api/src/endpoints/login/oauth/remote/github.rs @@ -149,7 +149,7 @@ impl OAuthProvider for GitHubOAuthProvider { &self.default_scopes } fn supports_pkce(&self) -> bool { - false + true } fn authz_code_flow_info(&self) -> Option<&OAuthProviderAuthorizationCodeInfo> { diff --git a/v-api/src/endpoints/login/oauth/remote/zendesk.rs b/v-api/src/endpoints/login/oauth/remote/zendesk.rs index cffb63df..283645cd 100644 --- a/v-api/src/endpoints/login/oauth/remote/zendesk.rs +++ b/v-api/src/endpoints/login/oauth/remote/zendesk.rs @@ -139,6 +139,8 @@ impl OAuthProvider for ZendeskOAuthProvider { } fn expires_in(&self) -> Option { + // This is the maximum token duration that Zendesk supports. In the future we should make + // this configurable Some(172800) } fn default_scopes(&self) -> &[String] { diff --git a/v-cli-sdk/src/printer/mod.rs b/v-cli-sdk/src/printer/mod.rs index 4685b8ed..34db6cc2 100644 --- a/v-cli-sdk/src/printer/mod.rs +++ b/v-cli-sdk/src/printer/mod.rs @@ -59,7 +59,7 @@ impl Printer { // Check for 401 Unauthorized up-front, regardless of output format. if let Some(status) = value.status() { if status == reqwest::StatusCode::UNAUTHORIZED { - eprintln!("Authentication required. Please run `sprue auth login` first."); + eprintln!("Authentication required. Please run `auth login` first."); return; } } diff --git a/v-model/src/schema_ext.rs b/v-model/src/schema_ext.rs index 303bfe4e..769280c3 100644 --- a/v-model/src/schema_ext.rs +++ b/v-model/src/schema_ext.rs @@ -51,7 +51,16 @@ macro_rules! sql_conversion { } #[derive( - Copy, Debug, PartialEq, Clone, FromSqlRow, AsExpression, Serialize, Deserialize, JsonSchema, Default, + Copy, + Debug, + PartialEq, + Clone, + FromSqlRow, + AsExpression, + Serialize, + Deserialize, + JsonSchema, + Default, )] #[diesel(sql_type = AttemptState)] #[serde(rename_all = "lowercase")] From 2d264895a1ccd3d4e205c2a006c90770fd7f6473 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 7 May 2026 15:19:25 -0500 Subject: [PATCH 49/51] Use trait mapping to HttpError --- v-api/src/endpoints/login/oauth/flow/code.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 27988985..7f26e405 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -512,8 +512,7 @@ where let client = ctx .oauth .get_oauth_client(&ctx.builtin_registration_user(), &attempt.client_id) - .await - .map_err(to_internal_error)?; + .await?; if !client.is_redirect_uri_valid(&attempt.redirect_uri) { tracing::warn!( redirect_uri = ?attempt.redirect_uri, From 1f26a147e83c9183a96438a79467d56e9e96ea85 Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 7 May 2026 15:51:59 -0500 Subject: [PATCH 50/51] Fix for revocation when user has not requested access to a token --- v-api/src/endpoints/login/oauth/flow/code.rs | 210 +++++++++++++++---- v-cli-sdk/src/cmd/auth/oauth/code.rs | 20 +- v-cli-sdk/src/printer/mod.rs | 10 +- v-model/src/storage/postgres.rs | 2 +- 4 files changed, 188 insertions(+), 54 deletions(-) diff --git a/v-api/src/endpoints/login/oauth/flow/code.rs b/v-api/src/endpoints/login/oauth/flow/code.rs index 7f26e405..30a36997 100644 --- a/v-api/src/endpoints/login/oauth/flow/code.rs +++ b/v-api/src/endpoints/login/oauth/flow/code.rs @@ -743,25 +743,30 @@ where // Now that the attempt has been claimed, use it to fetch user information from the // remote provider. If this fails, the attempt is already consumed and the user must - // re-authenticate. - let info = fetch_user_info( - ctx.public_url(), - &*provider, - &attempt, - !query.request_idp_token, - ) - .await?; + // re-authenticate. The upstream access token is always preserved here so that + // revocation can be deferred until after the permission check. + let (info, upstream_token) = fetch_user_info(ctx.public_url(), &*provider, &attempt).await?; tracing::debug!("Retrieved user information from remote provider"); - complete_exchange(ctx, info, &attempt, query.request_idp_token).await + complete_exchange( + ctx, + info, + &*provider, + &attempt, + query.request_idp_token, + upstream_token, + ) + .await } async fn complete_exchange( ctx: &VContext, info: UserInfo, + provider: &dyn OAuthProvider, attempt: &LoginAttempt, request_idp_token: bool, + upstream_token: Option, ) -> Result, HttpError> where T: VAppPermission + PermissionStorage, @@ -778,6 +783,16 @@ where // only the directly assigned user permissions. let idp_token = filter_idp_token(ctx, idp_token, request_idp_token, &api_user_info).await; + // Revoke the upstream access token whenever it will NOT be returned to the caller. + // This covers the cases where the token was never requested, where the user lacks + // the RetrieveRemoteAccessToken permission, and where the provider did not return + // a token at all. + if idp_token.is_none() + && let Some(upstream) = upstream_token + { + revoke_upstream_token(provider, &upstream).await; + } + tracing::info!(api_user_id = ?api_user_info.user.id, "Retrieved api user to generate access token for"); let scope = attempt @@ -987,16 +1002,52 @@ fn verify_login_attempt( } } +/// Revoke an upstream IdP access token if the provider supports revocation. +/// Failures are logged but do not propagate — callers should not fail the +/// overall exchange just because revocation was unsuccessful. +async fn revoke_upstream_token(provider: &dyn OAuthProvider, token_secret: &str) { + let provider_info = match provider.authz_code_flow_info() { + Some(info) => info, + None => return, + }; + + if provider_info.remote.revocation_endpoint.is_some() { + let client = match provider.as_web_client() { + Ok(c) => c, + Err(err) => { + tracing::warn!( + ?err, + "Failed to build web client for upstream token revocation" + ); + return; + } + }; + let oauth_client: oauth2_reqwest::ReqwestClient = provider.client().clone().into(); + let access_token = oauth2::AccessToken::new(token_secret.to_string()); + match client.revoke_token(access_token.into()) { + Ok(req) => { + if let Err(err) = req.request_async(&oauth_client).await { + tracing::warn!(?err, "Failed to revoke upstream IdP access token"); + } + } + Err(err) => { + tracing::warn!( + ?err, + "Failed to build revocation request for upstream token" + ); + } + } + } else { + tracing::debug!("Provider does not support token revocation") + } +} + #[instrument(skip(attempt))] async fn fetch_user_info( public_url: &str, provider: &dyn OAuthProvider, attempt: &LoginAttempt, - revoke_idp_token: bool, -) -> Result { - let provider_info = provider - .authz_code_flow_info() - .ok_or_else(|| internal_error("Authorization code flow not supported"))?; +) -> Result<(UserInfo, Option), HttpError> { // Exchange the stored authorization code with the remote provider for a remote access token let client = provider.as_web_client().map_err(to_internal_error)?; @@ -1035,18 +1086,11 @@ async fn fetch_user_info( tracing::info!("Fetched user info from remote service"); - // Now that we are done with fetching user information from the remote API, we can revoke it if - // the provider supports it - if revoke_idp_token && provider_info.remote.revocation_endpoint.is_some() { - client - .revoke_token(response.access_token().into()) - .map_err(internal_error)? - .request_async(&oauth_client) - .await - .map_err(internal_error)?; - } + // Return the upstream access token alongside the user info so the caller + // can decide whether to revoke it after the permission check. + let upstream_token = Some(response.access_token().secret().to_string()); - Ok(info) + Ok((info, upstream_token)) } #[cfg(test)] @@ -1100,6 +1144,70 @@ mod tests { use super::{authorize_code_exchange, get_oauth_client, oauth_redirect_response}; + /// A minimal no-op `OAuthProvider` for unit tests that need to pass a + /// provider reference to `complete_exchange` without performing any real + /// network I/O. `authz_code_flow_info` returns `None`, so + /// `revoke_upstream_token` will short-circuit immediately. + #[derive(Debug)] + struct NoOpOAuthProvider { + client: reqwest::Client, + } + + impl NoOpOAuthProvider { + fn new() -> Self { + Self { + client: reqwest::Client::new(), + } + } + } + + impl crate::endpoints::login::oauth::ExtractUserInfo for NoOpOAuthProvider { + fn extract_user_info( + &self, + _data: &[hyper::body::Bytes], + ) -> Result { + unimplemented!("not used in tests") + } + } + + impl crate::endpoints::login::oauth::OAuthProvider for NoOpOAuthProvider { + fn name(&self) -> OAuthProviderName { + OAuthProviderName::Google + } + fn initialize_headers(&self, _request: &mut reqwest::Request) {} + fn client(&self) -> &reqwest::Client { + &self.client + } + fn user_info_endpoints(&self) -> Vec<&str> { + vec![] + } + fn authz_code_flow_info( + &self, + ) -> Option<&crate::endpoints::login::oauth::OAuthProviderAuthorizationCodeInfo> { + None + } + fn authz_code_pkce_flow_info( + &self, + ) -> Option<&crate::endpoints::login::oauth::OAuthProviderAuthorizationCodePkceInfo> + { + None + } + fn device_code_flow_info( + &self, + ) -> Option<&crate::endpoints::login::oauth::OAuthProviderDeviceInfo> { + None + } + fn expires_in(&self) -> Option { + None + } + fn default_scopes(&self) -> &[String] { + &[] + } + fn supports_pkce(&self) -> bool { + false + } + } + /// Create a mock `OAuthClientStore` that returns a client with the given /// `client_id` and a single registered `redirect_uri`. This is needed by /// any test that exercises `authz_code_callback_op_inner`, which re-validates @@ -2396,11 +2504,19 @@ mod tests { let ctx = mock_context(Arc::new(storage)).await; let attempt = mock_completed_attempt(); let info = mock_user_info_with_idp_token(); + let provider = NoOpOAuthProvider::new(); - let response = super::complete_exchange(&ctx, info, &attempt, true) - .await - .unwrap() - .0; + let response = super::complete_exchange( + &ctx, + info, + &provider, + &attempt, + true, + Some("secret-upstream-token".to_string()), + ) + .await + .unwrap() + .0; assert_eq!( response.idp_token, @@ -2418,11 +2534,19 @@ mod tests { let ctx = mock_context(Arc::new(storage)).await; let attempt = mock_completed_attempt(); let info = mock_user_info_with_idp_token(); + let provider = NoOpOAuthProvider::new(); - let response = super::complete_exchange(&ctx, info, &attempt, true) - .await - .unwrap() - .0; + let response = super::complete_exchange( + &ctx, + info, + &provider, + &attempt, + true, + Some("secret-upstream-token".to_string()), + ) + .await + .unwrap() + .0; assert_eq!( response.idp_token, None, @@ -2532,7 +2656,9 @@ mod tests { idp_token: None, }; - let response = super::complete_exchange(&ctx, info, &attempt, false) + let provider = NoOpOAuthProvider::new(); + + let response = super::complete_exchange(&ctx, info, &provider, &attempt, false, None) .await .unwrap() .0; @@ -2560,11 +2686,19 @@ mod tests { let ctx = mock_context(Arc::new(storage)).await; let attempt = mock_completed_attempt(); let info = mock_user_info_with_idp_token(); + let provider = NoOpOAuthProvider::new(); - let response = super::complete_exchange(&ctx, info, &attempt, false) - .await - .unwrap() - .0; + let response = super::complete_exchange( + &ctx, + info, + &provider, + &attempt, + false, + Some("secret-upstream-token".to_string()), + ) + .await + .unwrap() + .0; assert_eq!( response.idp_token, None, diff --git a/v-cli-sdk/src/cmd/auth/oauth/code.rs b/v-cli-sdk/src/cmd/auth/oauth/code.rs index 8a471b88..92db7b45 100644 --- a/v-cli-sdk/src/cmd/auth/oauth/code.rs +++ b/v-cli-sdk/src/cmd/auth/oauth/code.rs @@ -178,10 +178,10 @@ impl CodeOAuth { .map_err(|e| anyhow::anyhow!(e))?; // Send the token back to the main task. - if let Ok(mut guard) = token_tx.lock() { - if let Some(tx) = guard.take() { - let _ = tx.send(Ok(token)); - } + if let Ok(mut guard) = token_tx.lock() + && let Some(tx) = guard.take() + { + let _ = tx.send(Ok(token)); } // Return a friendly page to the browser so the user @@ -210,12 +210,12 @@ impl CodeOAuth { // If the proxy died before we got a token, unblock the // receiver so the caller is not stuck forever. - if let Ok(mut guard) = error_token_tx.lock() { - if let Some(tx) = guard.take() { - let _ = tx.send(Err(anyhow::anyhow!( - "Proxy server exited unexpectedly: {e}" - ))); - } + if let Ok(mut guard) = error_token_tx.lock() + && let Some(tx) = guard.take() + { + let _ = tx.send(Err(anyhow::anyhow!( + "Proxy server exited unexpectedly: {e}" + ))); } } } diff --git a/v-cli-sdk/src/printer/mod.rs b/v-cli-sdk/src/printer/mod.rs index 34db6cc2..3fc117e1 100644 --- a/v-cli-sdk/src/printer/mod.rs +++ b/v-cli-sdk/src/printer/mod.rs @@ -57,11 +57,11 @@ impl Printer { T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug, { // Check for 401 Unauthorized up-front, regardless of output format. - if let Some(status) = value.status() { - if status == reqwest::StatusCode::UNAUTHORIZED { - eprintln!("Authentication required. Please run `auth login` first."); - return; - } + if let Some(status) = value.status() + && status == reqwest::StatusCode::UNAUTHORIZED + { + eprintln!("Authentication required. Please run `auth login` first."); + return; } match self { diff --git a/v-model/src/storage/postgres.rs b/v-model/src/storage/postgres.rs index 267d873c..a658531f 100644 --- a/v-model/src/storage/postgres.rs +++ b/v-model/src/storage/postgres.rs @@ -793,7 +793,7 @@ impl LoginAttemptStore for PostgresStore { .filter(login_attempt::attempt_state.eq(expected_state)), ) .set(( - login_attempt::attempt_state.eq(attempt.attempt_state.clone()), + login_attempt::attempt_state.eq(attempt.attempt_state), login_attempt::authz_code.eq(attempt.authz_code), login_attempt::expires_at.eq(attempt.expires_at), login_attempt::error.eq(attempt.error), From af5c72c1253f14e9edfbcf8ca37e88a4225ddc4e Mon Sep 17 00:00:00 2001 From: augustuswm Date: Thu, 7 May 2026 16:09:23 -0500 Subject: [PATCH 51/51] Fix incorrect debug message --- v-model/src/storage/postgres.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/v-model/src/storage/postgres.rs b/v-model/src/storage/postgres.rs index a658531f..10224d0e 100644 --- a/v-model/src/storage/postgres.rs +++ b/v-model/src/storage/postgres.rs @@ -807,8 +807,8 @@ impl LoginAttemptStore for PostgresStore { match result { Some(attempt) => Ok(LoginAttempt::from(attempt)), None => Err(StoreError::InvariantFailed(format!( - "Login attempt {} is not in expected state for transition to {}. It is in {}", - attempt.id, attempt.attempt_state, expected_state + "Login attempt {} is not in expected state for transition to {}", + attempt.id, attempt.attempt_state, ))), } }