Skip to content
This repository has been archived by the owner on Sep 10, 2024. It is now read-only.

Commit

Permalink
Run the registration policy on upstream OAuth registration
Browse files Browse the repository at this point in the history
  • Loading branch information
sandhose committed Aug 30, 2023
1 parent cda799e commit 319de87
Show file tree
Hide file tree
Showing 14 changed files with 207 additions and 41 deletions.
1 change: 1 addition & 0 deletions crates/axum-utils/src/fancy_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub struct FancyError {
}

impl FancyError {
#[must_use]
pub fn new(context: ErrorContext) -> Self {
Self { context }
}
Expand Down
33 changes: 31 additions & 2 deletions crates/handlers/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,20 @@ use hyper::StatusCode;
use mas_axum_utils::{
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
SessionInfoExt,
FancyError, SessionInfoExt,
};
use mas_data_model::{UpstreamOAuthProviderImportPreference, User};
use mas_jose::jwt::Jwt;
use mas_policy::Policy;
use mas_storage::{
job::{JobRepositoryExt, ProvisionUserJob},
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
user::{BrowserSessionRepository, UserEmailRepository, UserRepository},
BoxClock, BoxRepository, BoxRng, RepositoryAccess,
};
use mas_templates::{
TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
ErrorContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister,
UpstreamSuggestLink,
};
use serde::Deserialize;
use thiserror::Error;
Expand Down Expand Up @@ -76,6 +78,11 @@ pub(crate) enum RouteError {
#[error("Missing username")]
MissingUsername,

#[error("Policy violation: {violations:?}")]
PolicyViolation {
violations: Vec<mas_policy::Violation>,
},

#[error(transparent)]
Internal(Box<dyn std::error::Error>),
}
Expand All @@ -84,13 +91,24 @@ impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::EvaluationError);
impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError);

impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
sentry::capture_error(&self);
match self {
Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
Self::PolicyViolation { violations } => {
let details = violations.iter().map(|v| v.msg.clone()).collect::<Vec<_>>();
let details = details.join("\n");
let ctx = ErrorContext::new()
.with_description(
"Account registration denied because of policy violation".to_owned(),
)
.with_details(details);
FancyError::new(ctx).into_response()
}
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
e => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
}
Expand Down Expand Up @@ -358,6 +376,7 @@ pub(crate) async fn post(
mut repo: BoxRepository,
cookie_jar: CookieJar,
user_agent: Option<TypedHeader<headers::UserAgent>>,
mut policy: Policy,
Path(link_id): Path<Ulid>,
Form(form): Form<ProtectedForm<FormData>>,
) -> Result<impl IntoResponse, RouteError> {
Expand Down Expand Up @@ -478,6 +497,16 @@ pub(crate) async fn post(

let username = username.ok_or(RouteError::MissingUsername)?;

// Policy check
let res = policy
.evaluate_upstream_oauth_register(&username, email.as_deref())
.await?;
if !res.valid() {
return Err(RouteError::PolicyViolation {
violations: res.violations,
});
}

// Now we can create the user
let user = repo.user().add(&mut rng, &clock, username).await?;

Expand Down
7 changes: 2 additions & 5 deletions crates/handlers/src/views/account/emails/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,9 @@ pub(crate) async fn post(
let user_email = if let Some(user_email) = existing_user_email {
user_email
} else {
let user_email = repo
.user_email()
repo.user_email()
.add(&mut rng, &clock, &session.user, form.email)
.await?;

user_email
.await?
};

// If the email was not confirmed, send a confirmation email & redirect to the
Expand Down
25 changes: 25 additions & 0 deletions crates/policy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,31 @@ impl Policy {
Ok(res)
}

#[tracing::instrument(
name = "policy.evaluate.upstream_oauth_register",
skip_all,
fields(
input.registration_method = "password",
input.user.username = username,
input.user.email = email,
),
err,
)]
pub async fn evaluate_upstream_oauth_register(
&mut self,
username: &str,
email: Option<&str>,
) -> Result<EvaluationResult, EvaluationError> {
let input = RegisterInput::UpstreamOAuth2 { username, email };

let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(&mut self.store, &self.entrypoints.register, &input)
.await?;

Ok(res)
}

#[tracing::instrument(skip(self))]
pub async fn evaluate_client_registration(
&mut self,
Expand Down
11 changes: 10 additions & 1 deletion crates/policy/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,23 @@ impl EvaluationResult {

/// Input for the user registration policy.
#[derive(Serialize, Debug)]
#[serde(tag = "registration_method", rename_all = "snake_case")]
#[serde(tag = "registration_method")]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
pub enum RegisterInput<'a> {
#[serde(rename = "password")]
Password {
username: &'a str,
password: &'a str,
email: &'a str,
},

#[serde(rename = "upstream-oauth2")]
UpstreamOAuth2 {
username: &'a str,

#[serde(skip_serializing_if = "Option::is_none")]
email: Option<&'a str>,
},
}

/// Input for the client registration policy.
Expand Down
4 changes: 2 additions & 2 deletions policies/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,5 @@ coverage:

.PHONY: lint
lint:
$(OPA) fmt -d --fail .
$(OPA) check --strict .
$(OPA) fmt -d --fail ./*.rego util/*.rego
$(OPA) check --strict --schema schema/ ./*.rego util/*.rego
26 changes: 26 additions & 0 deletions policies/email_test.rego
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package email

test_allow_all_domains {
allow with input.email as "[email protected]"
}

test_allowed_domain {
allow with input.email as "[email protected]"
with data.allowed_domains as ["*.element.io"]
}

test_not_allowed_domain {
not allow with input.email as "[email protected]"
with data.allowed_domains as ["example.com"]
}

test_banned_domain {
not allow with input.email as "[email protected]"
with data.banned_domains as ["*.element.io"]
}

test_banned_subdomain {
not allow with input.email as "[email protected]"
with data.allowed_domains as ["*.element.io"]
with data.banned_domains as ["staging.element.io"]
}
29 changes: 29 additions & 0 deletions policies/password_test.rego
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package password

test_password_require_number {
allow with data.passwords.require_number as true

not allow with input.password as "hunter"
with data.passwords.require_number as true
}

test_password_require_lowercase {
allow with data.passwords.require_lowercase as true

not allow with input.password as "HUNTER2"
with data.passwords.require_lowercase as true
}

test_password_require_uppercase {
allow with data.passwords.require_uppercase as true

not allow with input.password as "hunter2"
with data.passwords.require_uppercase as true
}

test_password_min_length {
allow with data.passwords.min_length as 6

not allow with input.password as "short"
with data.passwords.min_length as 6
}
22 changes: 22 additions & 0 deletions policies/register.rego
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ violation[{"field": "username", "msg": "username too long"}] {
count(input.username) >= 15
}

violation[{"field": "username", "msg": "username contains invalid characters"}] {
not regex.match("^[a-z0-9.=_/-]+$", input.username)
}

violation[{"msg": "unspecified registration method"}] {
not input.registration_method
}

violation[{"msg": "unknown registration method"}] {
not input.registration_method in ["password", "upstream-oauth2"]
}

violation[object.union({"field": "password"}, v)] {
# Check if the registration method is password
input.registration_method == "password"
Expand All @@ -30,9 +42,19 @@ violation[object.union({"field": "password"}, v)] {
some v in password_policy.violation
}

# Check that we supplied an email for password registration
violation[{"field": "email", "msg": "email required for password-based registration"}] {
input.registration_method == "password"

not input.email
}

# Check if the email is valid using the email policy
# and add the email field to the violation object
violation[object.union({"field": "email"}, v)] {
# Check if we have an email set in the input
input.email

# Get the violation object from the email policy
some v in email_policy.violation
}
24 changes: 14 additions & 10 deletions policies/register_test.rego
Original file line number Diff line number Diff line change
Expand Up @@ -32,54 +32,58 @@ test_banned_subdomain {
with data.banned_domains as ["staging.element.io"]
}

test_email_required {
not allow with input as {"username": "hello", "registration_method": "password"}
}

test_no_email {
allow with input as {"username": "hello", "registration_method": "upstream-oauth2"}
}

test_short_username {
not allow with input as {"username": "a", "email": "[email protected]"}
not allow with input as {"username": "a", "registration_method": "upstream-oauth2"}
}

test_long_username {
not allow with input as {"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "email": "[email protected]"}
not allow with input as {"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "registration_method": "upstream-oauth2"}
}

test_invalid_username {
not allow with input as {"username": "hello world", "registration_method": "upstream-oauth2"}
}

test_password_require_number {
allow with input as mock_registration
with input.registration_method as "password"
with data.passwords.require_number as true

not allow with input as mock_registration
with input.registration_method as "password"
with input.password as "hunter"
with data.passwords.require_number as true
}

test_password_require_lowercase {
allow with input as mock_registration
with input.registration_method as "password"
with data.passwords.require_lowercase as true

not allow with input as mock_registration
with input.registration_method as "password"
with input.password as "HUNTER2"
with data.passwords.require_lowercase as true
}

test_password_require_uppercase {
allow with input as mock_registration
with input.registration_method as "password"
with data.passwords.require_uppercase as true

not allow with input as mock_registration
with input.registration_method as "password"
with input.password as "hunter2"
with data.passwords.require_uppercase as true
}

test_password_min_length {
allow with input as mock_registration
with input.registration_method as "password"
with data.passwords.min_length as 6

not allow with input as mock_registration
with input.registration_method as "password"
with input.password as "short"
with data.passwords.min_length as 6
}
21 changes: 21 additions & 0 deletions policies/schema/register_input.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,27 @@
"type": "string"
}
}
},
{
"type": "object",
"required": [
"registration_method",
"username"
],
"properties": {
"email": {
"type": "string"
},
"registration_method": {
"type": "string",
"enum": [
"upstream-oauth2"
]
},
"username": {
"type": "string"
}
}
}
]
}
3 changes: 2 additions & 1 deletion templates/components/field.html
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
#}

{% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text", autocorrect=false, autocapitalize=false, disabled=false) %}
{% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text", autocorrect=false, autocapitalize=false, disabled=false, required=false) %}
{% if not form_state %}
{% set form_state = dict(errors=[], fields=dict()) %}
{% endif %}
Expand All @@ -35,6 +35,7 @@
class="z-0 px-3 py-2 bg-white dark:bg-black-900 rounded-lg {{ border_color }} border-2 focus:border-accent focus:ring-0 focus:outline-0"
type="{{ type }}"
inputmode="{{ inputmode }}"
{% if required %} required {% endif %}
{% if disabled %} disabled {% endif %}
{% if autocomplete %} autocomplete="{{ autocomplete }}" {% endif %}
{% if autocorrect %} autocorrect="{{ autocorrect }}" {% endif %}
Expand Down
2 changes: 1 addition & 1 deletion templates/pages/account/emails/add.html
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ <h1 class="text-lg text-center font-medium">Add an email address</h1>
{% endif %}

<input type="hidden" name="csrf" value="{{ csrf_token }}" />
{{ field::input(label="Email", name="email", type="email", form_state=form, autocomplete="email") }}
{{ field::input(label="Email", name="email", type="email", form_state=form, autocomplete="email", required=true) }}
{{ button::button(text="Next") }}
</section>
{% endblock content %}
Loading

0 comments on commit 319de87

Please sign in to comment.