Skip to content

Commit

Permalink
Progress information for background tasks (#493)
Browse files Browse the repository at this point in the history
* Add progress to search index

* Remove test due to issues with celery/pytest

* Do not fail on unserializable results
  • Loading branch information
DavidMStraub authored Feb 24, 2024
1 parent 420b2a5 commit 7da26ec
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 80 deletions.
9 changes: 9 additions & 0 deletions gramps_webapi/api/resources/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,17 @@ def serialize_or_str(obj):
except TypeError:
return str(obj)

def serializable_or_str(obj):
try:
json.dumps(obj)
return obj
except TypeError:
return str(obj)

return {
"state": task.state,
"result_object": serializable_or_str(task.result),
# kept for backward compatibility
"info": serialize_or_str(task.info),
"result": serialize_or_str(task.result),
}
47 changes: 38 additions & 9 deletions gramps_webapi/api/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,25 @@

"""Full-text search utilities."""

from collections import OrderedDict
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple

from flask import current_app
from gramps.gen.db.base import DbReadBase
from gramps.gen.lib import Name, Place
from gramps.gen.lib.primaryobj import BasicPrimaryObject as GrampsObject
from unidecode import unidecode
from whoosh import index
from whoosh.fields import BOOLEAN, DATETIME, ID, TEXT, Schema
from whoosh.qparser import FieldsPlugin, MultifieldParser, QueryParser
from whoosh.qparser import MultifieldParser, QueryParser
from whoosh.qparser.dateparse import DateParserPlugin
from whoosh.query import Term
from whoosh.searching import Hit
from whoosh.sorting import FieldFacet
from whoosh.writing import AsyncWriter

from ..const import PRIMARY_GRAMPS_OBJECTS
from ..const import GRAMPS_OBJECT_PLURAL, PRIMARY_GRAMPS_OBJECTS
from ..types import FilenameOrPath


Expand Down Expand Up @@ -103,6 +103,13 @@ def obj_strings_from_handle(
"""Return object strings from a handle and Gramps class name."""
query_method = db_handle.method("get_%s_from_handle", class_name)
obj = query_method(handle)
return obj_strings_from_object(db_handle=db_handle, class_name=class_name, obj=obj)


def obj_strings_from_object(
db_handle: DbReadBase, class_name: str, obj: GrampsObject
) -> Optional[Dict[str, Any]]:
"""Return object strings from a handle and Gramps class name."""
obj_string, obj_string_private = object_to_strings(obj)
private = hasattr(obj, "private") and obj.private
if obj_string:
Expand All @@ -122,9 +129,10 @@ def iter_obj_strings(
) -> Generator[Dict[str, Any], None, None]:
"""Iterate over object strings in the whole database."""
for class_name in PRIMARY_GRAMPS_OBJECTS:
iter_method = db_handle.method("iter_%s_handles", class_name)
for handle in iter_method():
obj_strings = obj_strings_from_handle(db_handle, class_name, handle)
plural_name = GRAMPS_OBJECT_PLURAL[class_name]
iter_method = db_handle.method("iter_%s", plural_name)
for obj in iter_method():
obj_strings = obj_strings_from_object(db_handle, class_name, obj)
if obj_strings:
yield obj_strings

Expand Down Expand Up @@ -198,10 +206,31 @@ def _add_obj_strings(self, writer, obj_dict):
"Failed adding object {}".format(obj_dict["handle"])
)

def reindex_full(self, db_handle: DbReadBase):
def _get_total_number_of_objects(self, db_handle):
"""Get the total number of searchable objects in the database."""
return (
db_handle.get_number_of_people()
+ db_handle.get_number_of_families()
+ db_handle.get_number_of_sources()
+ db_handle.get_number_of_citations()
+ db_handle.get_number_of_events()
+ db_handle.get_number_of_media()
+ db_handle.get_number_of_places()
+ db_handle.get_number_of_repositories()
+ db_handle.get_number_of_notes()
+ db_handle.get_number_of_tags()
)

def reindex_full(
self, db_handle: DbReadBase, progress_cb: Optional[Callable] = None
):
"""Reindex the whole database."""
if progress_cb:
total = self._get_total_number_of_objects(db_handle)
with self.index(overwrite=True).writer() as writer:
for obj_dict in iter_obj_strings(db_handle):
for i, obj_dict in enumerate(iter_obj_strings(db_handle)):
if progress_cb:
progress_cb(current=i, total=total)
self._add_obj_strings(writer, obj_dict)

def _get_object_timestamps(self):
Expand Down
28 changes: 23 additions & 5 deletions gramps_webapi/api/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def make_task_response(task: AsyncResult):
return payload, HTTPStatus.ACCEPTED


def clip_progress(x: float) -> float:
"""Clip the progress to [0, 1), else return -1."""
if x < 0 or x >= 1:
return -1
return x


@shared_task()
def send_email_reset_password(email: str, token: str):
"""Send an email for password reset."""
Expand Down Expand Up @@ -94,20 +101,31 @@ def send_email_new_user(
send_email(subject=subject, body=body, to=emails)


def _search_reindex_full(tree) -> None:
def _search_reindex_full(tree, progress_cb: Optional[Callable] = None) -> None:
"""Rebuild the search index."""
indexer = get_search_indexer(tree)
db = get_db_outside_request(tree=tree, view_private=True, readonly=True)
try:
indexer.reindex_full(db)
indexer.reindex_full(db, progress_cb=progress_cb)
finally:
db.close()


@shared_task()
def search_reindex_full(tree) -> None:
@shared_task(bind=True)
def search_reindex_full(self, tree) -> None:
"""Rebuild the search index."""
return _search_reindex_full(tree)

def progress_cb(current: int, total: int):
self.update_state(
state="PROGRESS",
meta={
"current": current,
"total": total,
"progress": clip_progress(current / total),
},
)

return _search_reindex_full(tree, progress_cb=progress_cb)


def _search_reindex_incremental(tree) -> None:
Expand Down
66 changes: 0 additions & 66 deletions tests/test_endpoints/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,69 +1010,3 @@ def test_create_owner(self):
assert get_number_users() == 1
rv = self.client.get(f"{BASE_URL}/token/create_owner/")
assert rv.status_code == 405


@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
class TestUserCelery(unittest.TestCase):
"""Test cases for the /api/user endpoints, using celery."""

def setUp(self):
self.name = "Test Web API"
self.dbman = CLIDbManager(DbState())
dbpath, _name = self.dbman.create_new_db_cli(self.name, dbid="sqlite")
self.tree = os.path.basename(dbpath)
with patch.dict("os.environ", {ENV_CONFIG_FILE: TEST_AUTH_CONFIG}):
self.app = create_app(
config={
"TESTING": True,
"RATELIMIT_ENABLED": False,
"CELERY_CONFIG": {
"broker_url": "redis://",
"result_backend": "redis://",
},
}
)
self.client = self.app.test_client()
with self.app.app_context():
user_db.create_all()
add_user(
name="user",
password="123",
email="[email protected]",
role=ROLE_MEMBER,
tree=self.tree,
)
add_user(
name="owner",
password="123",
email="[email protected]",
role=ROLE_OWNER,
tree=self.tree,
)
self.assertTrue(self.app.testing)
self.ctx = self.app.test_request_context()
self.ctx.push()

def tearDown(self):
self.ctx.pop()
self.dbman.remove_database(self.name)

def test_reset_password_celery(self):
rv = self.client.post(BASE_URL + "/users/user/password/reset/trigger/")
assert rv.status_code == 202
assert "task" in rv.json
url = rv.json["task"]["href"]
rv = self.client.post(
BASE_URL + "/token/", json={"username": "user", "password": "123"}
)
assert rv.status_code == 200
token = rv.json["access_token"]
from time import sleep

sleep(5)
rv = self.client.get(
url,
headers={"Authorization": f"Bearer {token}"},
)
assert rv.status_code == 200

0 comments on commit 7da26ec

Please sign in to comment.