Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enhanced the get product endpoint with category filter #1222

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
10 changes: 7 additions & 3 deletions api/utils/pagination.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Optional
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, Query
from api.db.database import Base

from api.utils.success_response import success_response
Expand All @@ -12,6 +12,7 @@ def paginated_response(
skip: int,
limit: int,
join: Optional[Any] = None,
query: Optional[Query] = None,
filters: Optional[Dict[str, Any]]=None
):

Expand All @@ -24,6 +25,7 @@ def paginated_response(
* skip- this is the number of items to skip before fetching the next page of data. This would also
be a query parameter
* join- this is an optional argument to join a table to the query
* query- this is an optional custom query to use instead of querying all items from the model.
* filters- this is an optional dictionary of filters to apply to the query

Example use:
Expand Down Expand Up @@ -61,7 +63,8 @@ def paginated_response(
```
'''

query = db.query(model)
if query is None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when you set this condition what happens is query is not None?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohkay, your test fails

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If query is none, then it gets all products;
if query is not none, it gets the product related to that category.

query = db.query(model)

if join is not None:
query = query.join(join)
Expand All @@ -82,7 +85,8 @@ def paginated_response(

total = query.count()
results = jsonable_encoder(query.offset(skip).limit(limit).all())
total_pages = int(total / limit) + (total % limit > 0)
# total_pages = int(total / limit) + (total % limit > 0)
total_pages = (total + limit - 1) // limit

return success_response(
status_code=200,
Expand Down
17 changes: 14 additions & 3 deletions api/v1/routes/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from api.utils.pagination import paginated_response
from api.utils.success_response import success_response
from api.db.database import get_db
from api.v1.models.product import Product, ProductFilterStatusEnum, ProductStatusEnum
from api.v1.models.product import Product, ProductCategory, ProductFilterStatusEnum, ProductStatusEnum
from api.v1.services.product import product_service, ProductCategoryService
from api.v1.schemas.product import (
ProductCategoryCreate,
Expand Down Expand Up @@ -37,11 +37,22 @@ async def get_all_products(
ge=1, description="Number of products per page")] = 10,
skip: Annotated[int, Query(
ge=1, description="Page number (starts from 1)")] = 0,
category: Annotated[Optional[str], Query(
description="Filter products by category name")] = None,
db: Session = Depends(get_db),
):
"""Endpoint to get all products. Only accessible to superadmin"""
"""
Endpoint to get all products. Only accessible to superadmin.
Optionally filter products by category.
"""
# Base query
query = db.query(Product)

# Apply category filter if provided
if category:
query = query.join(Product.category).filter(ProductCategory.name.ilike(f"%{category}%"))

return paginated_response(db=db, model=Product, limit=limit, skip=skip)
return paginated_response(db=db, model=Product, limit=limit, skip=skip, query=query)


# categories
Expand Down
209 changes: 209 additions & 0 deletions tests/v1/product/test_get_product_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from unittest.mock import MagicMock
from uuid_extensions import uuid7
from datetime import datetime, timezone, timedelta

from api.v1.models.organisation import Organisation
from api.v1.models.product import Product, ProductCategory
from api.v1.models.user import User
from main import app
from api.v1.routes.blog import get_db
from api.v1.services.user import user_service


# Mock database dependency
@pytest.fixture
def db_session_mock():
db_session = MagicMock(spec=Session)
return db_session


@pytest.fixture
def client(db_session_mock):
app.dependency_overrides[get_db] = lambda: db_session_mock
client = TestClient(app)
yield client
app.dependency_overrides = {}


# Mock user service dependency

user_id = uuid7()
org_id = uuid7()
product_id = uuid7()
category_id = uuid7()
timezone_offset = -8.0
tzinfo = timezone(timedelta(hours=timezone_offset))
timeinfo = datetime.now(tzinfo)
created_at = timeinfo
updated_at = timeinfo
access_token = user_service.create_access_token(str(user_id))
access_token2 = user_service.create_access_token(str(uuid7()))

# Create test user

user = User(
id=str(user_id),
email="[email protected]",
password="password123",
created_at=created_at,
updated_at=updated_at,
)

# Create test organisation

org = Organisation(
id=str(org_id),
name="hng",
email=None,
industry=None,
type=None,
country=None,
state=None,
address=None,
description=None,
created_at=created_at,
updated_at=updated_at,
)

# Create test category

category = ProductCategory(id=category_id, name="Electronics")

# Create test product

product = Product(
id=str(product_id),
name="prod one",
description="Test product",
price=125.55,
org_id=str(org_id),
quantity=50,
image_url="http://img",
category_id=str(category_id),
status="in_stock",
archived=False,
)


# Mock data for multiple products
products = [
Product(
id=str(uuid7()),
name="Smartphone",
description="A smartphone",
price=500.00,
org_id=str(org_id),
quantity=10,
image_url="http://img1",
category_id=str(category_id),
status="in_stock",
archived=False,
),
Product(
id=str(uuid7()),
name="Laptop",
description="A laptop",
price=1200.00,
org_id=str(org_id),
quantity=5,
image_url="http://img2",
category_id=str(category_id),
status="in_stock",
archived=False,
),
Product(
id=str(uuid7()),
name="T-Shirt",
description="A T-Shirt",
price=20.00,
org_id=str(org_id),
quantity=100,
image_url="http://img3",
category_id=str(uuid7()), # Different category
status="in_stock",
archived=False,
),
]


def test_get_products_filtered_by_category(client, db_session_mock):
# Mock the database query to return filtered products
db_session_mock.query().join().filter().offset().limit().all.return_value = [
products[0], products[1]]
db_session_mock.query().join().filter().count.return_value = 2 # Return an integer

headers = {"authorization": f"Bearer {access_token}"}
response = client.get(
"/api/v1/products?category=Electronics",
headers=headers
)

assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert len(data["data"]["items"]) == 2


def test_get_all_products_without_filter(client, db_session_mock):
# Mock the database query to return all products
db_session_mock.query().offset().limit().all.return_value = products
db_session_mock.query().count.return_value = 3

headers = {"authorization": f"Bearer {access_token}"}
response = client.get(
"/api/v1/products",
headers=headers
)

assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert len(data["data"]["items"]) == 3


def test_unauthorized_access(client, db_session_mock):
# Test unauthorized access (missing or invalid token)
response = client.get("/api/v1/products")
assert response.status_code == 401
assert response.json() == {
"status": False,
"status_code": 401,
"message": "Not authenticated"
}


def test_invalid_category_name(client, db_session_mock):
# Mock the database query to return no products for an invalid category
db_session_mock.query().join().filter().offset().limit().all.return_value = []
db_session_mock.query().join().filter().count.return_value = 0 # Return an integer

headers = {"authorization": f"Bearer {access_token}"}
response = client.get(
"/api/v1/products?category=InvalidCategory",
headers=headers
)

assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert len(data["data"]["items"]) == 0


def test_empty_results_for_valid_category(client, db_session_mock):
# Mock the database query to return no products for a valid but unused category
db_session_mock.query().join().filter().offset().limit().all.return_value = []
db_session_mock.query().join().filter().count.return_value = 0 # Return an integer

headers = {"authorization": f"Bearer {access_token}"}
response = client.get(
"/api/v1/products?category=Furniture",
headers=headers
)

assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert len(data["data"]["items"]) == 0
Loading