Skip to content
This repository has been archived by the owner on Sep 10, 2024. It is now read-only.

Commit

Permalink
Make sure we validate passwords & emails by the policy at all stages
Browse files Browse the repository at this point in the history
Also refactors the way we get the policy engines in requests
  • Loading branch information
sandhose committed Aug 30, 2023
1 parent 05aa5b6 commit cda799e
Show file tree
Hide file tree
Showing 30 changed files with 265 additions and 85 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions crates/axum-utils/src/fancy_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ pub struct FancyError {
context: ErrorContext,
}

impl FancyError {
pub fn new(context: ErrorContext) -> Self {
Self { context }
}
}

impl std::fmt::Display for FancyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let code = self.context.code().unwrap_or("Internal error");
Expand Down
2 changes: 1 addition & 1 deletion crates/cli/src/commands/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl Options {
// Listen for SIGHUP
register_sighup(&templates)?;

let graphql_schema = mas_handlers::graphql_schema(&pool, conn);
let graphql_schema = mas_handlers::graphql_schema(&pool, &policy_factory, conn);

let state = {
let mut s = AppState {
Expand Down
1 change: 1 addition & 0 deletions crates/graphql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ url.workspace = true
oauth2-types = { path = "../oauth2-types" }
mas-data-model = { path = "../data-model" }
mas-matrix = { path = "../matrix" }
mas-policy = { path = "../policy" }
mas-storage = { path = "../storage" }

[[bin]]
Expand Down
28 changes: 26 additions & 2 deletions crates/graphql/src/mutations/user_email.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ pub enum AddEmailStatus {
Exists,
/// The email address is invalid
Invalid,
/// The email address is not allowed by the policy
Denied,
}

/// The payload of the `addEmail` mutation
Expand All @@ -57,6 +59,9 @@ enum AddEmailPayload {
Added(mas_data_model::UserEmail),
Exists(mas_data_model::UserEmail),
Invalid,
Denied {
violations: Vec<mas_policy::Violation>,
},
}

#[Object(use_type_description)]
Expand All @@ -67,6 +72,7 @@ impl AddEmailPayload {
AddEmailPayload::Added(_) => AddEmailStatus::Added,
AddEmailPayload::Exists(_) => AddEmailStatus::Exists,
AddEmailPayload::Invalid => AddEmailStatus::Invalid,
AddEmailPayload::Denied { .. } => AddEmailStatus::Denied,
}
}

Expand All @@ -76,7 +82,7 @@ impl AddEmailPayload {
AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => {
Some(UserEmail(email.clone()))
}
AddEmailPayload::Invalid => None,
AddEmailPayload::Invalid | AddEmailPayload::Denied { .. } => None,
}
}

Expand All @@ -87,7 +93,7 @@ impl AddEmailPayload {

let user_id = match self {
AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => email.user_id,
AddEmailPayload::Invalid => return Ok(None),
AddEmailPayload::Invalid | AddEmailPayload::Denied { .. } => return Ok(None),
};

let user = repo
Expand All @@ -98,6 +104,16 @@ impl AddEmailPayload {

Ok(Some(User(user)))
}

/// The list of policy violations if the email address was denied
async fn violations(&self) -> Option<Vec<String>> {
let AddEmailPayload::Denied { violations } = self else {
return None;
};

let messages = violations.iter().map(|v| v.msg.clone()).collect();
Some(messages)
}
}

/// The input for the `sendVerificationEmail` mutation
Expand Down Expand Up @@ -382,6 +398,14 @@ impl UserEmailMutations {
return Ok(AddEmailPayload::Invalid);
}

let mut policy = state.policy().await?;
let res = policy.evaluate_email(&input.email).await?;
if !res.valid() {
return Ok(AddEmailPayload::Denied {
violations: res.violations,
});
}

// Find an existing email address
let existing_user_email = repo.user_email().find(&user, &input.email).await?;
let (added, user_email) = if let Some(user_email) = existing_user_email {
Expand Down
2 changes: 2 additions & 0 deletions crates/graphql/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
// limitations under the License.

use mas_matrix::HomeserverConnection;
use mas_policy::Policy;
use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError};

use crate::Requester;

#[async_trait::async_trait]
pub trait State {
async fn repository(&self) -> Result<BoxRepository, RepositoryError>;
async fn policy(&self) -> Result<Policy, mas_policy::InstantiateError>;
fn homeserver_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error>;
fn clock(&self) -> BoxClock;
fn rng(&self) -> BoxRng;
Expand Down
45 changes: 30 additions & 15 deletions crates/handlers/src/app_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ use std::{convert::Infallible, sync::Arc, time::Instant};
use axum::{
async_trait,
extract::{FromRef, FromRequestParts},
response::IntoResponse,
response::{IntoResponse, Response},
};
use hyper::StatusCode;
use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory};
use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory;
use mas_policy::{Policy, PolicyFactory};
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock};
use mas_storage_pg::PgRepository;
Expand All @@ -33,7 +33,6 @@ use opentelemetry::{
};
use rand::SeedableRng;
use sqlx::PgPool;
use thiserror::Error;

