Skip to content
This repository has been archived by the owner on Sep 10, 2024. It is now read-only.

Commit

Permalink
Allow endpoints and discovery mode override for upstream oauth2 provi…
Browse files Browse the repository at this point in the history
…ders

This time, at the configuration and database level
  • Loading branch information
sandhose committed Nov 17, 2023
1 parent 5669e50 commit 7981d9a
Show file tree
Hide file tree
Showing 12 changed files with 609 additions and 173 deletions.
78 changes: 63 additions & 15 deletions crates/cli/src/commands/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ use std::collections::HashSet;
use clap::Parser;
use mas_config::{ConfigurationSection, RootConfig, SyncConfig};
use mas_storage::{
upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess, SystemClock,
upstream_oauth2::{UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository},
RepositoryAccess, SystemClock,
};
use mas_storage_pg::PgRepository;
use rand::SeedableRng;
use sqlx::{postgres::PgAdvisoryLock, Acquire};
use tracing::{info, info_span, warn};
use tracing::{error, info, info_span, warn};

use crate::util::database_connection_from_config;

Expand Down Expand Up @@ -204,10 +205,11 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu
}

for provider in config.upstream_oauth2.providers {
let _span = info_span!("provider", %provider.id).entered();
if existing_ids.contains(&provider.id) {
info!(%provider.id, "Updating provider");
info!("Updating provider");
} else {
info!(%provider.id, "Adding provider");
info!("Adding provider");
}

if dry_run {
Expand All @@ -218,20 +220,65 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu
.client_secret()
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
.transpose()?;
let client_auth_method = provider.client_auth_method();
let client_auth_signing_alg = provider.client_auth_signing_alg();
let token_endpoint_auth_method = provider.client_auth_method();
let token_endpoint_signing_alg = provider.client_auth_signing_alg();

let discovery_mode = match provider.discovery_mode {
mas_config::UpstreamOAuth2DiscoveryMode::Oidc => {
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc
}
mas_config::UpstreamOAuth2DiscoveryMode::Insecure => {
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Insecure
}
mas_config::UpstreamOAuth2DiscoveryMode::Disabled => {
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Disabled
}
};

if discovery_mode.is_disabled() {
if provider.authorization_endpoint.is_none() {
error!("Provider has discovery disabled but no authorization endpoint set");
}

if provider.token_endpoint.is_none() {
error!("Provider has discovery disabled but no token endpoint set");
}

if provider.jwks_uri.is_none() {
error!("Provider has discovery disabled but no JWKS URI set");
}
}

let pkce_mode = match provider.pkce_method {
mas_config::UpstreamOAuth2PkceMethod::Auto => {
mas_data_model::UpstreamOAuthProviderPkceMode::Auto
}
mas_config::UpstreamOAuth2PkceMethod::Always => {
mas_data_model::UpstreamOAuthProviderPkceMode::S256
}
mas_config::UpstreamOAuth2PkceMethod::Never => {
mas_data_model::UpstreamOAuthProviderPkceMode::Disabled
}
};

repo.upstream_oauth_provider()
.upsert(
&clock,
provider.id,
provider.issuer,
provider.scope.parse()?,
client_auth_method,
client_auth_signing_alg,
provider.client_id,
encrypted_client_secret,
map_claims_imports(&provider.claims_imports),
UpstreamOAuthProviderParams {
issuer: provider.issuer,
scope: provider.scope.parse()?,
token_endpoint_auth_method,
token_endpoint_signing_alg,
client_id: provider.client_id,
encrypted_client_secret,
claims_imports: map_claims_imports(&provider.claims_imports),
token_endpoint_override: provider.token_endpoint,
authorization_endpoint_override: provider.authorization_endpoint,
jwks_uri_override: provider.jwks_uri,
discovery_mode,
pkce_mode,
},
)
.await?;
}
Expand Down Expand Up @@ -268,10 +315,11 @@ async fn sync(root: &super::Options, prune: bool, dry_run: bool) -> anyhow::Resu
}

for client in config.clients.iter() {
let _span = info_span!("client", client.id = %client.client_id).entered();
if existing_ids.contains(&client.client_id) {
info!(client.id = %client.client_id, "Updating client");
info!("Updating client");
} else {
info!(client.id = %client.client_id, "Adding client");
info!("Adding client");
}

if dry_run {
Expand Down
4 changes: 2 additions & 2 deletions crates/config/src/sections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ pub use self::{
},
templates::TemplatesConfig,
upstream_oauth2::{
ClaimsImports as UpstreamOAuth2ClaimsImports,
ClaimsImports as UpstreamOAuth2ClaimsImports, DiscoveryMode as UpstreamOAuth2DiscoveryMode,
EmailImportPreference as UpstreamOAuth2EmailImportPreference,
ImportAction as UpstreamOAuth2ImportAction,
ImportPreference as UpstreamOAuth2ImportPreference,
ImportPreference as UpstreamOAuth2ImportPreference, PkceMethod as UpstreamOAuth2PkceMethod,
SetEmailVerification as UpstreamOAuth2SetEmailVerification, UpstreamOAuth2Config,
},
};
Expand Down
62 changes: 62 additions & 0 deletions crates/config/src/sections/upstream_oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use ulid::Ulid;
use url::Url;

use crate::ConfigurationSection;

Expand Down Expand Up @@ -197,6 +198,39 @@ pub struct ClaimsImports {
pub email: EmailImportPreference,
}

/// How to discover the provider's configuration
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum DiscoveryMode {
/// Use OIDC discovery with strict metadata verification
#[default]
Oidc,

/// Use OIDC discovery with relaxed metadata verification
Insecure,

/// Use a static configuration
Disabled,
}

/// Whether to use proof key for code exchange (PKCE) when requesting and
/// exchanging the token.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum PkceMethod {
/// Use PKCE if the provider supports it
///
/// Defaults to no PKCE if provider discovery is disabled
#[default]
Auto,

/// Always use PKCE with the S256 challenge method
Always,

/// Never use PKCE
Never,
}

