Skip to content
This repository has been archived by the owner on Sep 10, 2024. It is now read-only.

Commit

Permalink
Allow endpoints and discovery mode override for upstream oauth2 provi…
Browse files Browse the repository at this point in the history
…ders

This time, at the configuration and database level
  • Loading branch information
sandhose committed Nov 17, 2023
1 parent 5669e50 commit f6fc145
Show file tree
Hide file tree
Showing 17 changed files with 717 additions and 211 deletions.
78 changes: 63 additions & 15 deletions crates/cli/src/commands/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ use std::collections::HashSet;
use clap::Parser;
use mas_config::{ConfigurationSection, RootConfig, SyncConfig};
use mas_storage::{
upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess, SystemClock,
upstream_oauth2::{UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository},
RepositoryAccess, SystemClock,
};
use mas_storage_pg::PgRepository;
use rand::SeedableRng;
use sqlx::{postgres::PgAdvisoryLock, Acquire};
use tracing::{info, info_span, warn};
use tracing::{error, info, info_span, warn};

use crate::util::database_connection_from_config;

Expand Down Expand Up @@ -204,10 +205,11 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu
}

for provider in config.upstream_oauth2.providers {
let _span = info_span!("provider", %provider.id).entered();
if existing_ids.contains(&provider.id) {
info!(%provider.id, "Updating provider");
info!("Updating provider");
} else {
info!(%provider.id, "Adding provider");
info!("Adding provider");
}

if dry_run {
Expand All @@ -218,20 +220,65 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu
.client_secret()
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
.transpose()?;
let client_auth_method = provider.client_auth_method();
let client_auth_signing_alg = provider.client_auth_signing_alg();
let token_endpoint_auth_method = provider.client_auth_method();
let token_endpoint_signing_alg = provider.client_auth_signing_alg();

let discovery_mode = match provider.discovery_mode {
mas_config::UpstreamOAuth2DiscoveryMode::Oidc => {
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc
}
mas_config::UpstreamOAuth2DiscoveryMode::Insecure => {
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Insecure
}
mas_config::UpstreamOAuth2DiscoveryMode::Disabled => {
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Disabled
}
};

if discovery_mode.is_disabled() {
if provider.authorization_endpoint.is_none() {
error!("Provider has discovery disabled but no authorization endpoint set");
}

if provider.token_endpoint.is_none() {
error!("Provider has discovery disabled but no token endpoint set");
}

if provider.jwks_uri.is_none() {
error!("Provider has discovery disabled but no JWKS URI set");
}
}

let pkce_mode = match provider.pkce_method {
mas_config::UpstreamOAuth2PkceMethod::Auto => {
mas_data_model::UpstreamOAuthProviderPkceMode::Auto
}
mas_config::UpstreamOAuth2PkceMethod::Always => {
mas_data_model::UpstreamOAuthProviderPkceMode::S256
}
mas_config::UpstreamOAuth2PkceMethod::Never => {
mas_data_model::UpstreamOAuthProviderPkceMode::Disabled
}
};

repo.upstream_oauth_provider()
.upsert(
&clock,
provider.id,
provider.issuer,
provider.scope.parse()?,
client_auth_method,
client_auth_signing_alg,
provider.client_id,
encrypted_client_secret,
map_claims_imports(&provider.claims_imports),
UpstreamOAuthProviderParams {
issuer: provider.issuer,
scope: provider.scope.parse()?,
token_endpoint_auth_method,
token_endpoint_signing_alg,
client_id: provider.client_id,
encrypted_client_secret,
claims_imports: map_claims_imports(&provider.claims_imports),
token_endpoint_override: provider.token_endpoint,
authorization_endpoint_override: provider.authorization_endpoint,
jwks_uri_override: provider.jwks_uri,
discovery_mode,
pkce_mode,
},
)
.await?;
}
Expand Down Expand Up @@ -268,10 +315,11 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu
}

