Skip to content

Commit

Permalink
refactor: replace identity lookups (pypi#14949)
Browse files Browse the repository at this point in the history
  • Loading branch information
miketheman authored Dec 4, 2023
1 parent 3c11488 commit d1e4993
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 84 deletions.
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def pyramid_request(pyramid_services, jinja, remote_addr, remote_addr_hashed):
dummy_request.remote_addr_hashed = remote_addr_hashed
dummy_request.authentication_method = pretend.stub()
dummy_request._unauthenticated_userid = None
dummy_request.user = None
dummy_request.oidc_publisher = None

dummy_request.registry.registerUtility(jinja, IJinja2Environment, name=".jinja2")
Expand Down
42 changes: 22 additions & 20 deletions tests/unit/accounts/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ def test_returns_user(self, db_request):

class TestAccountsSearch:
def test_unauthenticated_raises_401(self):
pyramid_request = pretend.stub(authenticated_userid=None)
pyramid_request = pretend.stub(user=None)
with pytest.raises(HTTPUnauthorized):
views.accounts_search(pyramid_request)

def test_no_query_string_raises_400(self):
pyramid_request = pretend.stub(authenticated_userid=1, params=MultiDict({}))
pyramid_request = pretend.stub(user=pretend.stub(), params=MultiDict({}))
with pytest.raises(HTTPBadRequest):
views.accounts_search(pyramid_request)

Expand All @@ -162,7 +162,7 @@ def test_returns_users_with_prefix(self, db_session, user_service):
]