#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct Provider {
Expand All @@ -220,6 +254,34 @@ pub struct Provider {
#[serde(flatten)]
pub token_auth_method: TokenAuthMethod,

/// How to discover the provider's configuration
///
/// Defaults to use OIDC discovery with strict metadata verification
#[serde(default)]
pub discovery_mode: DiscoveryMode,

/// Whether to use proof key for code exchange (PKCE) when requesting and
/// exchanging the token.
///
/// Defaults to `auto`, which uses PKCE if the provider supports it.
#[serde(default)]
pub pkce_method: PkceMethod,

/// The URL to use for the provider's authorization endpoint
///
/// Defaults to the `authorization_endpoint` provided through discovery
pub authorization_endpoint: Option<Url>,

/// The URL to use for the provider's token endpoint
///
/// Defaults to the `token_endpoint` provided through discovery
pub token_endpoint: Option<Url>,

/// The URL to use for getting the provider's public keys
///
/// Defaults to the `jwks_uri` provided through discovery
pub jwks_uri: Option<Url>,

/// How claims should be imported from the `id_token` provided by the
/// provider
pub claims_imports: ClaimsImports,
Expand Down
77 changes: 77 additions & 0 deletions crates/data-model/src/upstream_oauth2/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use chrono::{DateTime, Utc};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use oauth2_types::scope::Scope;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use ulid::Ulid;
use url::Url;

Expand All @@ -33,6 +34,48 @@ pub enum DiscoveryMode {
Disabled,
}

impl DiscoveryMode {
/// Returns `true` if discovery is disabled
#[must_use]
pub fn is_disabled(&self) -> bool {
matches!(self, DiscoveryMode::Disabled)
}
}

#[derive(Debug, Clone, Error)]
#[error("Invalid discovery mode {0:?}")]
pub struct InvalidDiscoveryModeError(String);

impl std::str::FromStr for DiscoveryMode {
type Err = InvalidDiscoveryModeError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"oidc" => Ok(Self::Oidc),
"insecure" => Ok(Self::Insecure),
"disabled" => Ok(Self::Disabled),
s => Err(InvalidDiscoveryModeError(s.to_owned())),
}
}
}

impl DiscoveryMode {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Oidc => "oidc",
Self::Insecure => "insecure",
Self::Disabled => "disabled",
}
}
}

impl std::fmt::Display for DiscoveryMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum PkceMode {
Expand All @@ -47,6 +90,40 @@ pub enum PkceMode {
Disabled,
}

#[derive(Debug, Clone, Error)]
#[error("Invalid PKCE mode {0:?}")]
pub struct InvalidPkceModeError(String);

impl std::str::FromStr for PkceMode {
type Err = InvalidPkceModeError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"auto" => Ok(Self::Auto),
"s256" => Ok(Self::S256),
"disabled" => Ok(Self::Disabled),
s => Err(InvalidPkceModeError(s.to_owned())),
}
}
}

impl PkceMode {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Auto => "auto",
Self::S256 => "s256",
Self::Disabled => "disabled",
}
}
}

impl std::fmt::Display for PkceMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UpstreamOAuthProvider {
pub id: Ulid,
Expand Down
3 changes: 1 addition & 2 deletions crates/handlers/src/upstream_oauth2/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,8 @@ mod tests {
use tower::BoxError;
use ulid::Ulid;

use crate::test_utils::init_tracing;

use super::*;
use crate::test_utils::init_tracing;

#[tokio::test]
async fn test_metadata_cache() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
-- Copyright 2023 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

-- Adds various endpoint overrides for oauth providers
ALTER TABLE upstream_oauth_providers
ADD COLUMN "jwks_uri_override" TEXT,
ADD COLUMN "authorization_endpoint_override" TEXT,
ADD COLUMN "token_endpoint_override" TEXT,
ADD COLUMN "discovery_mode" TEXT NOT NULL DEFAULT 'oidc',
ADD COLUMN "pkce_mode" TEXT NOT NULL DEFAULT 'auto';
5 changes: 5 additions & 0 deletions crates/storage-pg/src/iden.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ pub enum UpstreamOAuthProviders {
TokenEndpointAuthMethod,
CreatedAt,
ClaimsImports,
DiscoveryMode,
PkceMode,
JwksUriOverride,
TokenEndpointOverride,
AuthorizationEndpointOverride,
}

#[derive(sea_query::Iden)]
Expand Down
Loading

0 comments on commit 7981d9a

Please sign in to comment.