From 05aa5b6a5c9d3028a4d42369bf73d058f412150d Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 30 Aug 2023 15:01:37 +0200 Subject: [PATCH 1/3] policies: split the email & password policies and add jsonschema validation of the input --- Cargo.lock | 1 + crates/cli/src/util.rs | 12 +- crates/config/src/sections/policy.rs | 18 ++ .../src/oauth2/authorization/complete.rs | 4 +- .../handlers/src/oauth2/authorization/mod.rs | 2 +- crates/handlers/src/oauth2/consent.rs | 2 +- crates/handlers/src/oauth2/registration.rs | 2 +- crates/handlers/src/test_utils.rs | 17 +- crates/policy/Cargo.toml | 8 +- crates/policy/src/bin/schema.rs | 55 +++++ crates/policy/src/lib.rs | 199 ++++++++++-------- crates/policy/src/model.rs | 96 +++++++++ misc/update.sh | 2 + policies/Makefile | 15 +- policies/authorization_grant.rego | 3 + policies/client_registration.rego | 3 + policies/email.rego | 35 +++ policies/password.rego | 30 +++ policies/register.rego | 58 ++--- policies/register_test.rego | 53 +++-- .../schema/authorization_grant_input.json | 24 +++ .../schema/client_registration_input.json | 14 ++ policies/schema/email_input.json | 13 ++ policies/schema/password_input.json | 13 ++ policies/schema/register_input.json | 32 +++ 25 files changed, 547 insertions(+), 164 deletions(-) create mode 100644 crates/policy/src/bin/schema.rs create mode 100644 crates/policy/src/model.rs create mode 100644 policies/email.rego create mode 100644 policies/password.rego create mode 100644 policies/schema/authorization_grant_input.json create mode 100644 policies/schema/client_registration_input.json create mode 100644 policies/schema/email_input.json create mode 100644 policies/schema/password_input.json create mode 100644 policies/schema/register_input.json diff --git a/Cargo.lock b/Cargo.lock index 347eb3bed..a6533dce3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3047,6 +3047,7 @@ dependencies = [ "mas-data-model", "oauth2-types", "opa-wasm", + "schemars", "serde", "serde_json", "thiserror", diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index 5eb934ba1..2ebe2efe9 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -97,12 +97,18 @@ pub async fn policy_factory_from_config( .await .context("failed to open OPA WASM policy file")?; + let entrypoints = mas_policy::Entrypoints { + register: config.register_entrypoint.clone(), + client_registration: config.client_registration_entrypoint.clone(), + authorization_grant: config.authorization_grant_entrypoint.clone(), + email: config.email_entrypoint.clone(), + password: config.password_entrypoint.clone(), + }; + PolicyFactory::load( policy_file, config.data.clone().unwrap_or_default(), - config.register_entrypoint.clone(), - config.client_registration_entrypoint.clone(), - config.authorization_grant_entrypoint.clone(), + entrypoints, ) .await .context("failed to load the policy") diff --git a/crates/config/src/sections/policy.rs b/crates/config/src/sections/policy.rs index 9317cfb92..b3e14954c 100644 --- a/crates/config/src/sections/policy.rs +++ b/crates/config/src/sections/policy.rs @@ -48,6 +48,14 @@ fn default_authorization_grant_endpoint() -> String { "authorization_grant/violation".to_owned() } +fn default_password_endpoint() -> String { + "password/violation".to_owned() +} + +fn default_email_endpoint() -> String { + "email/violation".to_owned() +} + /// Application secrets #[serde_as] #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] @@ -69,6 +77,14 @@ pub struct PolicyConfig { #[serde(default = "default_authorization_grant_endpoint")] pub authorization_grant_entrypoint: String, + /// Entrypoint to use when changing password + #[serde(default = "default_password_endpoint")] + pub password_entrypoint: String, + + /// Entrypoint to use when adding an email address + #[serde(default = "default_email_endpoint")] + pub email_entrypoint: String, + /// Arbitrary data to pass to the policy #[serde(default)] pub data: Option, @@ -81,6 +97,8 @@ impl Default for PolicyConfig { client_registration_entrypoint: default_client_registration_endpoint(), register_entrypoint: default_register_endpoint(), authorization_grant_entrypoint: default_authorization_grant_endpoint(), + password_entrypoint: default_password_endpoint(), + email_entrypoint: default_email_endpoint(), data: None, } } diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index 687c1c648..6c5b7b7f4 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -76,7 +76,7 @@ impl IntoResponse for RouteError { impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_policy::LoadError); -impl_from_error_for_route!(mas_policy::InstanciateError); +impl_from_error_for_route!(mas_policy::InstantiateError); impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(super::callback::CallbackDestinationError); @@ -187,7 +187,7 @@ pub enum GrantCompletionError { impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError); impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError); -impl_from_error_for_route!(GrantCompletionError: mas_policy::InstanciateError); +impl_from_error_for_route!(GrantCompletionError: mas_policy::InstantiateError); impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError); impl_from_error_for_route!(GrantCompletionError: super::super::IdTokenSignatureError); diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index ecc3af3db..8fec59fe7 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -94,7 +94,7 @@ impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(self::callback::CallbackDestinationError); impl_from_error_for_route!(mas_policy::LoadError); -impl_from_error_for_route!(mas_policy::InstanciateError); +impl_from_error_for_route!(mas_policy::InstantiateError); impl_from_error_for_route!(mas_policy::EvaluationError); #[derive(Deserialize)] diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index 916c36818..85acb82fc 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -61,7 +61,7 @@ pub enum RouteError { impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_policy::LoadError); -impl_from_error_for_route!(mas_policy::InstanciateError); +impl_from_error_for_route!(mas_policy::InstantiateError); impl_from_error_for_route!(mas_policy::EvaluationError); impl IntoResponse for RouteError { diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index e7859fb11..d2f2c5e36 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -49,7 +49,7 @@ pub(crate) enum RouteError { impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_policy::LoadError); -impl_from_error_for_route!(mas_policy::InstanciateError); +impl_from_error_for_route!(mas_policy::InstantiateError); impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(mas_keystore::aead::Error); diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index c06a8ba51..ca3a581fc 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -117,14 +117,15 @@ impl TestState { let file = tokio::fs::File::open(workspace_root.join("policies").join("policy.wasm")).await?; - let policy_factory = PolicyFactory::load( - file, - serde_json::json!({}), - "register/violation".to_owned(), - "client_registration/violation".to_owned(), - "authorization_grant/violation".to_owned(), - ) - .await?; + let entrypoints = mas_policy::Entrypoints { + register: "register/violation".to_owned(), + client_registration: "client_registration/violation".to_owned(), + authorization_grant: "authorization_grant/violation".to_owned(), + email: "email/violation".to_owned(), + password: "password/violation".to_owned(), + }; + + let policy_factory = PolicyFactory::load(file, serde_json::json!({}), entrypoints).await?; let homeserver_connection = MockHomeserverConnection::new("example.com"); diff --git a/crates/policy/Cargo.toml b/crates/policy/Cargo.toml index 3c779e3c9..5f25991f4 100644 --- a/crates/policy/Cargo.toml +++ b/crates/policy/Cargo.toml @@ -10,8 +10,9 @@ anyhow.workspace = true opa-wasm = { git = "https://github.com/matrix-org/rust-opa-wasm.git" } serde.workspace = true serde_json.workspace = true +schemars = {version = "0.8.1", optional = true } thiserror.workspace = true -tokio = { version = "1.32.0", features = ["io-util"] } +tokio = { version = "1.32.0", features = ["io-util", "rt"] } tracing.workspace = true wasmtime = { version = "12.0.1", default-features = false, features = ["async", "cranelift"] } @@ -23,3 +24,8 @@ tokio = { version = "1.32.0", features = ["fs", "rt", "macros"] } [features] cache = ["wasmtime/cache"] +jsonschema = ["dep:schemars"] + +[[bin]] +name = "schema" +required-features = ["jsonschema"] \ No newline at end of file diff --git a/crates/policy/src/bin/schema.rs b/crates/policy/src/bin/schema.rs new file mode 100644 index 000000000..7742a28aa --- /dev/null +++ b/crates/policy/src/bin/schema.rs @@ -0,0 +1,55 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::path::{Path, PathBuf}; + +use mas_policy::model::{ + AuthorizationGrantInput, ClientRegistrationInput, EmailInput, PasswordInput, RegisterInput, +}; +use schemars::{gen::SchemaSettings, JsonSchema}; + +fn write_schema(out_dir: Option<&Path>, file: &str) { + let mut writer: Box = match out_dir { + Some(out_dir) => { + let path = out_dir.join(file); + eprintln!("Writing to {path:?}"); + let file = std::fs::File::create(path).expect("Failed to create file"); + Box::new(std::io::BufWriter::new(file)) + } + None => { + eprintln!("--- {file} ---"); + Box::new(std::io::stdout()) + } + }; + + let settings = SchemaSettings::draft07().with(|s| { + s.option_nullable = false; + s.option_add_null_type = false; + }); + let generator = settings.into_generator(); + let schema = generator.into_root_schema_for::(); + serde_json::to_writer_pretty(&mut writer, &schema).expect("Failed to serialize schema"); + writer.flush().expect("Failed to flush writer"); +} + +fn main() { + let output_root = std::env::var("OUT_DIR").map(PathBuf::from).ok(); + let output_root = output_root.as_deref(); + + write_schema::(output_root, "register_input.json"); + write_schema::(output_root, "client_registration_input.json"); + write_schema::(output_root, "authorization_grant_input.json"); + write_schema::(output_root, "email_input.json"); + write_schema::(output_root, "password_input.json"); +} diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index 665afe5f7..a9d44c48b 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. +// Copyright 2022-2023 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,14 +17,20 @@ #![warn(clippy::pedantic)] #![allow(clippy::missing_errors_doc)] +pub mod model; + use mas_data_model::{AuthorizationGrant, Client, User}; use oauth2_types::registration::VerifiedClientMetadata; use opa_wasm::Runtime; -use serde::Deserialize; use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt}; use wasmtime::{Config, Engine, Module, Store}; +use self::model::{ + AuthorizationGrantInput, ClientRegistrationInput, EmailInput, PasswordInput, RegisterInput, +}; +pub use self::model::{EvaluationResult, Violation}; + #[derive(Debug, Error)] pub enum LoadError { #[error("failed to read module")] @@ -40,7 +46,7 @@ pub enum LoadError { Compilation(#[source] anyhow::Error), #[error("failed to instantiate a test instance")] - Instantiate(#[source] InstanciateError), + Instantiate(#[source] InstantiateError), #[cfg(feature = "cache")] #[error("could not load wasmtime cache configuration")] @@ -48,7 +54,7 @@ pub enum LoadError { } #[derive(Debug, Error)] -pub enum InstanciateError { +pub enum InstantiateError { #[error("failed to create WASM runtime")] Runtime(#[source] anyhow::Error), @@ -59,13 +65,33 @@ pub enum InstanciateError { LoadData(#[source] anyhow::Error), } +/// Holds the entrypoint of each policy +#[derive(Debug, Clone)] +pub struct Entrypoints { + pub register: String, + pub client_registration: String, + pub authorization_grant: String, + pub email: String, + pub password: String, +} + +impl Entrypoints { + fn all(&self) -> [&str; 5] { + [ + self.register.as_str(), + self.client_registration.as_str(), + self.authorization_grant.as_str(), + self.email.as_str(), + self.password.as_str(), + ] + } +} + pub struct PolicyFactory { engine: Engine, module: Module, data: serde_json::Value, - register_entrypoint: String, - client_registration_entrypoint: String, - authorization_grant_endpoint: String, + entrypoints: Entrypoints, } impl PolicyFactory { @@ -73,9 +99,7 @@ impl PolicyFactory { pub async fn load( mut source: impl AsyncRead + std::marker::Unpin, data: serde_json::Value, - register_entrypoint: String, - client_registration_entrypoint: String, - authorization_grant_endpoint: String, + entrypoints: Entrypoints, ) -> Result { let mut config = Config::default(); config.async_support(true); @@ -103,9 +127,7 @@ impl PolicyFactory { engine, module, data, - register_entrypoint, - client_registration_entrypoint, - authorization_grant_endpoint, + entrypoints, }; // Try to instantiate @@ -118,22 +140,18 @@ impl PolicyFactory { } #[tracing::instrument(name = "policy.instantiate", skip_all, err)] - pub async fn instantiate(&self) -> Result { + pub async fn instantiate(&self) -> Result { let mut store = Store::new(&self.engine, ()); let runtime = Runtime::new(&mut store, &self.module) .await - .map_err(InstanciateError::Runtime)?; + .map_err(InstantiateError::Runtime)?; // Check that we have the required entrypoints - let entrypoints = runtime.entrypoints(); - - for e in [ - self.register_entrypoint.as_str(), - self.client_registration_entrypoint.as_str(), - self.authorization_grant_endpoint.as_str(), - ] { - if !entrypoints.contains(e) { - return Err(InstanciateError::MissingEntrypoint { + let policy_entrypoints = runtime.entrypoints(); + + for e in self.entrypoints.all() { + if !policy_entrypoints.contains(e) { + return Err(InstantiateError::MissingEntrypoint { entrypoint: e.to_owned(), }); } @@ -142,43 +160,20 @@ impl PolicyFactory { let instance = runtime .with_data(&mut store, &self.data) .await - .map_err(InstanciateError::LoadData)?; + .map_err(InstantiateError::LoadData)?; Ok(Policy { store, instance, - register_entrypoint: self.register_entrypoint.clone(), - client_registration_entrypoint: self.client_registration_entrypoint.clone(), - authorization_grant_endpoint: self.authorization_grant_endpoint.clone(), + entrypoints: self.entrypoints.clone(), }) } } -#[derive(Deserialize, Debug)] -pub struct Violation { - pub msg: String, - pub field: Option, -} - -#[derive(Deserialize, Debug)] -pub struct EvaluationResult { - #[serde(rename = "result")] - pub violations: Vec, -} - -impl EvaluationResult { - #[must_use] - pub fn valid(&self) -> bool { - self.violations.is_empty() - } -} - pub struct Policy { store: Store<()>, instance: opa_wasm::Policy, - register_entrypoint: String, - client_registration_entrypoint: String, - authorization_grant_endpoint: String, + entrypoints: Entrypoints, } #[derive(Debug, Error)] @@ -189,11 +184,50 @@ pub enum EvaluationError { } impl Policy { + #[tracing::instrument( + name = "policy.evaluate_email", + skip_all, + fields( + input.email = email, + ), + err, + )] + pub async fn evaluate_email( + &mut self, + email: &str, + ) -> Result { + let input = EmailInput { email }; + + let [res]: [EvaluationResult; 1] = self + .instance + .evaluate(&mut self.store, &self.entrypoints.email, &input) + .await?; + + Ok(res) + } + + #[tracing::instrument(name = "policy.evaluate_password", skip_all, err)] + pub async fn evaluate_password( + &mut self, + password: &str, + ) -> Result { + let input = PasswordInput { password }; + + let [res]: [EvaluationResult; 1] = self + .instance + .evaluate(&mut self.store, &self.entrypoints.password, &input) + .await?; + + Ok(res) + } + #[tracing::instrument( name = "policy.evaluate.register", skip_all, fields( - data.username = username, + input.registration_method = "password", + input.user.username = username, + input.user.email = email, ), err, )] @@ -203,17 +237,15 @@ impl Policy { password: &str, email: &str, ) -> Result { - let input = serde_json::json!({ - "user": { - "username": username, - "password": password, - "email": email - } - }); + let input = RegisterInput::Password { + username, + password, + email, + }; let [res]: [EvaluationResult; 1] = self .instance - .evaluate(&mut self.store, &self.register_entrypoint, &input) + .evaluate(&mut self.store, &self.entrypoints.register, &input) .await?; Ok(res) @@ -224,16 +256,13 @@ impl Policy { &mut self, client_metadata: &VerifiedClientMetadata, ) -> Result { - let client_metadata = serde_json::to_value(client_metadata)?; - let input = serde_json::json!({ - "client_metadata": client_metadata, - }); + let input = ClientRegistrationInput { client_metadata }; let [res]: [EvaluationResult; 1] = self .instance .evaluate( &mut self.store, - &self.client_registration_entrypoint, + &self.entrypoints.client_registration, &input, ) .await?; @@ -245,9 +274,9 @@ impl Policy { name = "policy.evaluate.authorization_grant", skip_all, fields( - data.authorization_grant.id = %authorization_grant.id, - data.client.id = %client.id, - data.user.id = %user.id, + input.authorization_grant.id = %authorization_grant.id, + input.client.id = %client.id, + input.user.id = %user.id, ), err, )] @@ -257,17 +286,19 @@ impl Policy { client: &Client, user: &User, ) -> Result { - let authorization_grant = serde_json::to_value(authorization_grant)?; - let user = serde_json::to_value(user)?; - let input = serde_json::json!({ - "authorization_grant": authorization_grant, - "client": client, - "user": user, - }); + let input = AuthorizationGrantInput { + user, + client, + authorization_grant, + }; let [res]: [EvaluationResult; 1] = self .instance - .evaluate(&mut self.store, &self.authorization_grant_endpoint, &input) + .evaluate( + &mut self.store, + &self.entrypoints.authorization_grant, + &input, + ) .await?; Ok(res) @@ -294,15 +325,15 @@ mod tests { let file = tokio::fs::File::open(path).await.unwrap(); - let factory = PolicyFactory::load( - file, - data, - "register/violation".to_owned(), - "client_registration/violation".to_owned(), - "authorization_grant/violation".to_owned(), - ) - .await - .unwrap(); + let entrypoints = Entrypoints { + register: "register/violation".to_owned(), + client_registration: "client_registration/violation".to_owned(), + authorization_grant: "authorization_grant/violation".to_owned(), + email: "email/violation".to_owned(), + password: "password/violation".to_owned(), + }; + + let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap(); let mut policy = factory.instantiate().await.unwrap(); diff --git a/crates/policy/src/model.rs b/crates/policy/src/model.rs new file mode 100644 index 000000000..3cc9ff1f7 --- /dev/null +++ b/crates/policy/src/model.rs @@ -0,0 +1,96 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use mas_data_model::{AuthorizationGrant, Client, User}; +use oauth2_types::registration::VerifiedClientMetadata; +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Debug)] +#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] +pub struct Violation { + pub msg: String, + pub field: Option, +} + +#[derive(Deserialize, Debug)] +pub struct EvaluationResult { + #[serde(rename = "result")] + pub violations: Vec, +} + +impl EvaluationResult { + #[must_use] + pub fn valid(&self) -> bool { + self.violations.is_empty() + } +} + +#[derive(Serialize, Debug)] +#[serde(tag = "registration_method", rename_all = "snake_case")] +#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] +pub enum RegisterInput<'a> { + Password { + username: &'a str, + password: &'a str, + email: &'a str, + }, +} + +#[derive(Serialize, Debug)] +#[serde(rename_all = "snake_case")] +#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] +pub struct ClientRegistrationInput<'a> { + #[cfg_attr( + feature = "jsonschema", + schemars(with = "std::collections::HashMap") + )] + pub client_metadata: &'a VerifiedClientMetadata, +} + +#[derive(Serialize, Debug)] +#[serde(rename_all = "snake_case")] +#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] +pub struct AuthorizationGrantInput<'a> { + #[cfg_attr( + feature = "jsonschema", + schemars(with = "std::collections::HashMap") + )] + pub user: &'a User, + + #[cfg_attr( + feature = "jsonschema", + schemars(with = "std::collections::HashMap") + )] + pub client: &'a Client, + + #[cfg_attr( + feature = "jsonschema", + schemars(with = "std::collections::HashMap") + )] + pub authorization_grant: &'a AuthorizationGrant, +} + +#[derive(Serialize, Debug)] +#[serde(rename_all = "snake_case")] +#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] +pub struct EmailInput<'a> { + pub email: &'a str, +} + +#[derive(Serialize, Debug)] +#[serde(rename_all = "snake_case")] +#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] +pub struct PasswordInput<'a> { + pub password: &'a str, +} diff --git a/misc/update.sh b/misc/update.sh index 52d7ac36b..91e275df2 100644 --- a/misc/update.sh +++ b/misc/update.sh @@ -6,10 +6,12 @@ export SQLX_OFFLINE=1 BASE_DIR="$(dirname "$0")/.." CONFIG_SCHEMA="${BASE_DIR}/docs/config.schema.json" GRAPHQL_SCHEMA="${BASE_DIR}/frontend/schema.graphql" +POLICIES_SCHEMA="${BASE_DIR}/policies/schema/" set -x cargo run -p mas-config > "${CONFIG_SCHEMA}" cargo run -p mas-graphql > "${GRAPHQL_SCHEMA}" +OUT_DIR="${POLICIES_SCHEMA}" cargo run -p mas-policy --features jsonschema cd "${BASE_DIR}/frontend" npm run generate diff --git a/policies/Makefile b/policies/Makefile index d110c21da..2c9ff7c93 100644 --- a/policies/Makefile +++ b/policies/Makefile @@ -1,6 +1,13 @@ # Set to 1 to run OPA through Docker DOCKER := 0 -OPA_DOCKER_IMAGE := docker.io/openpolicyagent/opa:0.55.0 +OPA_DOCKER_IMAGE := docker.io/openpolicyagent/opa:0.55.0-debug + +INPUTS := \ + client_registration.rego \ + register.rego \ + authorization_grant.rego \ + password.rego \ + email.rego ifeq ($(DOCKER), 0) OPA := opa @@ -10,11 +17,13 @@ else OPA_RW := docker run -i -v $(shell pwd):/policies -w /policies --rm $(OPA_DOCKER_IMAGE) endif -policy.wasm: client_registration.rego register.rego authorization_grant.rego +policy.wasm: $(INPUTS) $(OPA_RW) build -t wasm \ -e "client_registration/violation" \ -e "register/violation" \ -e "authorization_grant/violation" \ + -e "password/violation" \ + -e "email/violation" \ $^ tar xzf bundle.tar.gz /policy.wasm $(RM) bundle.tar.gz @@ -26,7 +35,7 @@ fmt: .PHONY: test test: - $(OPA) test -v ./*.rego + $(OPA) test --schema ./schema/ -v ./*.rego .PHONY: coverage coverage: diff --git a/policies/authorization_grant.rego b/policies/authorization_grant.rego index 2fd0c7171..d59c6c572 100644 --- a/policies/authorization_grant.rego +++ b/policies/authorization_grant.rego @@ -1,3 +1,6 @@ +# METADATA +# schemas: +# - input: schema["authorization_grant_input"] package authorization_grant import future.keywords.in diff --git a/policies/client_registration.rego b/policies/client_registration.rego index 7ea671f04..a41375cf3 100644 --- a/policies/client_registration.rego +++ b/policies/client_registration.rego @@ -1,3 +1,6 @@ +# METADATA +# schemas: +# - input: schema["client_registration_input"] package client_registration import future.keywords.in diff --git a/policies/email.rego b/policies/email.rego new file mode 100644 index 000000000..fecad108f --- /dev/null +++ b/policies/email.rego @@ -0,0 +1,35 @@ +# METADATA +# schemas: +# - input: schema["email_input"] +package email + +import future.keywords.in + +default allow := false + +allow { + count(violation) == 0 +} + +# Allow any domains if the data.allowed_domains array is not set +email_domain_allowed { + not data.allowed_domains +} + +# Allow an email only if its domain is in the list of allowed domains +email_domain_allowed { + [_, domain] := split(input.email, "@") + some allowed_domain in data.allowed_domains + glob.match(allowed_domain, ["."], domain) +} + +violation[{"msg": "email domain is not allowed"}] { + not email_domain_allowed +} + +# Deny emails with their domain in the domains banlist +violation[{"msg": "email domain is banned"}] { + [_, domain] := split(input.email, "@") + some banned_domain in data.banned_domains + glob.match(banned_domain, ["."], domain) +} diff --git a/policies/password.rego b/policies/password.rego new file mode 100644 index 000000000..bae1c215a --- /dev/null +++ b/policies/password.rego @@ -0,0 +1,30 @@ +# METADATA +# schemas: +# - input: schema["password_input"] +package password + +default allow := false + +allow { + count(violation) == 0 +} + +violation[{"msg": msg}] { + count(input.password) < data.passwords.min_length + msg := sprintf("needs to be at least %d characters", [data.passwords.min_length]) +} + +violation[{"msg": "requires at least one number"}] { + data.passwords.require_number + not regex.match("[0-9]", input.password) +} + +violation[{"msg": "requires at least one lowercase letter"}] { + data.passwords.require_lowercase + not regex.match("[a-z]", input.password) +} + +violation[{"msg": "requires at least one uppercase letter"}] { + data.passwords.require_uppercase + not regex.match("[A-Z]", input.password) +} diff --git a/policies/register.rego b/policies/register.rego index 391fc37ba..b15e0fdce 100644 --- a/policies/register.rego +++ b/policies/register.rego @@ -1,5 +1,11 @@ +# METADATA +# schemas: +# - input: schema["register_input"] package register +import data.email as email_policy +import data.password as password_policy + import future.keywords.in default allow := false @@ -9,52 +15,24 @@ allow { } violation[{"field": "username", "msg": "username too short"}] { - count(input.user.username) <= 2 + count(input.username) <= 2 } violation[{"field": "username", "msg": "username too long"}] { - count(input.user.username) >= 15 -} - -violation[{"field": "password", "msg": msg}] { - count(input.user.password) < data.passwords.min_length - msg := sprintf("needs to be at least %d characters", [data.passwords.min_length]) -} - -violation[{"field": "password", "msg": "requires at least one number"}] { - data.passwords.require_number - not regex.match("[0-9]", input.user.password) -} - -violation[{"field": "password", "msg": "requires at least one lowercase letter"}] { - data.passwords.require_lowercase - not regex.match("[a-z]", input.user.password) -} - -violation[{"field": "password", "msg": "requires at least one uppercase letter"}] { - data.passwords.require_uppercase - not regex.match("[A-Z]", input.user.password) + count(input.username) >= 15 } -# Allow any domains if the data.allowed_domains array is not set -email_domain_allowed { - not data.allowed_domains -} - -# Allow an email only if its domain is in the list of allowed domains -email_domain_allowed { - [_, domain] := split(input.user.email, "@") - some allowed_domain in data.allowed_domains - glob.match(allowed_domain, ["."], domain) -} +violation[object.union({"field": "password"}, v)] { + # Check if the registration method is password + input.registration_method == "password" -violation[{"field": "email", "msg": "email domain not allowed"}] { - not email_domain_allowed + # Get the violation object from the password policy + some v in password_policy.violation } -# Deny emails with their domain in the domains banlist -violation[{"field": "email", "msg": "email domain not allowed"}] { - [_, domain] := split(input.user.email, "@") - some banned_domain in data.banned_domains - glob.match(banned_domain, ["."], domain) +# Check if the email is valid using the email policy +# and add the email field to the violation object +violation[object.union({"field": "email"}, v)] { + # Get the violation object from the email policy + some v in email_policy.violation } diff --git a/policies/register_test.rego b/policies/register_test.rego index d2b042fee..70acea87a 100644 --- a/policies/register_test.rego +++ b/policies/register_test.rego @@ -1,72 +1,85 @@ package register -mock_user := {"username": "hello", "password": "Hunter2", "email": "hello@staging.element.io"} +mock_registration := { + "registration_method": "password", + "username": "hello", + "password": "Hunter2", + "email": "hello@staging.element.io", +} test_allow_all_domains { - allow with input.user as mock_user + allow with input as mock_registration } test_allowed_domain { - allow with input.user as mock_user + allow with input as mock_registration with data.allowed_domains as ["*.element.io"] } test_not_allowed_domain { - not allow with input.user as mock_user + not allow with input as mock_registration with data.allowed_domains as ["example.com"] } test_banned_domain { - not allow with input.user as mock_user + not allow with input as mock_registration with data.banned_domains as ["*.element.io"] } test_banned_subdomain { - not allow with input.user as mock_user + not allow with input as mock_registration with data.allowed_domains as ["*.element.io"] with data.banned_domains as ["staging.element.io"] } test_short_username { - not allow with input.user as {"username": "a", "email": "hello@element.io"} + not allow with input as {"username": "a", "email": "hello@element.io"} } test_long_username { - not allow with input.user as {"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "email": "hello@element.io"} + not allow with input as {"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "email": "hello@element.io"} } test_password_require_number { - allow with input.user as mock_user + allow with input as mock_registration + with input.registration_method as "password" with data.passwords.require_number as true - not allow with input.user as mock_user - with input.user.password as "hunter" + not allow with input as mock_registration + with input.registration_method as "password" + with input.password as "hunter" with data.passwords.require_number as true } test_password_require_lowercase { - allow with input.user as mock_user + allow with input as mock_registration + with input.registration_method as "password" with data.passwords.require_lowercase as true - not allow with input.user as mock_user - with input.user.password as "HUNTER2" + not allow with input as mock_registration + with input.registration_method as "password" + with input.password as "HUNTER2" with data.passwords.require_lowercase as true } test_password_require_uppercase { - allow with input.user as mock_user + allow with input as mock_registration + with input.registration_method as "password" with data.passwords.require_uppercase as true - not allow with input.user as mock_user - with input.user.password as "hunter2" + not allow with input as mock_registration + with input.registration_method as "password" + with input.password as "hunter2" with data.passwords.require_uppercase as true } test_password_min_length { - allow with input.user as mock_user + allow with input as mock_registration + with input.registration_method as "password" with data.passwords.min_length as 6 - not allow with input.user as mock_user - with input.user.password as "short" + not allow with input as mock_registration + with input.registration_method as "password" + with input.password as "short" with data.passwords.min_length as 6 } diff --git a/policies/schema/authorization_grant_input.json b/policies/schema/authorization_grant_input.json new file mode 100644 index 000000000..a1a49a8d4 --- /dev/null +++ b/policies/schema/authorization_grant_input.json @@ -0,0 +1,24 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "AuthorizationGrantInput", + "type": "object", + "required": [ + "authorization_grant", + "client", + "user" + ], + "properties": { + "authorization_grant": { + "type": "object", + "additionalProperties": true + }, + "client": { + "type": "object", + "additionalProperties": true + }, + "user": { + "type": "object", + "additionalProperties": true + } + } +} \ No newline at end of file diff --git a/policies/schema/client_registration_input.json b/policies/schema/client_registration_input.json new file mode 100644 index 000000000..7261068e5 --- /dev/null +++ b/policies/schema/client_registration_input.json @@ -0,0 +1,14 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "ClientRegistrationInput", + "type": "object", + "required": [ + "client_metadata" + ], + "properties": { + "client_metadata": { + "type": "object", + "additionalProperties": true + } + } +} \ No newline at end of file diff --git a/policies/schema/email_input.json b/policies/schema/email_input.json new file mode 100644 index 000000000..487eb4b92 --- /dev/null +++ b/policies/schema/email_input.json @@ -0,0 +1,13 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "EmailInput", + "type": "object", + "required": [ + "email" + ], + "properties": { + "email": { + "type": "string" + } + } +} \ No newline at end of file diff --git a/policies/schema/password_input.json b/policies/schema/password_input.json new file mode 100644 index 000000000..d85b2862e --- /dev/null +++ b/policies/schema/password_input.json @@ -0,0 +1,13 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "PasswordInput", + "type": "object", + "required": [ + "password" + ], + "properties": { + "password": { + "type": "string" + } + } +} \ No newline at end of file diff --git a/policies/schema/register_input.json b/policies/schema/register_input.json new file mode 100644 index 000000000..d77ce66e0 --- /dev/null +++ b/policies/schema/register_input.json @@ -0,0 +1,32 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "RegisterInput", + "oneOf": [ + { + "type": "object", + "required": [ + "email", + "password", + "registration_method", + "username" + ], + "properties": { + "email": { + "type": "string" + }, + "password": { + "type": "string" + }, + "registration_method": { + "type": "string", + "enum": [ + "password" + ] + }, + "username": { + "type": "string" + } + } + } + ] +} \ No newline at end of file From cda799e5ada7d2c46c17e7d05ec567f6fd2b70cf Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 30 Aug 2023 16:47:57 +0200 Subject: [PATCH 2/3] Make sure we validate passwords & emails by the policy at all stages Also refactors the way we get the policy engines in requests --- Cargo.lock | 1 + crates/axum-utils/src/fancy_error.rs | 6 ++ crates/cli/src/commands/server.rs | 2 +- crates/graphql/Cargo.toml | 1 + crates/graphql/src/mutations/user_email.rs | 28 ++++++++- crates/graphql/src/state.rs | 2 + crates/handlers/src/app_state.rs | 45 ++++++++----- crates/handlers/src/graphql/mod.rs | 8 +++ crates/handlers/src/lib.rs | 8 +-- .../src/oauth2/authorization/complete.rs | 13 ++-- .../handlers/src/oauth2/authorization/mod.rs | 11 ++-- crates/handlers/src/oauth2/consent.rs | 11 +--- crates/handlers/src/oauth2/registration.rs | 8 +-- crates/handlers/src/test_utils.rs | 35 +++++++---- .../handlers/src/views/account/emails/add.rs | 63 ++++++++++++++----- crates/handlers/src/views/account/password.rs | 11 +++- crates/handlers/src/views/register.rs | 7 +-- crates/policy/src/bin/schema.rs | 2 + crates/policy/src/model.rs | 23 +++++++ docs/config.schema.json | 12 ++++ frontend/schema.graphql | 8 +++ .../components/UserProfile/AddEmailForm.tsx | 14 +++++ frontend/src/gql/gql.ts | 6 +- frontend/src/gql/graphql.ts | 6 ++ frontend/src/gql/schema.ts | 14 +++++ .../schema/authorization_grant_input.json | 1 + .../schema/client_registration_input.json | 1 + policies/schema/email_input.json | 1 + policies/schema/password_input.json | 1 + policies/schema/register_input.json | 1 + 30 files changed, 265 insertions(+), 85 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a6533dce3..a8705be13 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2763,6 +2763,7 @@ dependencies = [ "lettre", "mas-data-model", "mas-matrix", + "mas-policy", "mas-storage", "oauth2-types", "serde", diff --git a/crates/axum-utils/src/fancy_error.rs b/crates/axum-utils/src/fancy_error.rs index 363f423ef..86e0a60f5 100644 --- a/crates/axum-utils/src/fancy_error.rs +++ b/crates/axum-utils/src/fancy_error.rs @@ -23,6 +23,12 @@ pub struct FancyError { context: ErrorContext, } +impl FancyError { + pub fn new(context: ErrorContext) -> Self { + Self { context } + } +} + impl std::fmt::Display for FancyError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let code = self.context.code().unwrap_or("Internal error"); diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index f293aa10c..b00142528 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -143,7 +143,7 @@ impl Options { // Listen for SIGHUP register_sighup(&templates)?; - let graphql_schema = mas_handlers::graphql_schema(&pool, conn); + let graphql_schema = mas_handlers::graphql_schema(&pool, &policy_factory, conn); let state = { let mut s = AppState { diff --git a/crates/graphql/Cargo.toml b/crates/graphql/Cargo.toml index 39087f6c4..88db93e28 100644 --- a/crates/graphql/Cargo.toml +++ b/crates/graphql/Cargo.toml @@ -22,6 +22,7 @@ url.workspace = true oauth2-types = { path = "../oauth2-types" } mas-data-model = { path = "../data-model" } mas-matrix = { path = "../matrix" } +mas-policy = { path = "../policy" } mas-storage = { path = "../storage" } [[bin]] diff --git a/crates/graphql/src/mutations/user_email.rs b/crates/graphql/src/mutations/user_email.rs index 2ea7052cd..ef8c124f8 100644 --- a/crates/graphql/src/mutations/user_email.rs +++ b/crates/graphql/src/mutations/user_email.rs @@ -49,6 +49,8 @@ pub enum AddEmailStatus { Exists, /// The email address is invalid Invalid, + /// The email address is not allowed by the policy + Denied, } /// The payload of the `addEmail` mutation @@ -57,6 +59,9 @@ enum AddEmailPayload { Added(mas_data_model::UserEmail), Exists(mas_data_model::UserEmail), Invalid, + Denied { + violations: Vec, + }, } #[Object(use_type_description)] @@ -67,6 +72,7 @@ impl AddEmailPayload { AddEmailPayload::Added(_) => AddEmailStatus::Added, AddEmailPayload::Exists(_) => AddEmailStatus::Exists, AddEmailPayload::Invalid => AddEmailStatus::Invalid, + AddEmailPayload::Denied { .. } => AddEmailStatus::Denied, } } @@ -76,7 +82,7 @@ impl AddEmailPayload { AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => { Some(UserEmail(email.clone())) } - AddEmailPayload::Invalid => None, + AddEmailPayload::Invalid | AddEmailPayload::Denied { .. } => None, } } @@ -87,7 +93,7 @@ impl AddEmailPayload { let user_id = match self { AddEmailPayload::Added(email) | AddEmailPayload::Exists(email) => email.user_id, - AddEmailPayload::Invalid => return Ok(None), + AddEmailPayload::Invalid | AddEmailPayload::Denied { .. } => return Ok(None), }; let user = repo @@ -98,6 +104,16 @@ impl AddEmailPayload { Ok(Some(User(user))) } + + /// The list of policy violations if the email address was denied + async fn violations(&self) -> Option> { + let AddEmailPayload::Denied { violations } = self else { + return None; + }; + + let messages = violations.iter().map(|v| v.msg.clone()).collect(); + Some(messages) + } } /// The input for the `sendVerificationEmail` mutation @@ -382,6 +398,14 @@ impl UserEmailMutations { return Ok(AddEmailPayload::Invalid); } + let mut policy = state.policy().await?; + let res = policy.evaluate_email(&input.email).await?; + if !res.valid() { + return Ok(AddEmailPayload::Denied { + violations: res.violations, + }); + } + // Find an existing email address let existing_user_email = repo.user_email().find(&user, &input.email).await?; let (added, user_email) = if let Some(user_email) = existing_user_email { diff --git a/crates/graphql/src/state.rs b/crates/graphql/src/state.rs index 90b2e637d..441d9a749 100644 --- a/crates/graphql/src/state.rs +++ b/crates/graphql/src/state.rs @@ -13,6 +13,7 @@ // limitations under the License. use mas_matrix::HomeserverConnection; +use mas_policy::Policy; use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError}; use crate::Requester; @@ -20,6 +21,7 @@ use crate::Requester; #[async_trait::async_trait] pub trait State { async fn repository(&self) -> Result; + async fn policy(&self) -> Result; fn homeserver_connection(&self) -> &dyn HomeserverConnection; fn clock(&self) -> BoxClock; fn rng(&self) -> BoxRng; diff --git a/crates/handlers/src/app_state.rs b/crates/handlers/src/app_state.rs index 772c4bf53..cf031d889 100644 --- a/crates/handlers/src/app_state.rs +++ b/crates/handlers/src/app_state.rs @@ -17,12 +17,12 @@ use std::{convert::Infallible, sync::Arc, time::Instant}; use axum::{ async_trait, extract::{FromRef, FromRequestParts}, - response::IntoResponse, + response::{IntoResponse, Response}, }; use hyper::StatusCode; use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory}; use mas_keystore::{Encrypter, Keystore}; -use mas_policy::PolicyFactory; +use mas_policy::{Policy, PolicyFactory}; use mas_router::UrlBuilder; use mas_storage::{BoxClock, BoxRepository, BoxRng, Repository, SystemClock}; use mas_storage_pg::PgRepository; @@ -33,7 +33,6 @@ use opentelemetry::{ }; use rand::SeedableRng; use sqlx::PgPool; -use thiserror::Error; use crate::{passwords::PasswordManager, upstream_oauth2::cache::MetadataCache, MatrixHomeserver}; @@ -176,12 +175,6 @@ impl FromRef for MatrixHomeserver { } } -impl FromRef for Arc { - fn from_ref(input: &AppState) -> Self { - input.policy_factory.clone() - } -} - impl FromRef for HttpClientFactory { fn from_ref(input: &AppState) -> Self { input.http_client_factory.clone() @@ -236,19 +229,41 @@ impl FromRequestParts for BoxRng { } } -#[derive(Debug, Error)] -#[error(transparent)] -pub struct RepositoryError(#[from] mas_storage_pg::DatabaseError); +/// A simple wrapper around an error that implements [`IntoResponse`]. +pub struct ErrorWrapper(T); -impl IntoResponse for RepositoryError { - fn into_response(self) -> axum::response::Response { +impl From for ErrorWrapper { + fn from(input: T) -> Self { + Self(input) + } +} + +impl IntoResponse for ErrorWrapper +where + T: std::error::Error, +{ + fn into_response(self) -> Response { + // TODO: make this a bit more user friendly (StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response() } } +#[async_trait] +impl FromRequestParts for Policy { + type Rejection = ErrorWrapper; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + state: &AppState, + ) -> Result { + let policy = state.policy_factory.instantiate().await?; + Ok(policy) + } +} + #[async_trait] impl FromRequestParts for BoxRepository { - type Rejection = RepositoryError; + type Rejection = ErrorWrapper; async fn from_request_parts( _parts: &mut axum::http::request::Parts, diff --git a/crates/handlers/src/graphql/mod.rs b/crates/handlers/src/graphql/mod.rs index 89e97bb82..16fc692cb 100644 --- a/crates/handlers/src/graphql/mod.rs +++ b/crates/handlers/src/graphql/mod.rs @@ -31,6 +31,7 @@ use hyper::header::CACHE_CONTROL; use mas_axum_utils::{cookies::CookieJar, FancyError, SessionInfo, SessionInfoExt}; use mas_graphql::{Requester, Schema}; use mas_matrix::HomeserverConnection; +use mas_policy::{InstantiateError, Policy, PolicyFactory}; use mas_storage::{ BoxClock, BoxRepository, BoxRng, Clock, Repository, RepositoryError, SystemClock, }; @@ -48,6 +49,7 @@ mod tests; struct GraphQLState { pool: PgPool, homeserver_connection: Arc>, + policy_factory: Arc, } #[async_trait] @@ -60,6 +62,10 @@ impl mas_graphql::State for GraphQLState { Ok(repo.map_err(RepositoryError::from_error).boxed()) } + async fn policy(&self) -> Result { + self.policy_factory.instantiate().await + } + fn homeserver_connection(&self) -> &dyn HomeserverConnection { self.homeserver_connection.as_ref() } @@ -81,10 +87,12 @@ impl mas_graphql::State for GraphQLState { #[must_use] pub fn schema( pool: &PgPool, + policy_factory: &Arc, homeserver_connection: impl HomeserverConnection + 'static, ) -> Schema { let state = GraphQLState { pool: pool.clone(), + policy_factory: Arc::clone(policy_factory), homeserver_connection: Arc::new(homeserver_connection), }; let state: mas_graphql::BoxState = Box::new(state); diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index dda27c577..cc6554ebe 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -30,7 +30,7 @@ clippy::let_with_type_underscore, )] -use std::{convert::Infallible, sync::Arc, time::Duration}; +use std::{convert::Infallible, time::Duration}; use axum::{ body::{Bytes, HttpBody}, @@ -50,7 +50,7 @@ use hyper::{ use mas_axum_utils::{cookies::CookieJar, FancyError}; use mas_http::CorsLayerExt; use mas_keystore::{Encrypter, Keystore}; -use mas_policy::PolicyFactory; +use mas_policy::Policy; use mas_router::{Route, UrlBuilder}; use mas_storage::{BoxClock, BoxRepository, BoxRng}; use mas_templates::{ErrorContext, NotFoundContext, Templates}; @@ -166,12 +166,12 @@ where S: Clone + Send + Sync + 'static, Keystore: FromRef, UrlBuilder: FromRef, - Arc: FromRef, BoxRepository: FromRequestParts, Encrypter: FromRef, HttpClientFactory: FromRef, BoxClock: FromRequestParts, BoxRng: FromRequestParts, + Policy: FromRequestParts, { // All those routes are API-like, with a common CORS layer Router::new() @@ -267,7 +267,6 @@ where ::Error: std::error::Error + Send + Sync, S: Clone + Send + Sync + 'static, UrlBuilder: FromRef, - Arc: FromRef, BoxRepository: FromRequestParts, CookieJar: FromRequestParts, Encrypter: FromRef, @@ -278,6 +277,7 @@ where MetadataCache: FromRef, BoxClock: FromRequestParts, BoxRng: FromRequestParts, + Policy: FromRequestParts, { Router::new() // XXX: hard-coded redirect from /account to /account/ diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index 6c5b7b7f4..48b1800ce 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use axum::{ extract::{Path, State}, response::{Html, IntoResponse, Response}, @@ -22,7 +20,7 @@ use hyper::StatusCode; use mas_axum_utils::{cookies::CookieJar, csrf::CsrfExt, SessionInfoExt}; use mas_data_model::{AuthorizationGrant, BrowserSession, Client, Device}; use mas_keystore::Keystore; -use mas_policy::{EvaluationResult, PolicyFactory}; +use mas_policy::{EvaluationResult, Policy}; use mas_router::{PostAuthAction, Route, UrlBuilder}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository, OAuth2SessionRepository}, @@ -76,7 +74,6 @@ impl IntoResponse for RouteError { impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_policy::LoadError); -impl_from_error_for_route!(mas_policy::InstantiateError); impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(super::callback::CallbackDestinationError); @@ -90,10 +87,10 @@ impl_from_error_for_route!(super::callback::CallbackDestinationError); pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, - State(policy_factory): State>, State(templates): State, State(url_builder): State, State(key_store): State, + policy: Policy, mut repo: BoxRepository, cookie_jar: CookieJar, Path(grant_id): Path, @@ -128,7 +125,7 @@ pub(crate) async fn get( &clock, repo, key_store, - &policy_factory, + policy, url_builder, grant, &client, @@ -187,7 +184,6 @@ pub enum GrantCompletionError { impl_from_error_for_route!(GrantCompletionError: mas_storage::RepositoryError); impl_from_error_for_route!(GrantCompletionError: super::callback::IntoCallbackDestinationError); impl_from_error_for_route!(GrantCompletionError: mas_policy::LoadError); -impl_from_error_for_route!(GrantCompletionError: mas_policy::InstantiateError); impl_from_error_for_route!(GrantCompletionError: mas_policy::EvaluationError); impl_from_error_for_route!(GrantCompletionError: super::super::IdTokenSignatureError); @@ -196,7 +192,7 @@ pub(crate) async fn complete( clock: &impl Clock, mut repo: BoxRepository, key_store: Keystore, - policy_factory: &PolicyFactory, + mut policy: Policy, url_builder: UrlBuilder, grant: AuthorizationGrant, client: &Client, @@ -220,7 +216,6 @@ pub(crate) async fn complete( }; // Run through the policy - let mut policy = policy_factory.instantiate().await?; let res = policy .evaluate_authorization_grant(&grant, client, &browser_session.user) .await?; diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index 8fec59fe7..1cef1ac45 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use axum::{ extract::{Form, State}, response::{Html, IntoResponse, Response}, @@ -22,7 +20,7 @@ use hyper::StatusCode; use mas_axum_utils::{cookies::CookieJar, csrf::CsrfExt, SessionInfoExt}; use mas_data_model::{AuthorizationCode, Pkce}; use mas_keystore::Keystore; -use mas_policy::PolicyFactory; +use mas_policy::Policy; use mas_router::{PostAuthAction, Route, UrlBuilder}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, @@ -94,7 +92,6 @@ impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(self::callback::CallbackDestinationError); impl_from_error_for_route!(mas_policy::LoadError); -impl_from_error_for_route!(mas_policy::InstantiateError); impl_from_error_for_route!(mas_policy::EvaluationError); #[derive(Deserialize)] @@ -140,10 +137,10 @@ fn resolve_response_mode( pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, - State(policy_factory): State>, State(templates): State, State(key_store): State, State(url_builder): State, + policy: Policy, mut repo: BoxRepository, cookie_jar: CookieJar, Form(params): Form, @@ -346,7 +343,7 @@ pub(crate) async fn get( &clock, repo, key_store, - &policy_factory, + policy, url_builder, grant, &client, @@ -393,7 +390,7 @@ pub(crate) async fn get( &clock, repo, key_store, - &policy_factory, + policy, url_builder, grant, &client, diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index 85acb82fc..c448923a1 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use axum::{ extract::{Form, Path, State}, response::{Html, IntoResponse, Response}, @@ -25,7 +23,7 @@ use mas_axum_utils::{ SessionInfoExt, }; use mas_data_model::{AuthorizationGrantStage, Device}; -use mas_policy::PolicyFactory; +use mas_policy::Policy; use mas_router::{PostAuthAction, Route}; use mas_storage::{ oauth2::{OAuth2AuthorizationGrantRepository, OAuth2ClientRepository}, @@ -61,7 +59,6 @@ pub enum RouteError { impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_policy::LoadError); -impl_from_error_for_route!(mas_policy::InstantiateError); impl_from_error_for_route!(mas_policy::EvaluationError); impl IntoResponse for RouteError { @@ -80,8 +77,8 @@ impl IntoResponse for RouteError { pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, - State(policy_factory): State>, State(templates): State, + mut policy: Policy, mut repo: BoxRepository, cookie_jar: CookieJar, Path(grant_id): Path, @@ -109,7 +106,6 @@ pub(crate) async fn get( if let Some(session) = maybe_session { let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); - let mut policy = policy_factory.instantiate().await?; let res = policy .evaluate_authorization_grant(&grant, &client, &session.user) .await?; @@ -146,7 +142,7 @@ pub(crate) async fn get( pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, - State(policy_factory): State>, + mut policy: Policy, mut repo: BoxRepository, cookie_jar: CookieJar, Path(grant_id): Path, @@ -176,7 +172,6 @@ pub(crate) async fn post( .await? .ok_or(RouteError::NoSuchClient)?; - let mut policy = policy_factory.instantiate().await?; let res = policy .evaluate_authorization_grant(&grant, &client, &session.user) .await?; diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index d2f2c5e36..99cefc7a8 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -12,13 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use axum::{extract::State, response::IntoResponse, Json}; use hyper::StatusCode; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_keystore::Encrypter; -use mas_policy::{PolicyFactory, Violation}; +use mas_policy::{Policy, Violation}; use mas_storage::{oauth2::OAuth2ClientRepository, BoxClock, BoxRepository, BoxRng}; use oauth2_types::{ errors::{ClientError, ClientErrorCode}, @@ -49,7 +47,6 @@ pub(crate) enum RouteError { impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_policy::LoadError); -impl_from_error_for_route!(mas_policy::InstantiateError); impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(mas_keystore::aead::Error); @@ -136,7 +133,7 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, mut repo: BoxRepository, - State(policy_factory): State>, + mut policy: Policy, State(encrypter): State, body: Result, axum::extract::rejection::JsonRejection>, ) -> Result { @@ -148,7 +145,6 @@ pub(crate) async fn post( // Validate the body let metadata = body.validate()?; - let mut policy = policy_factory.instantiate().await?; let res = policy.evaluate_client_registration(&metadata).await?; if !res.valid() { return Err(RouteError::PolicyDenied(res.violations)); diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index ca3a581fc..6ed765e19 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -33,10 +33,10 @@ use hyper::{ use mas_axum_utils::{cookies::CookieManager, http_client_factory::HttpClientFactory}; use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey}; use mas_matrix::{HomeserverConnection, MockHomeserverConnection}; -use mas_policy::PolicyFactory; +use mas_policy::{InstantiateError, Policy, PolicyFactory}; use mas_router::{SimpleRoute, UrlBuilder}; use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository}; -use mas_storage_pg::PgRepository; +use mas_storage_pg::{DatabaseError, PgRepository}; use mas_templates::Templates; use rand::SeedableRng; use rand_chacha::ChaChaRng; @@ -46,7 +46,7 @@ use tower::{Layer, Service, ServiceExt}; use url::Url; use crate::{ - app_state::RepositoryError, + app_state::ErrorWrapper, passwords::{Hasher, PasswordManager}, upstream_oauth2::cache::MetadataCache, MatrixHomeserver, @@ -138,6 +138,7 @@ impl TestState { let graphql_state = TestGraphQLState { pool: pool.clone(), + policy_factory: Arc::clone(&policy_factory), homeserver_connection, rng: Arc::clone(&rng), clock: Arc::clone(&clock), @@ -202,7 +203,7 @@ impl TestState { Response::from_parts(parts, body) } - pub async fn repository(&self) -> Result { + pub async fn repository(&self) -> Result { let repo = PgRepository::from_pool(&self.pool).await?; Ok(repo .map_err(mas_storage::RepositoryError::from_error) @@ -243,6 +244,7 @@ impl TestState { struct TestGraphQLState { pool: PgPool, homeserver_connection: MockHomeserverConnection, + policy_factory: Arc, clock: Arc, rng: Arc>, } @@ -259,6 +261,10 @@ impl mas_graphql::State for TestGraphQLState { .boxed()) } + async fn policy(&self) -> Result { + self.policy_factory.instantiate().await + } + fn homeserver_connection(&self) -> &dyn HomeserverConnection { &self.homeserver_connection } @@ -316,12 +322,6 @@ impl FromRef for MatrixHomeserver { } } -impl FromRef for Arc { - fn from_ref(input: &TestState) -> Self { - input.policy_factory.clone() - } -} - impl FromRef for HttpClientFactory { fn from_ref(input: &TestState) -> Self { input.http_client_factory.clone() @@ -374,7 +374,7 @@ impl FromRequestParts for BoxRng { #[async_trait] impl FromRequestParts for BoxRepository { - type Rejection = RepositoryError; + type Rejection = ErrorWrapper; async fn from_request_parts( _parts: &mut axum::http::request::Parts, @@ -387,6 +387,19 @@ impl FromRequestParts for BoxRepository { } } +#[async_trait] +impl FromRequestParts for Policy { + type Rejection = ErrorWrapper; + + async fn from_request_parts( + _parts: &mut axum::http::request::Parts, + state: &TestState, + ) -> Result { + let policy = state.policy_factory.instantiate().await?; + Ok(policy) + } +} + pub(crate) trait RequestBuilderExt { /// Builds the request with the given JSON value as body. fn json(self, body: T) -> hyper::Request; diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index 49036010b..e4577d748 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -21,13 +21,14 @@ use mas_axum_utils::{ csrf::{CsrfExt, ProtectedForm}, FancyError, SessionInfoExt, }; +use mas_policy::Policy; use mas_router::Route; use mas_storage::{ job::{JobRepositoryExt, VerifyEmailJob}, user::UserEmailRepository, BoxClock, BoxRepository, BoxRng, }; -use mas_templates::{EmailAddContext, TemplateContext, Templates}; +use mas_templates::{EmailAddContext, ErrorContext, TemplateContext, Templates}; use serde::Deserialize; use crate::views::shared::OptionalPostAuthAction; @@ -69,6 +70,7 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, mut repo: BoxRepository, + mut policy: Policy, cookie_jar: CookieJar, Query(query): Query, Form(form): Form>, @@ -83,23 +85,56 @@ pub(crate) async fn post( return Ok((cookie_jar, login.go()).into_response()); }; - let user_email = repo - .user_email() - .add(&mut rng, &clock, &session.user, form.email) - .await?; - - let next = mas_router::AccountVerifyEmail::new(user_email.id); - let next = if let Some(action) = query.post_auth_action { - next.and_then(action) + // XXX: we really should show human readable errors on the form here + + // Validate the email address + if form.email.parse::().is_err() { + return Err(anyhow::anyhow!("Invalid email address").into()); + } + + // Run the email policy + let res = policy.evaluate_email(&form.email).await?; + if !res.valid() { + return Err(FancyError::new( + ErrorContext::new() + .with_description(format!("Email address {:?} denied by policy", form.email)) + .with_details(format!("{res}")), + )); + } + + // Find an existing email address + let existing_user_email = repo.user_email().find(&session.user, &form.email).await?; + let user_email = if let Some(user_email) = existing_user_email { + user_email } else { - next + let user_email = repo + .user_email() + .add(&mut rng, &clock, &session.user, form.email) + .await?; + + user_email }; - repo.job() - .schedule_job(VerifyEmailJob::new(&user_email)) - .await?; + // If the email was not confirmed, send a confirmation email & redirect to the + // verify page + let next = if user_email.confirmed_at.is_none() { + repo.job() + .schedule_job(VerifyEmailJob::new(&user_email)) + .await?; + + let next = mas_router::AccountVerifyEmail::new(user_email.id); + let next = if let Some(action) = query.post_auth_action { + next.and_then(action) + } else { + next + }; + + next.go() + } else { + query.go_next_or_default(&mas_router::Account) + }; repo.save().await?; - Ok((cookie_jar, next.go()).into_response()) + Ok((cookie_jar, next).into_response()) } diff --git a/crates/handlers/src/views/account/password.rs b/crates/handlers/src/views/account/password.rs index 4b2544a10..f21b5743f 100644 --- a/crates/handlers/src/views/account/password.rs +++ b/crates/handlers/src/views/account/password.rs @@ -24,6 +24,7 @@ use mas_axum_utils::{ FancyError, SessionInfoExt, }; use mas_data_model::BrowserSession; +use mas_policy::Policy; use mas_router::Route; use mas_storage::{ user::{BrowserSessionRepository, UserPasswordRepository}, @@ -93,6 +94,7 @@ pub(crate) async fn post( clock: BoxClock, State(password_manager): State, State(templates): State, + mut policy: Policy, mut repo: BoxRepository, cookie_jar: CookieJar, Form(form): Form>, @@ -119,6 +121,13 @@ pub(crate) async fn post( .await? .context("user has no password")?; + let res = policy.evaluate_password(&form.new_password).await?; + + // TODO: display nice form errors + if !res.valid() { + return Err(anyhow::anyhow!("Password policy violation: {res}").into()); + } + let password = Zeroizing::new(form.current_password.into_bytes()); let new_password = Zeroizing::new(form.new_password.into_bytes()); let new_password_confirm = Zeroizing::new(form.new_password_confirm.into_bytes()); @@ -133,7 +142,7 @@ pub(crate) async fn post( // TODO: display nice form errors if new_password != new_password_confirm { - return Err(anyhow::anyhow!("password mismatch").into()); + return Err(anyhow::anyhow!("Password mismatch").into()); } let (version, hashed_password) = password_manager.hash(&mut rng, new_password).await?; diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index 3e32bad10..070bb954f 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{str::FromStr, sync::Arc}; +use std::str::FromStr; use axum::{ extract::{Form, Query, State}, @@ -27,7 +27,7 @@ use mas_axum_utils::{ csrf::{CsrfExt, CsrfToken, ProtectedForm}, FancyError, SessionInfoExt, }; -use mas_policy::PolicyFactory; +use mas_policy::Policy; use mas_router::Route; use mas_storage::{ job::{JobRepositoryExt, ProvisionUserJob, VerifyEmailJob}, @@ -101,8 +101,8 @@ pub(crate) async fn post( mut rng: BoxRng, clock: BoxClock, State(password_manager): State, - State(policy_factory): State>, State(templates): State, + mut policy: Policy, mut repo: BoxRepository, Query(query): Query, cookie_jar: CookieJar, @@ -148,7 +148,6 @@ pub(crate) async fn post( state.add_error_on_field(RegisterFormField::PasswordConfirm, FieldError::Unspecified); } - let mut policy = policy_factory.instantiate().await?; let res = policy .evaluate_register(&form.username, &form.password, &form.email) .await?; diff --git a/crates/policy/src/bin/schema.rs b/crates/policy/src/bin/schema.rs index 7742a28aa..53547db68 100644 --- a/crates/policy/src/bin/schema.rs +++ b/crates/policy/src/bin/schema.rs @@ -43,6 +43,8 @@ fn write_schema(out_dir: Option<&Path>, file: &str) { writer.flush().expect("Failed to flush writer"); } +/// Write the input schemas to the output directory. +/// They are then used in rego files to type check the input. fn main() { let output_root = std::env::var("OUT_DIR").map(PathBuf::from).ok(); let output_root = output_root.as_deref(); diff --git a/crates/policy/src/model.rs b/crates/policy/src/model.rs index 3cc9ff1f7..65c4a0058 100644 --- a/crates/policy/src/model.rs +++ b/crates/policy/src/model.rs @@ -16,6 +16,7 @@ use mas_data_model::{AuthorizationGrant, Client, User}; use oauth2_types::registration::VerifiedClientMetadata; use serde::{Deserialize, Serialize}; +/// A single violation of a policy. #[derive(Deserialize, Debug)] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] pub struct Violation { @@ -23,19 +24,37 @@ pub struct Violation { pub field: Option, } +/// The result of a policy evaluation. #[derive(Deserialize, Debug)] pub struct EvaluationResult { #[serde(rename = "result")] pub violations: Vec, } +impl std::fmt::Display for EvaluationResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut first = true; + for violation in &self.violations { + if first { + first = false; + } else { + write!(f, ", ")?; + } + write!(f, "{}", violation.msg)?; + } + Ok(()) + } +} + impl EvaluationResult { + /// Returns true if the policy evaluation was successful. #[must_use] pub fn valid(&self) -> bool { self.violations.is_empty() } } +/// Input for the user registration policy. #[derive(Serialize, Debug)] #[serde(tag = "registration_method", rename_all = "snake_case")] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] @@ -47,6 +66,7 @@ pub enum RegisterInput<'a> { }, } +/// Input for the client registration policy. #[derive(Serialize, Debug)] #[serde(rename_all = "snake_case")] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] @@ -58,6 +78,7 @@ pub struct ClientRegistrationInput<'a> { pub client_metadata: &'a VerifiedClientMetadata, } +/// Input for the authorization grant policy. #[derive(Serialize, Debug)] #[serde(rename_all = "snake_case")] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] @@ -81,6 +102,7 @@ pub struct AuthorizationGrantInput<'a> { pub authorization_grant: &'a AuthorizationGrant, } +/// Input for the email add policy. #[derive(Serialize, Debug)] #[serde(rename_all = "snake_case")] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] @@ -88,6 +110,7 @@ pub struct EmailInput<'a> { pub email: &'a str, } +/// Input for the password set policy. #[derive(Serialize, Debug)] #[serde(rename_all = "snake_case")] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] diff --git a/docs/config.schema.json b/docs/config.schema.json index 835163db2..80663733a 100644 --- a/docs/config.schema.json +++ b/docs/config.schema.json @@ -147,6 +147,8 @@ "authorization_grant_entrypoint": "authorization_grant/violation", "client_registration_entrypoint": "client_registration/violation", "data": null, + "email_entrypoint": "email/violation", + "password_entrypoint": "password/violation", "register_entrypoint": "register/violation", "wasm_module": "./policies/policy.wasm" }, @@ -1349,6 +1351,16 @@ "description": "Arbitrary data to pass to the policy", "default": null }, + "email_entrypoint": { + "description": "Entrypoint to use when adding an email address", + "default": "email/violation", + "type": "string" + }, + "password_entrypoint": { + "description": "Entrypoint to use when changing password", + "default": "password/violation", + "type": "string" + }, "register_entrypoint": { "description": "Entrypoint to use when evaluating user registrations", "default": "register/violation", diff --git a/frontend/schema.graphql b/frontend/schema.graphql index 821534e35..5a89768e1 100644 --- a/frontend/schema.graphql +++ b/frontend/schema.graphql @@ -28,6 +28,10 @@ type AddEmailPayload { The user to whom the email address was added """ user: User + """ + The list of policy violations if the email address was denied + """ + violations: [String!] } """ @@ -46,6 +50,10 @@ enum AddEmailStatus { The email address is invalid """ INVALID + """ + The email address is not allowed by the policy + """ + DENIED } type Anonymous implements Node { diff --git a/frontend/src/components/UserProfile/AddEmailForm.tsx b/frontend/src/components/UserProfile/AddEmailForm.tsx index dabad3265..315220fd8 100644 --- a/frontend/src/components/UserProfile/AddEmailForm.tsx +++ b/frontend/src/components/UserProfile/AddEmailForm.tsx @@ -30,6 +30,7 @@ const ADD_EMAIL_MUTATION = graphql(/* GraphQL */ ` mutation AddEmail($userId: ID!, $email: String!) { addEmail(input: { userId: $userId, email: $email }) { status + violations email { id ...UserEmail_email @@ -79,6 +80,8 @@ const AddEmailForm: React.FC<{ const status = addEmailResult.data?.addEmail.status ?? null; const emailExists = status === "EXISTS"; const emailInvalid = status === "INVALID"; + const emailDenied = status === "DENIED"; + const violations = addEmailResult.data?.addEmail.violations ?? []; return ( <> @@ -95,6 +98,17 @@ const AddEmailForm: React.FC<{ )} + {emailDenied && ( + + The entered email is not allowed by the server policy. +
    + {violations.map((violation, index) => ( +
  • • {violation}
  • + ))} +
+
+ )} + diff --git a/frontend/src/gql/gql.ts b/frontend/src/gql/gql.ts index 009714c1a..199583a95 100644 --- a/frontend/src/gql/gql.ts +++ b/frontend/src/gql/gql.ts @@ -47,7 +47,7 @@ const documents = { types.UserGreetingDocument, "\n fragment UserHome_user on User {\n id\n\n primaryEmail {\n id\n ...UserEmail_email\n }\n\n confirmedEmails: emails(first: 0, state: CONFIRMED) {\n totalCount\n }\n\n unverifiedEmails: emails(first: 0, state: PENDING) {\n totalCount\n }\n\n browserSessions(first: 0, state: ACTIVE) {\n totalCount\n }\n\n oauth2Sessions(first: 0, state: ACTIVE) {\n totalCount\n }\n\n compatSessions(first: 0, state: ACTIVE) {\n totalCount\n }\n }\n": types.UserHome_UserFragmentDoc, - "\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n email {\n id\n ...UserEmail_email\n }\n }\n }\n": + "\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n violations\n email {\n id\n ...UserEmail_email\n }\n }\n }\n": types.AddEmailDocument, "\n query UserEmailListQuery(\n $userId: ID!\n $first: Int\n $after: String\n $last: Int\n $before: String\n ) {\n user(id: $userId) {\n id\n\n emails(first: $first, after: $after, last: $last, before: $before) {\n edges {\n cursor\n node {\n id\n ...UserEmail_email\n }\n }\n totalCount\n pageInfo {\n hasNextPage\n hasPreviousPage\n startCursor\n endCursor\n }\n }\n }\n }\n": types.UserEmailListQueryDocument, @@ -191,8 +191,8 @@ export function graphql( * The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients. */ export function graphql( - source: "\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n email {\n id\n ...UserEmail_email\n }\n }\n }\n", -): (typeof documents)["\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n email {\n id\n ...UserEmail_email\n }\n }\n }\n"]; + source: "\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n violations\n email {\n id\n ...UserEmail_email\n }\n }\n }\n", +): (typeof documents)["\n mutation AddEmail($userId: ID!, $email: String!) {\n addEmail(input: { userId: $userId, email: $email }) {\n status\n violations\n email {\n id\n ...UserEmail_email\n }\n }\n }\n"]; /** * The graphql function is used to parse GraphQL queries into a document that can be used by GraphQL clients. */ diff --git a/frontend/src/gql/graphql.ts b/frontend/src/gql/graphql.ts index 2b39fb926..96748bedf 100644 --- a/frontend/src/gql/graphql.ts +++ b/frontend/src/gql/graphql.ts @@ -54,12 +54,16 @@ export type AddEmailPayload = { status: AddEmailStatus; /** The user to whom the email address was added */ user?: Maybe; + /** The list of policy violations if the email address was denied */ + violations?: Maybe>; }; /** The status of the `addEmail` mutation */ export enum AddEmailStatus { /** The email address was added */ Added = "ADDED", + /** The email address is not allowed by the policy */ + Denied = "DENIED", /** The email address already exists */ Exists = "EXISTS", /** The email address is invalid */ @@ -1231,6 +1235,7 @@ export type AddEmailMutation = { addEmail: { __typename?: "AddEmailPayload"; status: AddEmailStatus; + violations?: Array | null; email?: | ({ __typename?: "UserEmail"; id: string } & { " $fragmentRefs"?: { @@ -3129,6 +3134,7 @@ export const AddEmailDocument = { kind: "SelectionSet", selections: [ { kind: "Field", name: { kind: "Name", value: "status" } }, + { kind: "Field", name: { kind: "Name", value: "violations" } }, { kind: "Field", name: { kind: "Name", value: "email" }, diff --git a/frontend/src/gql/schema.ts b/frontend/src/gql/schema.ts index e5d546ab7..c27b9bd37 100644 --- a/frontend/src/gql/schema.ts +++ b/frontend/src/gql/schema.ts @@ -42,6 +42,20 @@ export default { }, args: [], }, + { + name: "violations", + type: { + kind: "LIST", + ofType: { + kind: "NON_NULL", + ofType: { + kind: "SCALAR", + name: "Any", + }, + }, + }, + args: [], + }, ], interfaces: [], }, diff --git a/policies/schema/authorization_grant_input.json b/policies/schema/authorization_grant_input.json index a1a49a8d4..afd230c4f 100644 --- a/policies/schema/authorization_grant_input.json +++ b/policies/schema/authorization_grant_input.json @@ -1,6 +1,7 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "title": "AuthorizationGrantInput", + "description": "Input for the authorization grant policy.", "type": "object", "required": [ "authorization_grant", diff --git a/policies/schema/client_registration_input.json b/policies/schema/client_registration_input.json index 7261068e5..cc9957a85 100644 --- a/policies/schema/client_registration_input.json +++ b/policies/schema/client_registration_input.json @@ -1,6 +1,7 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "title": "ClientRegistrationInput", + "description": "Input for the client registration policy.", "type": "object", "required": [ "client_metadata" diff --git a/policies/schema/email_input.json b/policies/schema/email_input.json index 487eb4b92..19f4af523 100644 --- a/policies/schema/email_input.json +++ b/policies/schema/email_input.json @@ -1,6 +1,7 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "title": "EmailInput", + "description": "Input for the email add policy.", "type": "object", "required": [ "email" diff --git a/policies/schema/password_input.json b/policies/schema/password_input.json index d85b2862e..c3cbf92d8 100644 --- a/policies/schema/password_input.json +++ b/policies/schema/password_input.json @@ -1,6 +1,7 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "title": "PasswordInput", + "description": "Input for the password set policy.", "type": "object", "required": [ "password" diff --git a/policies/schema/register_input.json b/policies/schema/register_input.json index d77ce66e0..1f1585aa7 100644 --- a/policies/schema/register_input.json +++ b/policies/schema/register_input.json @@ -1,6 +1,7 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "title": "RegisterInput", + "description": "Input for the user registration policy.", "oneOf": [ { "type": "object", From 319de87a456b9cd9a5ea91f3c3beac9e32e87924 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 30 Aug 2023 18:36:53 +0200 Subject: [PATCH 3/3] Run the registration policy on upstream OAuth registration --- crates/axum-utils/src/fancy_error.rs | 1 + crates/handlers/src/upstream_oauth2/link.rs | 33 ++++++++++++++- .../handlers/src/views/account/emails/add.rs | 7 +--- crates/policy/src/lib.rs | 25 ++++++++++++ crates/policy/src/model.rs | 11 ++++- policies/Makefile | 4 +- policies/email_test.rego | 26 ++++++++++++ policies/password_test.rego | 29 ++++++++++++++ policies/register.rego | 22 ++++++++++ policies/register_test.rego | 24 ++++++----- policies/schema/register_input.json | 21 ++++++++++ templates/components/field.html | 3 +- templates/pages/account/emails/add.html | 2 +- templates/pages/error.html | 40 ++++++++++--------- 14 files changed, 207 insertions(+), 41 deletions(-) create mode 100644 policies/email_test.rego create mode 100644 policies/password_test.rego diff --git a/crates/axum-utils/src/fancy_error.rs b/crates/axum-utils/src/fancy_error.rs index 86e0a60f5..88cefdbb7 100644 --- a/crates/axum-utils/src/fancy_error.rs +++ b/crates/axum-utils/src/fancy_error.rs @@ -24,6 +24,7 @@ pub struct FancyError { } impl FancyError { + #[must_use] pub fn new(context: ErrorContext) -> Self { Self { context } } diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 9a49c4efc..60518788b 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -21,10 +21,11 @@ use hyper::StatusCode; use mas_axum_utils::{ cookies::CookieJar, csrf::{CsrfExt, ProtectedForm}, - SessionInfoExt, + FancyError, SessionInfoExt, }; use mas_data_model::{UpstreamOAuthProviderImportPreference, User}; use mas_jose::jwt::Jwt; +use mas_policy::Policy; use mas_storage::{ job::{JobRepositoryExt, ProvisionUserJob}, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, @@ -32,7 +33,8 @@ use mas_storage::{ BoxClock, BoxRepository, BoxRng, RepositoryAccess, }; use mas_templates::{ - TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink, + ErrorContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, + UpstreamSuggestLink, }; use serde::Deserialize; use thiserror::Error; @@ -76,6 +78,11 @@ pub(crate) enum RouteError { #[error("Missing username")] MissingUsername, + #[error("Policy violation: {violations:?}")] + PolicyViolation { + violations: Vec, + }, + #[error(transparent)] Internal(Box), } @@ -84,6 +91,7 @@ impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError); impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); impl_from_error_for_route!(mas_storage::RepositoryError); +impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError); impl IntoResponse for RouteError { @@ -91,6 +99,16 @@ impl IntoResponse for RouteError { sentry::capture_error(&self); match self { Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(), + Self::PolicyViolation { violations } => { + let details = violations.iter().map(|v| v.msg.clone()).collect::>(); + let details = details.join("\n"); + let ctx = ErrorContext::new() + .with_description( + "Account registration denied because of policy violation".to_owned(), + ) + .with_details(details); + FancyError::new(ctx).into_response() + } Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), e => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), } @@ -358,6 +376,7 @@ pub(crate) async fn post( mut repo: BoxRepository, cookie_jar: CookieJar, user_agent: Option>, + mut policy: Policy, Path(link_id): Path, Form(form): Form>, ) -> Result { @@ -478,6 +497,16 @@ pub(crate) async fn post( let username = username.ok_or(RouteError::MissingUsername)?; + // Policy check + let res = policy + .evaluate_upstream_oauth_register(&username, email.as_deref()) + .await?; + if !res.valid() { + return Err(RouteError::PolicyViolation { + violations: res.violations, + }); + } + // Now we can create the user let user = repo.user().add(&mut rng, &clock, username).await?; diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index e4577d748..c8bc5da3f 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -107,12 +107,9 @@ pub(crate) async fn post( let user_email = if let Some(user_email) = existing_user_email { user_email } else { - let user_email = repo - .user_email() + repo.user_email() .add(&mut rng, &clock, &session.user, form.email) - .await?; - - user_email + .await? }; // If the email was not confirmed, send a confirmation email & redirect to the diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index a9d44c48b..95f963a76 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -251,6 +251,31 @@ impl Policy { Ok(res) } + #[tracing::instrument( + name = "policy.evaluate.upstream_oauth_register", + skip_all, + fields( + input.registration_method = "password", + input.user.username = username, + input.user.email = email, + ), + err, + )] + pub async fn evaluate_upstream_oauth_register( + &mut self, + username: &str, + email: Option<&str>, + ) -> Result { + let input = RegisterInput::UpstreamOAuth2 { username, email }; + + let [res]: [EvaluationResult; 1] = self + .instance + .evaluate(&mut self.store, &self.entrypoints.register, &input) + .await?; + + Ok(res) + } + #[tracing::instrument(skip(self))] pub async fn evaluate_client_registration( &mut self, diff --git a/crates/policy/src/model.rs b/crates/policy/src/model.rs index 65c4a0058..c17ce0417 100644 --- a/crates/policy/src/model.rs +++ b/crates/policy/src/model.rs @@ -56,14 +56,23 @@ impl EvaluationResult { /// Input for the user registration policy. #[derive(Serialize, Debug)] -#[serde(tag = "registration_method", rename_all = "snake_case")] +#[serde(tag = "registration_method")] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] pub enum RegisterInput<'a> { + #[serde(rename = "password")] Password { username: &'a str, password: &'a str, email: &'a str, }, + + #[serde(rename = "upstream-oauth2")] + UpstreamOAuth2 { + username: &'a str, + + #[serde(skip_serializing_if = "Option::is_none")] + email: Option<&'a str>, + }, } /// Input for the client registration policy. diff --git a/policies/Makefile b/policies/Makefile index 2c9ff7c93..f67f53b50 100644 --- a/policies/Makefile +++ b/policies/Makefile @@ -46,5 +46,5 @@ coverage: .PHONY: lint lint: - $(OPA) fmt -d --fail . - $(OPA) check --strict . + $(OPA) fmt -d --fail ./*.rego util/*.rego + $(OPA) check --strict --schema schema/ ./*.rego util/*.rego diff --git a/policies/email_test.rego b/policies/email_test.rego new file mode 100644 index 000000000..efda51bd5 --- /dev/null +++ b/policies/email_test.rego @@ -0,0 +1,26 @@ +package email + +test_allow_all_domains { + allow with input.email as "hello@staging.element.io" +} + +test_allowed_domain { + allow with input.email as "hello@staging.element.io" + with data.allowed_domains as ["*.element.io"] +} + +test_not_allowed_domain { + not allow with input.email as "hello@staging.element.io" + with data.allowed_domains as ["example.com"] +} + +test_banned_domain { + not allow with input.email as "hello@staging.element.io" + with data.banned_domains as ["*.element.io"] +} + +test_banned_subdomain { + not allow with input.email as "hello@staging.element.io" + with data.allowed_domains as ["*.element.io"] + with data.banned_domains as ["staging.element.io"] +} diff --git a/policies/password_test.rego b/policies/password_test.rego new file mode 100644 index 000000000..4748974dd --- /dev/null +++ b/policies/password_test.rego @@ -0,0 +1,29 @@ +package password + +test_password_require_number { + allow with data.passwords.require_number as true + + not allow with input.password as "hunter" + with data.passwords.require_number as true +} + +test_password_require_lowercase { + allow with data.passwords.require_lowercase as true + + not allow with input.password as "HUNTER2" + with data.passwords.require_lowercase as true +} + +test_password_require_uppercase { + allow with data.passwords.require_uppercase as true + + not allow with input.password as "hunter2" + with data.passwords.require_uppercase as true +} + +test_password_min_length { + allow with data.passwords.min_length as 6 + + not allow with input.password as "short" + with data.passwords.min_length as 6 +} diff --git a/policies/register.rego b/policies/register.rego index b15e0fdce..34ada0216 100644 --- a/policies/register.rego +++ b/policies/register.rego @@ -22,6 +22,18 @@ violation[{"field": "username", "msg": "username too long"}] { count(input.username) >= 15 } +violation[{"field": "username", "msg": "username contains invalid characters"}] { + not regex.match("^[a-z0-9.=_/-]+$", input.username) +} + +violation[{"msg": "unspecified registration method"}] { + not input.registration_method +} + +violation[{"msg": "unknown registration method"}] { + not input.registration_method in ["password", "upstream-oauth2"] +} + violation[object.union({"field": "password"}, v)] { # Check if the registration method is password input.registration_method == "password" @@ -30,9 +42,19 @@ violation[object.union({"field": "password"}, v)] { some v in password_policy.violation } +# Check that we supplied an email for password registration +violation[{"field": "email", "msg": "email required for password-based registration"}] { + input.registration_method == "password" + + not input.email +} + # Check if the email is valid using the email policy # and add the email field to the violation object violation[object.union({"field": "email"}, v)] { + # Check if we have an email set in the input + input.email + # Get the violation object from the email policy some v in email_policy.violation } diff --git a/policies/register_test.rego b/policies/register_test.rego index 70acea87a..7e4ca3a01 100644 --- a/policies/register_test.rego +++ b/policies/register_test.rego @@ -32,54 +32,58 @@ test_banned_subdomain { with data.banned_domains as ["staging.element.io"] } +test_email_required { + not allow with input as {"username": "hello", "registration_method": "password"} +} + +test_no_email { + allow with input as {"username": "hello", "registration_method": "upstream-oauth2"} +} + test_short_username { - not allow with input as {"username": "a", "email": "hello@element.io"} + not allow with input as {"username": "a", "registration_method": "upstream-oauth2"} } test_long_username { - not allow with input as {"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "email": "hello@element.io"} + not allow with input as {"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "registration_method": "upstream-oauth2"} +} + +test_invalid_username { + not allow with input as {"username": "hello world", "registration_method": "upstream-oauth2"} } test_password_require_number { allow with input as mock_registration - with input.registration_method as "password" with data.passwords.require_number as true not allow with input as mock_registration - with input.registration_method as "password" with input.password as "hunter" with data.passwords.require_number as true } test_password_require_lowercase { allow with input as mock_registration - with input.registration_method as "password" with data.passwords.require_lowercase as true not allow with input as mock_registration - with input.registration_method as "password" with input.password as "HUNTER2" with data.passwords.require_lowercase as true } test_password_require_uppercase { allow with input as mock_registration - with input.registration_method as "password" with data.passwords.require_uppercase as true not allow with input as mock_registration - with input.registration_method as "password" with input.password as "hunter2" with data.passwords.require_uppercase as true } test_password_min_length { allow with input as mock_registration - with input.registration_method as "password" with data.passwords.min_length as 6 not allow with input as mock_registration - with input.registration_method as "password" with input.password as "short" with data.passwords.min_length as 6 } diff --git a/policies/schema/register_input.json b/policies/schema/register_input.json index 1f1585aa7..db0c137ab 100644 --- a/policies/schema/register_input.json +++ b/policies/schema/register_input.json @@ -28,6 +28,27 @@ "type": "string" } } + }, + { + "type": "object", + "required": [ + "registration_method", + "username" + ], + "properties": { + "email": { + "type": "string" + }, + "registration_method": { + "type": "string", + "enum": [ + "upstream-oauth2" + ] + }, + "username": { + "type": "string" + } + } } ] } \ No newline at end of file diff --git a/templates/components/field.html b/templates/components/field.html index fe68987a4..1aec8a832 100644 --- a/templates/components/field.html +++ b/templates/components/field.html @@ -14,7 +14,7 @@ limitations under the License. #} -{% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text", autocorrect=false, autocapitalize=false, disabled=false) %} +{% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text", autocorrect=false, autocapitalize=false, disabled=false, required=false) %} {% if not form_state %} {% set form_state = dict(errors=[], fields=dict()) %} {% endif %} @@ -35,6 +35,7 @@ class="z-0 px-3 py-2 bg-white dark:bg-black-900 rounded-lg {{ border_color }} border-2 focus:border-accent focus:ring-0 focus:outline-0" type="{{ type }}" inputmode="{{ inputmode }}" + {% if required %} required {% endif %} {% if disabled %} disabled {% endif %} {% if autocomplete %} autocomplete="{{ autocomplete }}" {% endif %} {% if autocorrect %} autocorrect="{{ autocorrect }}" {% endif %} diff --git a/templates/pages/account/emails/add.html b/templates/pages/account/emails/add.html index 63ad8fedd..1fe3f155b 100644 --- a/templates/pages/account/emails/add.html +++ b/templates/pages/account/emails/add.html @@ -33,7 +33,7 @@

Add an email address

{% endif %} - {{ field::input(label="Email", name="email", type="email", form_state=form, autocomplete="email") }} + {{ field::input(label="Email", name="email", type="email", form_state=form, autocomplete="email", required=true) }} {{ button::button(text="Next") }} {% endblock content %} diff --git a/templates/pages/error.html b/templates/pages/error.html index 09b00b79d..219582b99 100644 --- a/templates/pages/error.html +++ b/templates/pages/error.html @@ -17,23 +17,25 @@ {% extends "base.html" %} {% block content %} -
-
-
- {% if code %} -

- {{ code }} -

- {% endif %} - {% if description %} -

- {{ description }} -

- {% endif %} - {% if details %} -
{{ details }}
- {% endif %} -
-
-
+
+
+

Unexpected error

+ {% if code %} +

+ {{ code }} +

+ {% endif %} + {% if description %} +

+ {{ description }} +

+ {% endif %} + {% if details %} +
+ +
{{ details }}
+
+ {% endif %} +
+
{% endblock %}