request = pretend.stub(
authenticated_userid=1,
user=pretend.stub(),
find_service=lambda svc, **kw: {
IUserService: user_service,
IRateLimiter: pretend.stub(
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_when_rate_limited(self, db_session):
test=pretend.call_recorder(lambda ip_address: False),
)
request = pretend.stub(
authenticated_userid=1,
user=pretend.stub(),
find_service=lambda svc, **kw: {
IRateLimiter: search_limiter,
}[svc],
Expand Down Expand Up @@ -405,7 +405,7 @@ def test_post_validate_no_redirects(
pyramid_request.session.record_password_timestamp = lambda timestamp: None

security_policy = pretend.stub(
authenticated_userid=lambda r: None,
identity=lambda r: None,
remember=lambda r, u, **kw: [],
reset=pretend.call_recorder(lambda r: None),
)
Expand All @@ -418,6 +418,7 @@ def test_post_validate_no_redirects(
)
form_class = pretend.call_recorder(lambda d, **kw: form_obj)
pyramid_request.route_path = pretend.call_recorder(lambda a: "/the-redirect")

result = views.login(pyramid_request, _form_class=form_class)

assert isinstance(result, HTTPSeeOther)
Expand All @@ -433,7 +434,7 @@ def test_post_validate_no_redirects(
assert security_policy.reset.calls == [pretend.call(pyramid_request)]

def test_redirect_authenticated_user(self):
pyramid_request = pretend.stub(authenticated_userid=1)
pyramid_request = pretend.stub(user=pretend.stub())
pyramid_request.route_path = pretend.call_recorder(lambda a: "/the-redirect")
result = views.login(pyramid_request)
assert isinstance(result, HTTPSeeOther)
Expand Down Expand Up @@ -697,6 +698,7 @@ def test_get_returns_recovery_code_status(self, pyramid_request, redirect_url):
}[interface]
pyramid_request.registry.settings = {"remember_device.days": 30}
pyramid_request.query_string = pretend.stub()

result = views.two_factor_and_totp_validate(
pyramid_request, _form_class=pretend.stub()
)
Expand Down Expand Up @@ -765,9 +767,6 @@ def test_totp_auth(
get_password_timestamp=lambda userid: 0,
)

pyramid_request.set_property(
lambda r: str(uuid.uuid4()), name="unauthenticated_userid"
)
pyramid_request.session.record_auth_timestamp = pretend.call_recorder(
lambda *args: None
)
Expand Down Expand Up @@ -824,7 +823,7 @@ def test_totp_auth(

def test_totp_auth_already_authed(self):
request = pretend.stub(
authenticated_userid="not_none",
identity=pretend.stub(),
route_path=pretend.call_recorder(lambda p: "redirect_to"),
)
result = views.two_factor_and_totp_validate(request)
Expand Down Expand Up @@ -858,7 +857,7 @@ def test_totp_form_invalid(self):
POST={},
method="POST",
session=pretend.stub(flash=pretend.call_recorder(lambda *a, **kw: None)),
authenticated_userid=None,
identity=None,
route_path=pretend.call_recorder(lambda p: "redirect_to"),
find_service=lambda interface, **kwargs: {
ITokenService: token_service,
Expand Down Expand Up @@ -894,6 +893,7 @@ def test_two_factor_token_missing_userid(self, pyramid_request):
ITokenService: token_service
}[interface]
pyramid_request.query_string = pretend.stub()

result = views.two_factor_and_totp_validate(pyramid_request)

assert token_service.loads.calls == [
Expand Down Expand Up @@ -929,7 +929,7 @@ def test_two_factor_token_invalid(self, pyramid_request):

class TestWebAuthn:
def test_webauthn_get_options_already_authenticated(self):
request = pretend.stub(authenticated_userid=pretend.stub(), _=lambda a: a)
request = pretend.stub(user=pretend.stub(), _=lambda a: a)

result = views.webauthn_authentication_options(request)

Expand Down Expand Up @@ -967,7 +967,7 @@ def test_webauthn_get_options(self, monkeypatch):
),
registry=pretend.stub(settings=pretend.stub(get=lambda *a: pretend.stub())),
domain=pretend.stub(),
authenticated_userid=None,
user=None,
find_service=lambda interface, **kwargs: user_service,
)

Expand All @@ -977,7 +977,8 @@ def test_webauthn_get_options(self, monkeypatch):
assert result == {"not": "real"}

def test_webauthn_validate_already_authenticated(self):
request = pretend.stub(authenticated_userid=pretend.stub())
# TODO: Determine why we can't use `request.user` here.
request = pretend.stub(identity=pretend.stub())
result = views.webauthn_authentication_validate(request)

assert result == {"fail": {"errors": ["Already authenticated"]}}
Expand All @@ -1004,7 +1005,8 @@ def test_webauthn_validate_invalid_form(self, monkeypatch):
monkeypatch.setattr(views, "_get_two_factor_data", _get_two_factor_data)

request = pretend.stub(
authenticated_userid=None,
# TODO: Determine why we can't use `request.user` here.
identity=None,
POST={},
session=pretend.stub(
get_webauthn_challenge=pretend.call_recorder(lambda: "not_real"),
Expand Down Expand Up @@ -1188,7 +1190,7 @@ def test_remember_device(self):
class TestRecoveryCode:
def test_already_authenticated(self):
request = pretend.stub(
authenticated_userid="not_none",
user=pretend.stub(),
route_path=pretend.call_recorder(lambda p: "redirect_to"),
)
result = views.recovery_code(request)
Expand Down Expand Up @@ -1382,7 +1384,7 @@ def test_recovery_code_form_invalid(self):
POST={},
method="POST",
session=pretend.stub(flash=pretend.call_recorder(lambda *a, **kw: None)),
authenticated_userid=None,
user=None,
route_path=pretend.call_recorder(lambda p: "redirect_to"),
find_service=lambda interface, **kwargs: {
ITokenService: token_service,
Expand Down Expand Up @@ -1510,7 +1512,7 @@ def test_get(self, db_request):
assert result["form"] is form_inst

def test_redirect_authenticated_user(self):
pyramid_request = pretend.stub(authenticated_userid=1)
pyramid_request = pretend.stub(user=pretend.stub())
pyramid_request.route_path = pretend.call_recorder(lambda a: "/the-redirect")
result = views.register(pyramid_request)
assert isinstance(result, HTTPSeeOther)
Expand Down Expand Up @@ -1957,7 +1959,7 @@ def test_password_reset_prohibited(
]

def test_redirect_authenticated_user(self):
pyramid_request = pretend.stub(authenticated_userid=1)
pyramid_request = pretend.stub(user=pretend.stub())
pyramid_request.route_path = pretend.call_recorder(lambda a: "/the-redirect")
result = views.request_password_reset(pyramid_request)
assert isinstance(result, HTTPSeeOther)
Expand Down Expand Up @@ -2317,7 +2319,7 @@ def test_reset_password_password_date_changed(self, pyramid_request):
]

def test_redirect_authenticated_user(self):
pyramid_request = pretend.stub(authenticated_userid=1)
pyramid_request = pretend.stub(user=pretend.stub())
pyramid_request.route_path = pretend.call_recorder(lambda a: "/the-redirect")
result = views.reset_password(pyramid_request)
assert isinstance(result, HTTPSeeOther)
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,15 @@ def test_logged_in_returns_exception(self, pyramid_config):
exc = pretend.stub(
status_code=403, status="403 Forbidden", headers={}, result=pretend.stub()
)
request = pretend.stub(authenticated_userid=1, context=None)
request = pretend.stub(user=pretend.stub(), context=None)
resp = forbidden(exc, request)
assert resp.status_code == 403
renderer.assert_()

def test_logged_out_redirects_login(self):
exc = pretend.stub()
request = pretend.stub(
authenticated_userid=None,
user=None,
path_qs="/foo/bar/?b=s",
route_url=pretend.call_recorder(
lambda route, _query: "/accounts/login/?next=/foo/bar/%3Fb%3Ds"
Expand All @@ -236,7 +236,7 @@ def test_two_factor_required(self, reason):
result = WarehouseDenied("Some summary", reason=reason)
exc = pretend.stub(result=result)
request = pretend.stub(
authenticated_userid=1,
user=pretend.stub(),
session=pretend.stub(flash=pretend.call_recorder(lambda x, queue: None)),
path_qs="/foo/bar/?b=s",
route_url=pretend.call_recorder(
Expand Down Expand Up @@ -268,7 +268,7 @@ def test_unverified_email_redirects(self, requested_path):
result = WarehouseDenied("Some summary", reason="unverified_email")
exc = pretend.stub(result=result)
request = pretend.stub(
authenticated_userid=1,
user=pretend.stub(),
session=pretend.stub(flash=pretend.call_recorder(lambda x, queue: None)),
path_qs=requested_path,
route_url=pretend.call_recorder(lambda route, _query: "/manage/account/"),
Expand Down Expand Up @@ -300,7 +300,7 @@ def test_generic_warehousedeined(self, pyramid_config):
exc = pretend.stub(
status_code=403, status="403 Forbidden", headers={}, result=result
)
request = pretend.stub(authenticated_userid=1, context=None)
request = pretend.stub(user=pretend.stub(), context=None)
resp = forbidden(exc, request)
assert resp.status_code == 403
renderer.assert_()
Expand Down
22 changes: 13 additions & 9 deletions warehouse/accounts/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def accounts_search(request) -> dict[str, list[User]]:
Used with autocomplete.
User must be logged in.
"""
if request.authenticated_userid is None:
if request.user is None:
raise HTTPUnauthorized()

form = UsernameSearchForm(request.params)
Expand Down Expand Up @@ -217,7 +217,7 @@ def login(request, redirect_field_name=REDIRECT_FIELD_NAME, _form_class=LoginFor
# TODO: Logging in should reset request.user
# TODO: Configure the login view as the default view for not having
# permission to view something.
if request.authenticated_userid is not None:
if request.user is not None:
return HTTPSeeOther(request.route_path("manage.projects"))

user_service = request.find_service(IUserService, context=None)
Expand Down Expand Up @@ -308,7 +308,9 @@ def login(request, redirect_field_name=REDIRECT_FIELD_NAME, _form_class=LoginFor
has_translations=True,
)
def two_factor_and_totp_validate(request, _form_class=TOTPAuthenticationForm):
if request.authenticated_userid is not None:
# TODO: Using `request.user` here fails `test_totp_auth()` because
# of how the test is constructed. We should fix that.
if request.identity is not None:
return HTTPSeeOther(request.route_path("manage.projects"))

try:
Expand Down Expand Up @@ -377,7 +379,7 @@ def two_factor_and_totp_validate(request, _form_class=TOTPAuthenticationForm):
has_translations=True,
)
def webauthn_authentication_options(request):
if request.authenticated_userid is not None:
if request.user is not None:
return {"fail": {"errors": [request._("Already authenticated")]}}

try:
Expand Down Expand Up @@ -406,7 +408,9 @@ def webauthn_authentication_options(request):
has_translations=True,
)
def webauthn_authentication_validate(request):
if request.authenticated_userid is not None:
# TODO: Using `request.user` here fails `test_webauthn_validate()` because
# of how the test is constructed. We should fix that.
if request.identity is not None:
return {"fail": {"errors": ["Already authenticated"]}}

try:
Expand Down Expand Up @@ -514,7 +518,7 @@ def _remember_device(request, response, userid, two_factor_method) -> None:
has_translations=True,
)
def recovery_code(request, _form_class=RecoveryCodeAuthenticationForm):
if request.authenticated_userid is not None:
if request.user is not None:
return HTTPSeeOther(request.route_path("manage.projects"))

try:
Expand Down Expand Up @@ -635,7 +639,7 @@ def logout(request, redirect_field_name=REDIRECT_FIELD_NAME):
has_translations=True,
)
def register(request, _form_class=RegistrationForm):
if request.authenticated_userid is not None:
if request.user is not None:
return HTTPSeeOther(request.route_path("manage.projects"))

# Check if the honeypot field has been filled
Expand Down Expand Up @@ -701,7 +705,7 @@ def register(request, _form_class=RegistrationForm):
has_translations=True,
)
def request_password_reset(request, _form_class=RequestPasswordResetForm):
if request.authenticated_userid is not None:
if request.user is not None:
return HTTPSeeOther(request.route_path("index"))

user_service = request.find_service(IUserService, context=None)
Expand Down Expand Up @@ -760,7 +764,7 @@ def request_password_reset(request, _form_class=RequestPasswordResetForm):
has_translations=True,
)
def reset_password(request, _form_class=ResetPasswordForm):
if request.authenticated_userid is not None:
if request.user is not None:
return HTTPSeeOther(request.route_path("index"))

user_service = request.find_service(IUserService, context=None)
Expand Down
Loading

0 comments on commit d1e4993

Please sign in to comment.