for client in config.clients.iter() {
let _span = info_span!("client", client.id = %client.client_id).entered();
if existing_ids.contains(&client.client_id) {
info!(client.id = %client.client_id, "Updating client");
info!("Updating client");
} else {
info!(client.id = %client.client_id, "Adding client");
info!("Adding client");
}

if dry_run {
Expand Down
4 changes: 2 additions & 2 deletions crates/config/src/sections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ pub use self::{
},
templates::TemplatesConfig,
upstream_oauth2::{
ClaimsImports as UpstreamOAuth2ClaimsImports,
ClaimsImports as UpstreamOAuth2ClaimsImports, DiscoveryMode as UpstreamOAuth2DiscoveryMode,
EmailImportPreference as UpstreamOAuth2EmailImportPreference,
ImportAction as UpstreamOAuth2ImportAction,
ImportPreference as UpstreamOAuth2ImportPreference,
ImportPreference as UpstreamOAuth2ImportPreference, PkceMethod as UpstreamOAuth2PkceMethod,
SetEmailVerification as UpstreamOAuth2SetEmailVerification, UpstreamOAuth2Config,
},
};
Expand Down
62 changes: 62 additions & 0 deletions crates/config/src/sections/upstream_oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use ulid::Ulid;
use url::Url;

use crate::ConfigurationSection;

Expand Down Expand Up @@ -197,6 +198,39 @@ pub struct ClaimsImports {
pub email: EmailImportPreference,
}

/// How to discover the provider's configuration
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum DiscoveryMode {
/// Use OIDC discovery with strict metadata verification
#[default]
Oidc,

/// Use OIDC discovery with relaxed metadata verification
Insecure,

/// Use a static configuration
Disabled,
}

/// Whether to use proof key for code exchange (PKCE) when requesting and
/// exchanging the token.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum PkceMethod {
/// Use PKCE if the provider supports it
///
/// Defaults to no PKCE if provider discovery is disabled
#[default]
Auto,

/// Always use PKCE with the S256 challenge method
Always,

/// Never use PKCE
Never,
}

#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct Provider {
Expand All @@ -220,6 +254,34 @@ pub struct Provider {
#[serde(flatten)]
pub token_auth_method: TokenAuthMethod,

/// How to discover the provider's configuration
///
/// Defaults to use OIDC discovery with strict metadata verification
#[serde(default)]
pub discovery_mode: DiscoveryMode,

/// Whether to use proof key for code exchange (PKCE) when requesting and
/// exchanging the token.
///
/// Defaults to `auto`, which uses PKCE if the provider supports it.
#[serde(default)]
pub pkce_method: PkceMethod,

/// The URL to use for the provider's authorization endpoint
///
/// Defaults to the `authorization_endpoint` provided through discovery
pub authorization_endpoint: Option<Url>,

/// The URL to use for the provider's token endpoint
///
/// Defaults to the `token_endpoint` provided through discovery
pub token_endpoint: Option<Url>,

/// The URL to use for getting the provider's public keys
///
/// Defaults to the `jwks_uri` provided through discovery
pub jwks_uri: Option<Url>,

/// How claims should be imported from the `id_token` provided by the
/// provider
pub claims_imports: ClaimsImports,
Expand Down
77 changes: 77 additions & 0 deletions crates/data-model/src/upstream_oauth2/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use chrono::{DateTime, Utc};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use oauth2_types::scope::Scope;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use ulid::Ulid;
use url::Url;

Expand All @@ -33,6 +34,48 @@ pub enum DiscoveryMode {
Disabled,
}

impl DiscoveryMode {
/// Returns `true` if discovery is disabled
#[must_use]
pub fn is_disabled(&self) -> bool {
matches!(self, DiscoveryMode::Disabled)
}
}

#[derive(Debug, Clone, Error)]
#[error("Invalid discovery mode {0:?}")]
pub struct InvalidDiscoveryModeError(String);

impl std::str::FromStr for DiscoveryMode {
type Err = InvalidDiscoveryModeError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"oidc" => Ok(Self::Oidc),
"insecure" => Ok(Self::Insecure),
"disabled" => Ok(Self::Disabled),
s => Err(InvalidDiscoveryModeError(s.to_owned())),
}
}
}

impl DiscoveryMode {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Oidc => "oidc",
Self::Insecure => "insecure",
Self::Disabled => "disabled",
}
}
}

impl std::fmt::Display for DiscoveryMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum PkceMode {
Expand All @@ -47,6 +90,40 @@ pub enum PkceMode {
Disabled,
}

#[derive(Debug, Clone, Error)]
#[error("Invalid PKCE mode {0:?}")]
pub struct InvalidPkceModeError(String);

impl std::str::FromStr for PkceMode {
type Err = InvalidPkceModeError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"auto" => Ok(Self::Auto),
"s256" => Ok(Self::S256),
"disabled" => Ok(Self::Disabled),
s => Err(InvalidPkceModeError(s.to_owned())),
}
}
}

impl PkceMode {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Auto => "auto",
Self::S256 => "s256",
Self::Disabled => "disabled",
}
}
}

impl std::fmt::Display for PkceMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UpstreamOAuthProvider {
pub id: Ulid,
Expand Down
3 changes: 1 addition & 2 deletions crates/handlers/src/upstream_oauth2/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,8 @@ mod tests {
use tower::BoxError;
use ulid::Ulid;

use crate::test_utils::init_tracing;

use super::*;
use crate::test_utils::init_tracing;

#[tokio::test]
async fn test_metadata_cache() {
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit f6fc145

Please sign in to comment.