From 319de87a456b9cd9a5ea91f3c3beac9e32e87924 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 30 Aug 2023 18:36:53 +0200 Subject: [PATCH] Run the registration policy on upstream OAuth registration --- crates/axum-utils/src/fancy_error.rs | 1 + crates/handlers/src/upstream_oauth2/link.rs | 33 ++++++++++++++- .../handlers/src/views/account/emails/add.rs | 7 +--- crates/policy/src/lib.rs | 25 ++++++++++++ crates/policy/src/model.rs | 11 ++++- policies/Makefile | 4 +- policies/email_test.rego | 26 ++++++++++++ policies/password_test.rego | 29 ++++++++++++++ policies/register.rego | 22 ++++++++++ policies/register_test.rego | 24 ++++++----- policies/schema/register_input.json | 21 ++++++++++ templates/components/field.html | 3 +- templates/pages/account/emails/add.html | 2 +- templates/pages/error.html | 40 ++++++++++--------- 14 files changed, 207 insertions(+), 41 deletions(-) create mode 100644 policies/email_test.rego create mode 100644 policies/password_test.rego diff --git a/crates/axum-utils/src/fancy_error.rs b/crates/axum-utils/src/fancy_error.rs index 86e0a60f5..88cefdbb7 100644 --- a/crates/axum-utils/src/fancy_error.rs +++ b/crates/axum-utils/src/fancy_error.rs @@ -24,6 +24,7 @@ pub struct FancyError { } impl FancyError { + #[must_use] pub fn new(context: ErrorContext) -> Self { Self { context } } diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 9a49c4efc..60518788b 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -21,10 +21,11 @@ use hyper::StatusCode; use mas_axum_utils::{ cookies::CookieJar, csrf::{CsrfExt, ProtectedForm}, - SessionInfoExt, + FancyError, SessionInfoExt, }; use mas_data_model::{UpstreamOAuthProviderImportPreference, User}; use mas_jose::jwt::Jwt; +use mas_policy::Policy; use mas_storage::{ job::{JobRepositoryExt, ProvisionUserJob}, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, @@ -32,7 +33,8 @@ use mas_storage::{ BoxClock, BoxRepository, BoxRng, RepositoryAccess, }; use mas_templates::{ - TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink, + ErrorContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, + UpstreamSuggestLink, }; use serde::Deserialize; use thiserror::Error; @@ -76,6 +78,11 @@ pub(crate) enum RouteError { #[error("Missing username")] MissingUsername, + #[error("Policy violation: {violations:?}")] + PolicyViolation { + violations: Vec, + }, + #[error(transparent)] Internal(Box), } @@ -84,6 +91,7 @@ impl_from_error_for_route!(mas_templates::TemplateError); impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError); impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); impl_from_error_for_route!(mas_storage::RepositoryError); +impl_from_error_for_route!(mas_policy::EvaluationError); impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError); impl IntoResponse for RouteError { @@ -91,6 +99,16 @@ impl IntoResponse for RouteError { sentry::capture_error(&self); match self { Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(), + Self::PolicyViolation { violations } => { + let details = violations.iter().map(|v| v.msg.clone()).collect::>(); + let details = details.join("\n"); + let ctx = ErrorContext::new() + .with_description( + "Account registration denied because of policy violation".to_owned(), + ) + .with_details(details); + FancyError::new(ctx).into_response() + } Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), e => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), } @@ -358,6 +376,7 @@ pub(crate) async fn post( mut repo: BoxRepository, cookie_jar: CookieJar, user_agent: Option>, + mut policy: Policy, Path(link_id): Path, Form(form): Form>, ) -> Result { @@ -478,6 +497,16 @@ pub(crate) async fn post( let username = username.ok_or(RouteError::MissingUsername)?; + // Policy check + let res = policy + .evaluate_upstream_oauth_register(&username, email.as_deref()) + .await?; + if !res.valid() { + return Err(RouteError::PolicyViolation { + violations: res.violations, + }); + } + // Now we can create the user let user = repo.user().add(&mut rng, &clock, username).await?; diff --git a/crates/handlers/src/views/account/emails/add.rs b/crates/handlers/src/views/account/emails/add.rs index e4577d748..c8bc5da3f 100644 --- a/crates/handlers/src/views/account/emails/add.rs +++ b/crates/handlers/src/views/account/emails/add.rs @@ -107,12 +107,9 @@ pub(crate) async fn post( let user_email = if let Some(user_email) = existing_user_email { user_email } else { - let user_email = repo - .user_email() + repo.user_email() .add(&mut rng, &clock, &session.user, form.email) - .await?; - - user_email + .await? }; // If the email was not confirmed, send a confirmation email & redirect to the diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index a9d44c48b..95f963a76 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -251,6 +251,31 @@ impl Policy { Ok(res) } + #[tracing::instrument( + name = "policy.evaluate.upstream_oauth_register", + skip_all, + fields( + input.registration_method = "password", + input.user.username = username, + input.user.email = email, + ), + err, + )] + pub async fn evaluate_upstream_oauth_register( + &mut self, + username: &str, + email: Option<&str>, + ) -> Result { + let input = RegisterInput::UpstreamOAuth2 { username, email }; + + let [res]: [EvaluationResult; 1] = self + .instance + .evaluate(&mut self.store, &self.entrypoints.register, &input) + .await?; + + Ok(res) + } + #[tracing::instrument(skip(self))] pub async fn evaluate_client_registration( &mut self, diff --git a/crates/policy/src/model.rs b/crates/policy/src/model.rs index 65c4a0058..c17ce0417 100644 --- a/crates/policy/src/model.rs +++ b/crates/policy/src/model.rs @@ -56,14 +56,23 @@ impl EvaluationResult { /// Input for the user registration policy. #[derive(Serialize, Debug)] -#[serde(tag = "registration_method", rename_all = "snake_case")] +#[serde(tag = "registration_method")] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] pub enum RegisterInput<'a> { + #[serde(rename = "password")] Password { username: &'a str, password: &'a str, email: &'a str, }, + + #[serde(rename = "upstream-oauth2")] + UpstreamOAuth2 { + username: &'a str, + + #[serde(skip_serializing_if = "Option::is_none")] + email: Option<&'a str>, + }, } /// Input for the client registration policy. diff --git a/policies/Makefile b/policies/Makefile index 2c9ff7c93..f67f53b50 100644 --- a/policies/Makefile +++ b/policies/Makefile @@ -46,5 +46,5 @@ coverage: .PHONY: lint lint: - $(OPA) fmt -d --fail . - $(OPA) check --strict . + $(OPA) fmt -d --fail ./*.rego util/*.rego + $(OPA) check --strict --schema schema/ ./*.rego util/*.rego diff --git a/policies/email_test.rego b/policies/email_test.rego new file mode 100644 index 000000000..efda51bd5 --- /dev/null +++ b/policies/email_test.rego @@ -0,0 +1,26 @@ +package email + +test_allow_all_domains { + allow with input.email as "hello@staging.element.io" +} + +test_allowed_domain { + allow with input.email as "hello@staging.element.io" + with data.allowed_domains as ["*.element.io"] +} + +test_not_allowed_domain { + not allow with input.email as "hello@staging.element.io" + with data.allowed_domains as ["example.com"] +} + +test_banned_domain { + not allow with input.email as "hello@staging.element.io" + with data.banned_domains as ["*.element.io"] +} + +test_banned_subdomain { + not allow with input.email as "hello@staging.element.io" + with data.allowed_domains as ["*.element.io"] + with data.banned_domains as ["staging.element.io"] +} diff --git a/policies/password_test.rego b/policies/password_test.rego new file mode 100644 index 000000000..4748974dd --- /dev/null +++ b/policies/password_test.rego @@ -0,0 +1,29 @@ +package password + +test_password_require_number { + allow with data.passwords.require_number as true + + not allow with input.password as "hunter" + with data.passwords.require_number as true +} + +test_password_require_lowercase { + allow with data.passwords.require_lowercase as true + + not allow with input.password as "HUNTER2" + with data.passwords.require_lowercase as true +} + +test_password_require_uppercase { + allow with data.passwords.require_uppercase as true + + not allow with input.password as "hunter2" + with data.passwords.require_uppercase as true +} + +test_password_min_length { + allow with data.passwords.min_length as 6 + + not allow with input.password as "short" + with data.passwords.min_length as 6 +} diff --git a/policies/register.rego b/policies/register.rego index b15e0fdce..34ada0216 100644 --- a/policies/register.rego +++ b/policies/register.rego @@ -22,6 +22,18 @@ violation[{"field": "username", "msg": "username too long"}] { count(input.username) >= 15 } +violation[{"field": "username", "msg": "username contains invalid characters"}] { + not regex.match("^[a-z0-9.=_/-]+$", input.username) +} + +violation[{"msg": "unspecified registration method"}] { + not input.registration_method +} + +violation[{"msg": "unknown registration method"}] { + not input.registration_method in ["password", "upstream-oauth2"] +} + violation[object.union({"field": "password"}, v)] { # Check if the registration method is password input.registration_method == "password" @@ -30,9 +42,19 @@ violation[object.union({"field": "password"}, v)] { some v in password_policy.violation } +# Check that we supplied an email for password registration +violation[{"field": "email", "msg": "email required for password-based registration"}] { + input.registration_method == "password" + + not input.email +} + # Check if the email is valid using the email policy # and add the email field to the violation object violation[object.union({"field": "email"}, v)] { + # Check if we have an email set in the input + input.email + # Get the violation object from the email policy some v in email_policy.violation } diff --git a/policies/register_test.rego b/policies/register_test.rego index 70acea87a..7e4ca3a01 100644 --- a/policies/register_test.rego +++ b/policies/register_test.rego @@ -32,54 +32,58 @@ test_banned_subdomain { with data.banned_domains as ["staging.element.io"] } +test_email_required { + not allow with input as {"username": "hello", "registration_method": "password"} +} + +test_no_email { + allow with input as {"username": "hello", "registration_method": "upstream-oauth2"} +} + test_short_username { - not allow with input as {"username": "a", "email": "hello@element.io"} + not allow with input as {"username": "a", "registration_method": "upstream-oauth2"} } test_long_username { - not allow with input as {"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "email": "hello@element.io"} + not allow with input as {"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "registration_method": "upstream-oauth2"} +} + +test_invalid_username { + not allow with input as {"username": "hello world", "registration_method": "upstream-oauth2"} } test_password_require_number { allow with input as mock_registration - with input.registration_method as "password" with data.passwords.require_number as true not allow with input as mock_registration - with input.registration_method as "password" with input.password as "hunter" with data.passwords.require_number as true } test_password_require_lowercase { allow with input as mock_registration - with input.registration_method as "password" with data.passwords.require_lowercase as true not allow with input as mock_registration - with input.registration_method as "password" with input.password as "HUNTER2" with data.passwords.require_lowercase as true } test_password_require_uppercase { allow with input as mock_registration - with input.registration_method as "password" with data.passwords.require_uppercase as true not allow with input as mock_registration - with input.registration_method as "password" with input.password as "hunter2" with data.passwords.require_uppercase as true } test_password_min_length { allow with input as mock_registration - with input.registration_method as "password" with data.passwords.min_length as 6 not allow with input as mock_registration - with input.registration_method as "password" with input.password as "short" with data.passwords.min_length as 6 } diff --git a/policies/schema/register_input.json b/policies/schema/register_input.json index 1f1585aa7..db0c137ab 100644 --- a/policies/schema/register_input.json +++ b/policies/schema/register_input.json @@ -28,6 +28,27 @@ "type": "string" } } + }, + { + "type": "object", + "required": [ + "registration_method", + "username" + ], + "properties": { + "email": { + "type": "string" + }, + "registration_method": { + "type": "string", + "enum": [ + "upstream-oauth2" + ] + }, + "username": { + "type": "string" + } + } } ] } \ No newline at end of file diff --git a/templates/components/field.html b/templates/components/field.html index fe68987a4..1aec8a832 100644 --- a/templates/components/field.html +++ b/templates/components/field.html @@ -14,7 +14,7 @@ limitations under the License. #} -{% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text", autocorrect=false, autocapitalize=false, disabled=false) %} +{% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text", autocorrect=false, autocapitalize=false, disabled=false, required=false) %} {% if not form_state %} {% set form_state = dict(errors=[], fields=dict()) %} {% endif %} @@ -35,6 +35,7 @@ class="z-0 px-3 py-2 bg-white dark:bg-black-900 rounded-lg {{ border_color }} border-2 focus:border-accent focus:ring-0 focus:outline-0" type="{{ type }}" inputmode="{{ inputmode }}" + {% if required %} required {% endif %} {% if disabled %} disabled {% endif %} {% if autocomplete %} autocomplete="{{ autocomplete }}" {% endif %} {% if autocorrect %} autocorrect="{{ autocorrect }}" {% endif %} diff --git a/templates/pages/account/emails/add.html b/templates/pages/account/emails/add.html index 63ad8fedd..1fe3f155b 100644 --- a/templates/pages/account/emails/add.html +++ b/templates/pages/account/emails/add.html @@ -33,7 +33,7 @@

Add an email address

{% endif %} - {{ field::input(label="Email", name="email", type="email", form_state=form, autocomplete="email") }} + {{ field::input(label="Email", name="email", type="email", form_state=form, autocomplete="email", required=true) }} {{ button::button(text="Next") }} {% endblock content %} diff --git a/templates/pages/error.html b/templates/pages/error.html index 09b00b79d..219582b99 100644 --- a/templates/pages/error.html +++ b/templates/pages/error.html @@ -17,23 +17,25 @@ {% extends "base.html" %} {% block content %} -
-
-
- {% if code %} -

- {{ code }} -

- {% endif %} - {% if description %} -

- {{ description }} -

- {% endif %} - {% if details %} -
{{ details }}
- {% endif %} -
-
-
+
+
+

Unexpected error

+ {% if code %} +

+ {{ code }} +

+ {% endif %} + {% if description %} +

+ {{ description }} +

+ {% endif %} + {% if details %} +
+ +
{{ details }}
+
+ {% endif %} +
+
{% endblock %}