From b2cd749799d8db0caa7e14bfa19c15abff992160 Mon Sep 17 00:00:00 2001 From: Alexandr Jariuc Date: Mon, 27 Nov 2023 01:13:07 +1100 Subject: [PATCH 1/3] feat(terminate_all_session): Terminate all user sessions --- src/core/config.py | 6 ++++- src/db/redis.py | 47 +++++++++++++++++++++++++++++++++++++ src/v1/auth/helpers.py | 15 +++++++++++- src/v1/auth/routers.py | 26 +++++++++++++++++++-- src/v1/auth/service.py | 53 +++++++++++++++++++++++++++++++++++++----- 5 files changed, 137 insertions(+), 10 deletions(-) diff --git a/src/core/config.py b/src/core/config.py index 281784b..e99c4ac 100755 --- a/src/core/config.py +++ b/src/core/config.py @@ -14,7 +14,7 @@ class Settings(BaseSettings): listen_port: int = 8000 allowed_hosts: list = const.DEFAULT_ALLOWED_HOSTS - redis_host: str = "redis" + redis_host: str = "localhost" redis_port: int = 6379 redis_db: int = 0 @@ -42,6 +42,10 @@ def pg_dsn(self): f"{self.postgres_port}/{self.postgres_db}" ) + @cached_property + def redis_dsn(self): + return f"redis://{self.redis_host}:{self.redis_port}" + class Config: case_sensitive = False env_file = ".env" diff --git a/src/db/redis.py b/src/db/redis.py index e69de29..0042a1d 100644 --- a/src/db/redis.py +++ b/src/db/redis.py @@ -0,0 +1,47 @@ +from typing import Annotated + +from aioredis import Redis, from_url +from pydantic import UUID4 +from fastapi import Depends + +from src.db.storages import BaseStorage +from src.core.config import settings + + +class RedisBlacklistUserSignatureStorage(BaseStorage): + + def __init__(self) -> None: + self.protocol: Redis = from_url(settings.redis_dsn, decode_responses=True) + self.namespace: str = "auth_service" + + async def create(self, user_id: UUID4, signature: str): + signature_key = f"{self.namespace}:{user_id}" + async with self.protocol.client() as conn: + await conn.set(signature_key, signature) + + async def get(self, user_id: UUID4) -> str: + signature_key = f"{self.namespace}:{user_id}" + async with self.protocol.client() as conn: + return await conn.get(signature_key) + + async def delete(self, user_id: UUID4): + signature_key = f"{self.namespace}:{user_id}" + async with self.protocol.client() as conn: + return await conn.delete(signature_key) + + async def delete_all(self, user_id: UUID4, count_size: int = 10) -> int: + pattern = f"{self.namespace}:{user_id}:*" + cursor = b"0" + deleted_count = 0 + + async with self.protocol.client() as conn: + while cursor: + cursor, keys = await conn.scan(cursor, match=pattern, count=count_size) + deleted_count += await conn.unlink(*keys) + return deleted_count + + +redis_blackist_storage = RedisBlacklistUserSignatureStorage() +BlacklistSignatureStorage = Annotated[ + RedisBlacklistUserSignatureStorage, Depends(redis_blackist_storage) +] diff --git a/src/v1/auth/helpers.py b/src/v1/auth/helpers.py index d329e75..5b9e5c4 100644 --- a/src/v1/auth/helpers.py +++ b/src/v1/auth/helpers.py @@ -5,9 +5,10 @@ from passlib.hash import pbkdf2_sha256 from pydantic import UUID4 +from src.db.redis import BlacklistSignatureStorage from src.core.config import settings from src.v1.auth.schemas import JWTTokens -from src.v1.auth.exceptions import InvalidTokenError +from src.v1.auth.exceptions import InvalidTokenError, UnauthorizedError # TODO(alexander.zharyuk): Improve generation. Maybe add some salt? @@ -65,3 +66,15 @@ def decode_jwt(token: str) -> dict: except JWTError: raise InvalidTokenError() return payload + + +async def validate_jwt(blacklist_tokens_storage: BlacklistSignatureStorage, token: str): + """Validate that token is not in blacklists""" + token_payload = decode_jwt(token) + token_headers = jwt.get_unverified_header(token) + + user_id = token_payload.get("user_id") + token_signature = token_headers.get("jti") + + if token_signature == await blacklist_tokens_storage.get(user_id): + raise UnauthorizedError() diff --git a/src/v1/auth/routers.py b/src/v1/auth/routers.py index bea5f20..387ce01 100755 --- a/src/v1/auth/routers.py +++ b/src/v1/auth/routers.py @@ -1,8 +1,8 @@ from fastapi import APIRouter, Depends, Request, Response, status from fastapi.security import APIKeyCookie -from src.db.postgres import DatabaseSession -from src.db.postgres import RefreshTokensStorage +from src.db.postgres import DatabaseSession, RefreshTokensStorage +from src.db.redis import BlacklistSignatureStorage from src.v1.auth.schemas import (TokensResponse, UserCreate, UserLogin, UserResponse, LogoutResponse, UserLogout) from src.v1.auth.service import AuthService @@ -54,3 +54,25 @@ async def logout( ) return LogoutResponse() + +@router.post( + "/logout_all", + summary="Выход из всех сессий пользователя", + response_model=LogoutResponse +) +async def terminate_all_sessions( + db_session: DatabaseSession, + blacklist_signatures_storage: BlacklistSignatureStorage, + refresh_token_storage: RefreshTokensStorage, + response: Response, + access_token: str | None = Depends(cookie_scheme) +) -> TokensResponse: + """Выход из всех сессий пользователя""" + await AuthService.terminate_all_sessions( + db_session, + blacklist_signatures_storage, + refresh_token_storage, + response, + access_token + ) + return LogoutResponse() \ No newline at end of file diff --git a/src/v1/auth/service.py b/src/v1/auth/service.py index bc018ea..c9ffcfa 100755 --- a/src/v1/auth/service.py +++ b/src/v1/auth/service.py @@ -3,10 +3,11 @@ from datetime import datetime from fastapi import Request, Response -from pydantic import BaseModel, UUID4 -from sqlalchemy import and_, or_, select, delete +from pydantic import BaseModel +from sqlalchemy import and_, or_, select, update from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession +from jose import jwt from src.core.config import settings from src.db.postgres import RefreshTokensStorage @@ -16,9 +17,11 @@ generate_user_signature, hash_password, verify_password, - decode_jwt + decode_jwt, + validate_jwt ) from src.v1.auth.schemas import JWTTokens, User, UserCreate, UserLogin +from src.db.redis import BlacklistSignatureStorage from src.v1.exceptions import ServiceError from src.v1.users.service import UserService from src.v1.users import models as users_models @@ -91,7 +94,7 @@ async def signin( # TODO: Add role to JWT tokens = generate_jwt( - payload={"user_id": str(exists_user.id)}, + payload={"user_id": str(exists_user.id), "username": exists_user.username}, access_jti=exists_user.signature.signature, refresh_jti=uuid.uuid4(), ) @@ -116,8 +119,30 @@ async def logout( async def verify(current_user=None): ... - async def terminate_all_sessions(current_user=None): - ... + @staticmethod + async def terminate_all_sessions( + db_session: AsyncSession, + blacklist_signatures_storage: BlacklistSignatureStorage, + refresh_token_storage: RefreshTokensStorage, + response: Response, + access_token: str + ): + """Terminate all sessions""" + + await validate_jwt(blacklist_signatures_storage, access_token) + response.delete_cookie(settings.sessions_cookie_name) + token_payload = decode_jwt(access_token) + token_headers = jwt.get_unverified_header(access_token) + + username = token_payload.get("username") + user_id = token_payload.get("user_id") + old_user_signature = token_headers.get("jti") + new_user_signature = generate_user_signature(username) + + await blacklist_signatures_storage.create(user_id, old_user_signature) + await __class__._update_user_signature(db_session, old_user_signature, new_user_signature) + + await refresh_token_storage.delete_all(db_session, user_id=user_id) @staticmethod async def _save_login_session_if_not_exists( @@ -166,5 +191,21 @@ async def _set_user_cookie(cookie_key: str, cookie_value: str, response: Respons expires=settings.jwt_access_expire_time_in_seconds ) + @staticmethod + async def _update_user_signature( + db_session: AsyncSession, + old_user_signature: str, + new_user_signature: str + ): + """Update user signature when user terminate all sessions.""" + statement = update(users_models.UserSignature).where( + users_models.UserSignature.signature == old_user_signature + ).values({users_models.UserSignature.signature: new_user_signature}) + await db_session.execute(statement) + try: + await db_session.commit() + except SQLAlchemyError: + await db_session.rollback() + raise ServiceError() AuthService = JWTAuthService() From 79270334784d0e9e437a551286d423f4e56a0621 Mon Sep 17 00:00:00 2001 From: Alexandr Jariuc Date: Tue, 28 Nov 2023 02:31:03 +1100 Subject: [PATCH 2/3] feat(auth_routers): Implement session logout --- requirements.txt | 2 +- src/db/postgres.py | 23 ++++++++-- src/db/redis.py | 8 +++- src/v1/auth/exceptions.py | 13 ++++++ src/v1/auth/helpers.py | 13 +++--- src/v1/auth/schemas.py | 12 +++++ src/v1/auth/service.py | 93 +++++++++++++++++++++++++++++++-------- src/v1/users/service.py | 2 +- 8 files changed, 133 insertions(+), 33 deletions(-) diff --git a/requirements.txt b/requirements.txt index a0f4feb..cca0d85 100755 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ passlib==1.7.4 pydantic==2.4.2 email-validator==2.1.0 pydantic-settings==2.0.3 -aioredis==2.0.1 +redis==5.0.1 fastapi==0.103.2 uvicorn==0.23.2 gunicorn==21.2.0 diff --git a/src/db/postgres.py b/src/db/postgres.py index c60e4b7..433f865 100644 --- a/src/db/postgres.py +++ b/src/db/postgres.py @@ -5,12 +5,13 @@ from fastapi import Depends from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.exc import SQLAlchemyError -from jose import jwt +from jose import jwt, JWTError from pydantic import UUID4 from sqlalchemy import delete from src.core.config import settings from src.v1.exceptions import ServiceError +from src.v1.auth.exceptions import InvalidTokenError from src.db.storages import Database, BaseStorage from src.v1.auth.helpers import decode_jwt from src.v1.users.models import UserRefreshTokens @@ -40,11 +41,15 @@ class PostgresRefreshTokenStorage(BaseStorage): @staticmethod async def create(db_session: AsyncSession, refresh_token: str, user_id: UUID4) -> UUID4: token_headers = jwt.get_unverified_header(refresh_token) - token_data = decode_jwt(refresh_token) + try: + token_payload = decode_jwt(refresh_token) + await __class__._verify_that_token_is_refresh(token_payload) + except JWTError: + raise InvalidTokenError() refresh_token = UserRefreshTokens( token=token_headers.get("jti"), user_id=user_id, - expire_at=datetime.fromtimestamp(token_data.get("exp")), + expire_at=datetime.fromtimestamp(token_payload.get("exp")), ) db_session.add(refresh_token) try: @@ -62,7 +67,12 @@ async def get(db_session: AsyncSession, token: str) -> UserRefreshTokens: @staticmethod async def delete(db_session: AsyncSession, token: str): - decode_jwt(token) + try: + token_payload = decode_jwt(token) + await __class__._verify_that_token_is_refresh(token_payload) + except JWTError: + raise InvalidTokenError() + token_headers = jwt.get_unverified_header(token) token_id = token_headers.get("jti") @@ -83,6 +93,11 @@ async def delete_all(db_session: AsyncSession, user_id: UUID4): except SQLAlchemyError: await db_session.rollback() + @staticmethod + async def _verify_that_token_is_refresh(token_payload: dict): + if len(token_payload.values()) > 1: + raise JWTError() + db_session = PostgresDatabase() refresh_tokens_storage = PostgresRefreshTokenStorage() diff --git a/src/db/redis.py b/src/db/redis.py index 0042a1d..64c872f 100644 --- a/src/db/redis.py +++ b/src/db/redis.py @@ -1,6 +1,6 @@ from typing import Annotated -from aioredis import Redis, from_url +from redis.asyncio import Redis, from_url from pydantic import UUID4 from fastapi import Depends @@ -11,7 +11,11 @@ class RedisBlacklistUserSignatureStorage(BaseStorage): def __init__(self) -> None: - self.protocol: Redis = from_url(settings.redis_dsn, decode_responses=True) + self.protocol: Redis = from_url( + settings.redis_dsn, + decode_responses=True, + db=settings.redis_db + ) self.namespace: str = "auth_service" async def create(self, user_id: UUID4, signature: str): diff --git a/src/v1/auth/exceptions.py b/src/v1/auth/exceptions.py index d71c88c..234829e 100644 --- a/src/v1/auth/exceptions.py +++ b/src/v1/auth/exceptions.py @@ -10,6 +10,7 @@ class AuthExceptionCodes: USER_UNAUTHORIZED: int = 3002 PROVIDED_PASSWORD_INCORRECT: int = 3003 INVALID_PROVIDED_TOKEN: int = 3004 + TOKEN_NOT_FOUND: int = 3005 class UserAlreadyExistsError(HTTPException): @@ -58,3 +59,15 @@ def __init__( ) -> None: detail = {"code": AuthExceptionCodes.INVALID_PROVIDED_TOKEN, "message": message} super().__init__(status_code=status_code, detail=detail) + + +class TokenNotFoundError(HTTPException): + """Error raised whe token doesnt exists in DB.""" + + def __init__( + self, + status_code: int = status.HTTP_404_NOT_FOUND, + message: str = "Invalid token.", + ) -> None: + detail = {"code": AuthExceptionCodes.TOKEN_NOT_FOUND, "message": message} + super().__init__(status_code=status_code, detail=detail) diff --git a/src/v1/auth/helpers.py b/src/v1/auth/helpers.py index 5b9e5c4..04a0f87 100644 --- a/src/v1/auth/helpers.py +++ b/src/v1/auth/helpers.py @@ -8,7 +8,7 @@ from src.db.redis import BlacklistSignatureStorage from src.core.config import settings from src.v1.auth.schemas import JWTTokens -from src.v1.auth.exceptions import InvalidTokenError, UnauthorizedError +from src.v1.auth.exceptions import UnauthorizedError, InvalidTokenError # TODO(alexander.zharyuk): Improve generation. Maybe add some salt? @@ -57,20 +57,19 @@ def generate_jwt(payload: dict, access_jti: str, refresh_jti: UUID4) -> JWTToken def decode_jwt(token: str) -> dict: """Decode access / refresh tokens payload""" - try: - payload = jwt.decode( + return jwt.decode( token, key=settings.jwt_secret_key, algorithms=[settings.jwt_algorithm] ) - except JWTError: - raise InvalidTokenError() - return payload async def validate_jwt(blacklist_tokens_storage: BlacklistSignatureStorage, token: str): """Validate that token is not in blacklists""" - token_payload = decode_jwt(token) + try: + token_payload = decode_jwt(token) + except JWTError: + raise InvalidTokenError() token_headers = jwt.get_unverified_header(token) user_id = token_payload.get("user_id") diff --git a/src/v1/auth/schemas.py b/src/v1/auth/schemas.py index ac02ff6..db59935 100644 --- a/src/v1/auth/schemas.py +++ b/src/v1/auth/schemas.py @@ -29,6 +29,10 @@ class UserLogout(BaseModel): refresh_token: str +class RefreshTokens(UserLogout): + ... + + class User(UserBase): id: UUID4 @@ -51,3 +55,11 @@ class TokensResponse(BaseResponseBody): class LogoutResponse(BaseResponseBody): data: dict = {"sucess": True} + + +class VerifyTokenResponse(BaseResponseBody): + data: dict = {"access": True} + + +class JWTPayload(BaseModel): + user_id: UUID4 diff --git a/src/v1/auth/service.py b/src/v1/auth/service.py index c9ffcfa..117712c 100755 --- a/src/v1/auth/service.py +++ b/src/v1/auth/service.py @@ -3,15 +3,17 @@ from datetime import datetime from fastapi import Request, Response -from pydantic import BaseModel +from pydantic import BaseModel, UUID4 from sqlalchemy import and_, or_, select, update from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession -from jose import jwt +from jose import jwt, JWTError from src.core.config import settings from src.db.postgres import RefreshTokensStorage -from src.v1.auth.exceptions import UserAlreadyExistsError +from src.v1.auth.exceptions import ( + UserAlreadyExistsError, UnauthorizedError, TokenNotFoundError, InvalidTokenError +) from src.v1.auth.helpers import ( generate_jwt, generate_user_signature, @@ -20,7 +22,8 @@ decode_jwt, validate_jwt ) -from src.v1.auth.schemas import JWTTokens, User, UserCreate, UserLogin +from src.v1.auth.schemas import (JWTTokens, User, UserCreate, UserLogin, + VerifyTokenResponse, JWTPayload) from src.db.redis import BlacklistSignatureStorage from src.v1.exceptions import ServiceError from src.v1.users.service import UserService @@ -93,8 +96,9 @@ async def signin( await __class__._save_login_session_if_not_exists(db_session, exists_user, request) # TODO: Add role to JWT + jwt_payload = JWTPayload(user_id=exists_user.id).model_dump(mode="json") tokens = generate_jwt( - payload={"user_id": str(exists_user.id), "username": exists_user.username}, + payload=jwt_payload, access_jti=exists_user.signature.signature, refresh_jti=uuid.uuid4(), ) @@ -115,9 +119,15 @@ async def logout( response.delete_cookie(settings.sessions_cookie_name) await refresh_token_storage.delete(db_session, refresh_token) - - async def verify(current_user=None): - ... + @staticmethod + async def verify(access_token: str, blacklist_signatures_storage: BlacklistSignatureStorage): + """Verfiy that provided access token is not blacklist""" + try: + await validate_jwt(blacklist_signatures_storage, access_token) + data={"access": True} + except UnauthorizedError: + data = {"access": False} + return VerifyTokenResponse(data=data) @staticmethod async def terminate_all_sessions( @@ -131,19 +141,62 @@ async def terminate_all_sessions( await validate_jwt(blacklist_signatures_storage, access_token) response.delete_cookie(settings.sessions_cookie_name) - token_payload = decode_jwt(access_token) + try: + token_payload = decode_jwt(access_token) + except JWTError: + raise InvalidTokenError() + token_headers = jwt.get_unverified_header(access_token) - username = token_payload.get("username") user_id = token_payload.get("user_id") old_user_signature = token_headers.get("jti") - new_user_signature = generate_user_signature(username) await blacklist_signatures_storage.create(user_id, old_user_signature) - await __class__._update_user_signature(db_session, old_user_signature, new_user_signature) + await __class__._update_user_signature(db_session, user_id, old_user_signature) await refresh_token_storage.delete_all(db_session, user_id=user_id) + @staticmethod + async def refresh_tokens( + db_session: AsyncSession, + refresh_token_storage: RefreshTokensStorage, + response: Response, + refresh_token: str + ): + """Regenerate JWT pair of tokens""" + + try: + decode_jwt(refresh_token) + except JWTError: + raise InvalidTokenError() + + refresh_token_headers = jwt.get_unverified_header(refresh_token) + refresh_jti = refresh_token_headers.get("jti") + statement = select(users_models.UserRefreshTokens).where( + users_models.UserRefreshTokens.token == refresh_jti + ) + result = await db_session.execute(statement) + if (token := result.scalar()) is None: + raise TokenNotFoundError() + + user = await UserService.get_by_id(db_session, token.user_id) + await refresh_token_storage.delete(db_session, refresh_token) + + jwt_payload = JWTPayload(user_id=user.id).model_dump(mode="json") + tokens = generate_jwt( + payload=jwt_payload, + access_jti=user.signature.signature, + refresh_jti=str(uuid.uuid4()), + ) + await refresh_token_storage.create(db_session, tokens.refresh_token, user.id) + await __class__._set_user_cookie( + settings.sessions_cookie_name, + tokens.access_token, + response + ) + return tokens + + @staticmethod async def _save_login_session_if_not_exists( db_session: AsyncSession, user: users_models.User, request: Request @@ -193,14 +246,17 @@ async def _set_user_cookie(cookie_key: str, cookie_value: str, response: Respons @staticmethod async def _update_user_signature( - db_session: AsyncSession, - old_user_signature: str, - new_user_signature: str + db_session: AsyncSession, + user_id: UUID4, + old_user_signature: str ): """Update user signature when user terminate all sessions.""" - statement = update(users_models.UserSignature).where( - users_models.UserSignature.signature == old_user_signature - ).values({users_models.UserSignature.signature: new_user_signature}) + user = await UserService.get_by_id(db_session, user_id) + new_user_signature = generate_user_signature(user.username) + + statement = update(users_models.UserSignature)\ + .where(users_models.UserSignature.signature == old_user_signature)\ + .values({users_models.UserSignature.signature: new_user_signature}) await db_session.execute(statement) try: await db_session.commit() @@ -208,4 +264,5 @@ async def _update_user_signature( await db_session.rollback() raise ServiceError() + AuthService = JWTAuthService() diff --git a/src/v1/users/service.py b/src/v1/users/service.py index 821518c..f7174a2 100755 --- a/src/v1/users/service.py +++ b/src/v1/users/service.py @@ -22,7 +22,7 @@ class UserService: async def get_by_email(db_session: AsyncSession, email: EmailStr) -> Type[User]: statement = select(User).where(User.email == email) result = await db_session.execute(statement) - if (exists_user := result.scalar_one()) is None: + if (exists_user := result.scalar()) is None: raise UserNotFoundError() return exists_user From 567889d5d5a8b4e47c6a78181decbe0ece91d3a2 Mon Sep 17 00:00:00 2001 From: Alexandr Jariuc Date: Tue, 28 Nov 2023 05:50:57 +1100 Subject: [PATCH 3/3] fix: add routers --- src/v1/auth/routers.py | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/src/v1/auth/routers.py b/src/v1/auth/routers.py index 387ce01..9a18bae 100755 --- a/src/v1/auth/routers.py +++ b/src/v1/auth/routers.py @@ -4,7 +4,8 @@ from src.db.postgres import DatabaseSession, RefreshTokensStorage from src.db.redis import BlacklistSignatureStorage from src.v1.auth.schemas import (TokensResponse, UserCreate, UserLogin, - UserResponse, LogoutResponse, UserLogout) + UserResponse, LogoutResponse, UserLogout, VerifyTokenResponse, + RefreshTokens) from src.v1.auth.service import AuthService from src.core.config import settings @@ -37,6 +38,40 @@ async def signin( return TokensResponse(data=tokens) +@router.post( + "/verify", + summary="Верификация переданного access_token", + response_model=VerifyTokenResponse +) +async def verify_token( + blacklist_signatures_storage: BlacklistSignatureStorage, + access_token: str | None = Depends(cookie_scheme) +) -> VerifyTokenResponse: + """Выход из всех сессий пользователя""" + return await AuthService.verify(access_token, blacklist_signatures_storage) + + +@router.post( + "/refresh", + summary="Выдача новой пары JWT-токенов", + response_model=TokensResponse +) +async def refresh_tokens( + db_session: DatabaseSession, + refresh_token_storage: RefreshTokensStorage, + response: Response, + data: RefreshTokens, +) -> TokensResponse: + """Выход из всех сессий пользователя""" + tokens = await AuthService.refresh_tokens( + db_session, + refresh_token_storage, + response, + data.refresh_token + ) + return TokensResponse(data=tokens) + + @router.post("/logout", summary="Выход из текущей сессии", response_model=LogoutResponse) async def logout( db_session: DatabaseSession, @@ -66,7 +101,7 @@ async def terminate_all_sessions( refresh_token_storage: RefreshTokensStorage, response: Response, access_token: str | None = Depends(cookie_scheme) -) -> TokensResponse: +) -> LogoutResponse: """Выход из всех сессий пользователя""" await AuthService.terminate_all_sessions( db_session,