Skip to content
This repository has been archived by the owner on Sep 10, 2024. It is now read-only.

Run email, password and registration policies in more places #1656

Merged
merged 3 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions crates/axum-utils/src/fancy_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ pub struct FancyError {
context: ErrorContext,
}

impl FancyError {
#[must_use]
pub fn new(context: ErrorContext) -> Self {
Self { context }
}
}

impl std::fmt::Display for FancyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let code = self.context.code().unwrap_or("Internal error");
Expand Down
2 changes: 1 addition & 1 deletion crates/cli/src/commands/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl Options {
// Listen for SIGHUP
register_sighup(&templates)?;

let graphql_schema = mas_handlers::graphql_schema(&pool, conn);
let graphql_schema = mas_handlers::graphql_schema(&pool, &policy_factory, conn);

let state = {
let mut s = AppState {
Expand Down
12 changes: 9 additions & 3 deletions crates/cli/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,18 @@ pub async fn policy_factory_from_config(
.await
.context("failed to open OPA WASM policy file")?;

let entrypoints = mas_policy::Entrypoints {
register: config.register_entrypoint.clone(),
client_registration: config.client_registration_entrypoint.clone(),
authorization_grant: config.authorization_grant_entrypoint.clone(),
email: config.email_entrypoint.clone(),
password: config.password_entrypoint.clone(),
};

PolicyFactory::load(
policy_file,
config.data.clone().unwrap_or_default(),
config.register_entrypoint.clone(),
config.client_registration_entrypoint.clone(),
config.authorization_grant_entrypoint.clone(),
entrypoints,
)
.await
.context("failed to load the policy")
Expand Down
18 changes: 18 additions & 0 deletions crates/config/src/sections/policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ fn default_authorization_grant_endpoint() -> String {
"authorization_grant/violation".to_owned()
}

fn default_password_endpoint() -> String {
"password/violation".to_owned()
}

fn default_email_endpoint() -> String {
"email/violation".to_owned()
}

/// Application secrets
#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
Expand All @@ -69,6 +77,14 @@ pub struct PolicyConfig {
#[serde(default = "default_authorization_grant_endpoint")]
pub authorization_grant_entrypoint: String,

/// Entrypoint to use when changing password
#[serde(default = "default_password_endpoint")]
pub password_entrypoint: String,

/// Entrypoint to use when adding an email address
#[serde(default = "default_email_endpoint")]
pub email_entrypoint: String,

/// Arbitrary data to pass to the policy
#[serde(default)]
pub data: Option<serde_json::Value>,
Expand All @@ -81,6 +97,8 @@ impl Default for PolicyConfig {
client_registration_entrypoint: default_client_registration_endpoint(),
register_entrypoint: default_register_endpoint(),
authorization_grant_entrypoint: default_authorization_grant_endpoint(),
password_entrypoint: default_password_endpoint(),
email_entrypoint: default_email_endpoint(),
data: None,
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/graphql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ url.workspace = true
oauth2-types = { path = "../oauth2-types" }
mas-data-model = { path = "../data-model" }
mas-matrix = { path = "../matrix" }
mas-policy = { path = "../policy" }
mas-storage = { path = "../storage" }

[[bin]]
Expand Down
28 changes: 26 additions & 2 deletions crates/graphql/src/mutations/user_email.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ pub enum AddEmailStatus {
Exists,
/// The email address is invalid
Invalid,
/// The email address is not allowed by the policy
Denied,
}

/// The payload of the `addEmail` mutation
Expand All @@ -57,6 +59,9 @@ enum AddEmailPayload {
Added(mas_data_model::UserEmail),
Exists(mas_data_model::UserEmail),
Invalid,
Denied {
violations: Vec<mas_policy::Violation>,
},
}

#[Object(use_type_description)]
Expand All @@ -67,6 +72,7 @@ impl AddEmailPayload {
AddEmailPayload::Added(_) => AddEmailStatus::Added,
AddEmailPayload::Exists(_) => AddEmailStatus::Exists,
AddEmailPayload::Invalid => AddEmailStatus::Invalid,
AddEmailPayload::Denied { .. } => AddEmailStatus::Denied,
}
}

Expand All @@ -76,7 +82,7 @@ impl AddEmailPayload {
AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => {
Some(UserEmail(email.clone()))
}
AddEmailPayload::Invalid => None,
AddEmailPayload::Invalid | AddEmailPayload::Denied { .. } => None,
}
}

Expand All @@ -87,7 +93,7 @@ impl AddEmailPayload {

let user_id = match self {
AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => email.user_id,
AddEmailPayload::Invalid => return Ok(None),
AddEmailPayload::Invalid | AddEmailPayload::Denied { .. } => return Ok(None),
};

let user = repo
Expand All @@ -98,6 +104,16 @@ impl AddEmailPayload {

Ok(Some(User(user)))
}

/// The list of policy violations if the email address was denied
async fn violations(&self) -> Option<Vec<String>> {
let AddEmailPayload::Denied { violations } = self else {
return None;
};

let messages = violations.iter().map(|v| v.msg.clone()).collect();
Some(messages)
}
}

/// The input for the `sendVerificationEmail` mutation
Expand Down Expand Up @@ -382,6 +398,14 @@ impl UserEmailMutations {
return Ok(AddEmailPayload::Invalid);
}

let mut policy = state.policy().await?;
let res = policy.evaluate_email(&input.email).await?;
if !res.valid() {
return Ok(AddEmailPayload::Denied {
violations: res.violations,
});
}

// Find an existing email address
let existing_user_email = repo.user_email().find(&user, &input.email).await?;
let (added, user_email) = if let Some(user_email) = existing_user_email {
Expand Down
2 changes: 2 additions & 0 deletions crates/graphql/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
// limitations under the License.

use mas_matrix::HomeserverConnection;
use mas_policy::Policy;
use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError};

use crate::Requester;

#[async_trait::async_trait]
pub trait State {
async fn repository(&self) -> Result<BoxRepository, RepositoryError>;
async fn policy(&self) -> Result<Policy, mas_policy::InstantiateError>;
fn homeserver_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error>;
fn clock(&self) -> BoxClock;
fn rng(&self) -> BoxRng;
Expand Down
45 changes: 30 additions & 15 deletions crates/handlers/src/app_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ use std::{convert::Infallible, sync::Arc, time::Instant};
use axum::{
async_trait,
extract::{FromRef, FromRequestParts},
response::IntoResponse,
response::{IntoResponse, Response},
};
use hyper::StatusCode;
use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory};
use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory;
use mas_policy::{Policy, PolicyFactory};
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock};
use mas_storage_pg::PgRepository;
Expand All @@ -33,7 +33,6 @@ use opentelemetry::{
};
use rand::SeedableRng;
use sqlx::PgPool;
use thiserror::Error;

use crate::{passwords::PasswordManager, upstream_oauth2::cache::MetadataCache, MatrixHomeserver};

Expand Down Expand Up @@ -176,12 +175,6 @@ impl FromRef<AppState> for MatrixHomeserver {
}
}

impl FromRef<AppState> for Arc<PolicyFactory> {
fn from_ref(input: &AppState) -> Self {
input.policy_factory.clone()
}
}

impl FromRef<AppState> for HttpClientFactory {
fn from_ref(input: &AppState) -> Self {
input.http_client_factory.clone()
Expand Down Expand Up @@ -236,19 +229,41 @@ impl FromRequestParts<AppState> for BoxRng {
}
}

#[derive(Debug, Error)]
#[error(transparent)]
pub struct RepositoryError(#[from] mas_storage_pg::DatabaseError);
/// A simple wrapper around an error that implements [`IntoResponse`].
pub struct ErrorWrapper<T>(T);

impl IntoResponse for RepositoryError {
fn into_response(self) -> axum::response::Response {
impl<T> From<T> for ErrorWrapper<T> {
fn from(input: T) -> Self {
Self(input)
}
}

impl<T> IntoResponse for ErrorWrapper<T>
where
T: std::error::Error,
{
fn into_response(self) -> Response {
// TODO: make this a bit more user friendly
(StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response()
}
}

#[async_trait]
impl FromRequestParts<AppState> for Policy {
type Rejection = ErrorWrapper<mas_policy::InstantiateError>;

async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let policy = state.policy_factory.instantiate().await?;
Ok(policy)
}
}

#[async_trait]
impl FromRequestParts<AppState> for BoxRepository {
type Rejection = RepositoryError;
type Rejection = ErrorWrapper<mas_storage_pg::DatabaseError>;

async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
Expand Down
8 changes: 8 additions & 0 deletions crates/handlers/src/graphql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use hyper::header::CACHE_CONTROL;
use mas_axum_utils::{cookies::CookieJar, FancyError, SessionInfo, SessionInfoExt};
use mas_graphql::{Requester, Schema};
use mas_matrix::HomeserverConnection;
use mas_policy::{InstantiateError, Policy, PolicyFactory};
use mas_storage::{
BoxClock, BoxRepository, BoxRng, Clock, Repository, RepositoryError, SystemClock,
};
Expand All @@ -48,6 +49,7 @@ mod tests;
struct GraphQLState {
pool: PgPool,
homeserver_connection: Arc<dyn HomeserverConnection<Error = anyhow::Error>>,
policy_factory: Arc<PolicyFactory>,
}

#[async_trait]
Expand All @@ -60,6 +62,10 @@ impl mas_graphql::State for GraphQLState {
Ok(repo.map_err(RepositoryError::from_error).boxed())
}

async fn policy(&self) -> Result<Policy, InstantiateError> {
self.policy_factory.instantiate().await
}

fn homeserver_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error> {
self.homeserver_connection.as_ref()
}
Expand All @@ -81,10 +87,12 @@ impl mas_graphql::State for GraphQLState {
#[must_use]
pub fn schema(
pool: &PgPool,
policy_factory: &Arc<PolicyFactory>,
homeserver_connection: impl HomeserverConnection<Error = anyhow::Error> + 'static,
) -> Schema {
let state = GraphQLState {
pool: pool.clone(),
policy_factory: Arc::clone(policy_factory),
homeserver_connection: Arc::new(homeserver_connection),
};
let state: mas_graphql::BoxState = Box::new(state);
Expand Down
8 changes: 4 additions & 4 deletions crates/handlers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
clippy::let_with_type_underscore,
)]

use std::{convert::Infallible, sync::Arc, time::Duration};
use std::{convert::Infallible, time::Duration};

use axum::{
body::{Bytes, HttpBody},
Expand All @@ -50,7 +50,7 @@ use hyper::{
use mas_axum_utils::{cookies::CookieJar, FancyError};
use mas_http::CorsLayerExt;
use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory;
use mas_policy::Policy;
use mas_router::{Route, UrlBuilder};
use mas_storage::{BoxClock, BoxRepository, BoxRng};
use mas_templates::{ErrorContext, NotFoundContext, Templates};
Expand Down Expand Up @@ -166,12 +166,12 @@ where
S: Clone + Send + Sync + 'static,
Keystore: FromRef<S>,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
BoxRepository: FromRequestParts<S>,
Encrypter: FromRef<S>,
HttpClientFactory: FromRef<S>,
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,
Policy: FromRequestParts<S>,
{
// All those routes are API-like, with a common CORS layer
Router::new()
Expand Down Expand Up @@ -267,7 +267,6 @@ where
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Clone + Send + Sync + 'static,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
BoxRepository: FromRequestParts<S>,
CookieJar: FromRequestParts<S>,
Encrypter: FromRef<S>,
Expand All @@ -278,6 +277,7 @@ where
MetadataCache: FromRef<S>,
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,
Policy: FromRequestParts<S>,
{
Router::new()
// XXX: hard-coded redirect from /account to /account/
Expand Down
Loading