use crate::{passwords::PasswordManager, upstream_oauth2::cache::MetadataCache, MatrixHomeserver};

Expand Down Expand Up @@ -176,12 +175,6 @@ impl FromRef<AppState> for MatrixHomeserver {
}
}

impl FromRef<AppState> for Arc<PolicyFactory> {
fn from_ref(input: &AppState) -> Self {
input.policy_factory.clone()
}
}

impl FromRef<AppState> for HttpClientFactory {
fn from_ref(input: &AppState) -> Self {
input.http_client_factory.clone()
Expand Down Expand Up @@ -236,19 +229,41 @@ impl FromRequestParts<AppState> for BoxRng {
}
}

#[derive(Debug, Error)]
#[error(transparent)]
pub struct RepositoryError(#[from] mas_storage_pg::DatabaseError);
/// A simple wrapper around an error that implements [`IntoResponse`].
pub struct ErrorWrapper<T>(T);

impl IntoResponse for RepositoryError {
fn into_response(self) -> axum::response::Response {
impl<T> From<T> for ErrorWrapper<T> {
fn from(input: T) -> Self {
Self(input)
}
}

impl<T> IntoResponse for ErrorWrapper<T>
where
T: std::error::Error,
{
fn into_response(self) -> Response {
// TODO: make this a bit more user friendly
(StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response()
}
}

#[async_trait]
impl FromRequestParts<AppState> for Policy {
type Rejection = ErrorWrapper<mas_policy::InstantiateError>;

async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let policy = state.policy_factory.instantiate().await?;
Ok(policy)
}
}

#[async_trait]
impl FromRequestParts<AppState> for BoxRepository {
type Rejection = RepositoryError;
type Rejection = ErrorWrapper<mas_storage_pg::DatabaseError>;

async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
Expand Down
8 changes: 8 additions & 0 deletions crates/handlers/src/graphql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use hyper::header::CACHE_CONTROL;
use mas_axum_utils::{cookies::CookieJar, FancyError, SessionInfo, SessionInfoExt};
use mas_graphql::{Requester, Schema};
use mas_matrix::HomeserverConnection;
use mas_policy::{InstantiateError, Policy, PolicyFactory};
use mas_storage::{
BoxClock, BoxRepository, BoxRng, Clock, Repository, RepositoryError, SystemClock,
};
Expand All @@ -48,6 +49,7 @@ mod tests;
struct GraphQLState {
pool: PgPool,
homeserver_connection: Arc<dyn HomeserverConnection<Error = anyhow::Error>>,
policy_factory: Arc<PolicyFactory>,
}

#[async_trait]
Expand All @@ -60,6 +62,10 @@ impl mas_graphql::State for GraphQLState {
Ok(repo.map_err(RepositoryError::from_error).boxed())
}

async fn policy(&self) -> Result<Policy, InstantiateError> {
self.policy_factory.instantiate().await
}

fn homeserver_connection(&self) -> &dyn HomeserverConnection<Error = anyhow::Error> {
self.homeserver_connection.as_ref()
}
Expand All @@ -81,10 +87,12 @@ impl mas_graphql::State for GraphQLState {
#[must_use]
pub fn schema(
pool: &PgPool,
policy_factory: &Arc<PolicyFactory>,
homeserver_connection: impl HomeserverConnection<Error = anyhow::Error> + 'static,
) -> Schema {
let state = GraphQLState {
pool: pool.clone(),
policy_factory: Arc::clone(policy_factory),
homeserver_connection: Arc::new(homeserver_connection),
};
let state: mas_graphql::BoxState = Box::new(state);
Expand Down
8 changes: 4 additions & 4 deletions crates/handlers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
clippy::let_with_type_underscore,
)]

use std::{convert::Infallible, sync::Arc, time::Duration};
use std::{convert::Infallible, time::Duration};

use axum::{
body::{Bytes, HttpBody},
Expand All @@ -50,7 +50,7 @@ use hyper::{
use mas_axum_utils::{cookies::CookieJar, FancyError};
use mas_http::CorsLayerExt;
use mas_keystore::{Encrypter, Keystore};
use mas_policy::PolicyFactory;
use mas_policy::Policy;
use mas_router::{Route, UrlBuilder};
use mas_storage::{BoxClock, BoxRepository, BoxRng};
use mas_templates::{ErrorContext, NotFoundContext, Templates};
Expand Down Expand Up @@ -166,12 +166,12 @@ where
S: Clone + Send + Sync + 'static,
Keystore: FromRef<S>,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
BoxRepository: FromRequestParts<S>,
Encrypter: FromRef<S>,
HttpClientFactory: FromRef<S>,
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,
Policy: FromRequestParts<S>,
{
// All those routes are API-like, with a common CORS layer
Router::new()
Expand Down Expand Up @@ -267,7 +267,6 @@ where
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Clone + Send + Sync + 'static,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
BoxRepository: FromRequestParts<S>,
CookieJar: FromRequestParts<S>,
Encrypter: FromRef<S>,
Expand All @@ -278,6 +277,7 @@ where
MetadataCache: FromRef<S>,
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,
Policy: FromRequestParts<S>,
{
Router::new()
// XXX: hard-coded redirect from /account to /account/
Expand Down
13 changes: 4 additions & 9 deletions crates/handlers/src/oauth2/authorization/complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use axum::{
extract::{Path, State},
response::{Html, IntoResponse, Response},
Expand All @@ -22,7 +20,7 @@ use hyper::StatusCode;
use mas_axum_utils::{cookies::CookieJar, csrf::CsrfExt, SessionInfoExt};
use mas_data_model::{AuthorizationGrant, BrowserSession, Client, Device};
use mas_keystore::Keystore;
use mas_policy::{EvaluationResult, PolicyFactory};
use mas_policy::{EvaluationResult, Policy};
use mas_router::{PostAuthAction, Route, UrlBuilder};
use mas_storage::{
oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository},
Expand Down Expand Up @@ -76,7 +74,6 @@ impl IntoResponse for RouteError {
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::InstantiateError);
impl_from_error_for_route!(mas_policy::EvaluationError);
impl_from_error_for_route!(super::callback::IntoCallbackDestinationError);
impl_from_error_for_route!(super::callback::CallbackDestinationError);
Expand All @@ -90,10 +87,10 @@ impl_from_error_for_route!(super::callback::CallbackDestinationError);
pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
State(policy_factory): State<Arc<PolicyFactory>>,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
State(key_store): State<Keystore>,
policy: Policy,
mut repo: BoxRepository,
cookie_jar: CookieJar,
Path(grant_id): Path<Ulid>,
Expand Down Expand Up @@ -128,7 +125,7 @@ pub(crate) async fn get(
&clock,
repo,
key_store,
&policy_factory,
policy,
url_builder,
grant,
&client,
Expand Down Expand Up @@ -187,7 +184,6 @@ pub enum GrantCompletionError {
impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError);
impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::InstantiateError);
impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError);
impl_from_error_for_route!(GrantCompletionError: super::super::IdTokenSignatureError);

Expand All @@ -196,7 +192,7 @@ pub(crate) async fn complete(
clock: &impl Clock,
mut repo: BoxRepository,
key_store: Keystore,
policy_factory: &PolicyFactory,
mut policy: Policy,
url_builder: UrlBuilder,
grant: AuthorizationGrant,
client: &Client,
Expand All @@ -220,7 +216,6 @@ pub(crate) async fn complete(
};

// Run through the policy
let mut policy = policy_factory.instantiate().await?;
let res = policy
.evaluate_authorization_grant(&grant, client, &browser_session.user)
.await?;
Expand Down
Loading

0 comments on commit cda799e

Please sign in to comment.