From 392de43036cb5ef312e24a8db5eccfbb60a1eff3 Mon Sep 17 00:00:00 2001 From: Bakar Tavadze Date: Wed, 20 Mar 2024 20:09:23 +0400 Subject: [PATCH] Update the get_assistant storage method. --- backend/app/api/runs.py | 9 ++------- backend/app/storage.py | 11 +---------- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/backend/app/api/runs.py b/backend/app/api/runs.py index 7e117807..8ec4b813 100644 --- a/backend/app/api/runs.py +++ b/backend/app/api/runs.py @@ -1,4 +1,3 @@ -import asyncio import json from typing import Optional, Sequence @@ -16,7 +15,7 @@ from app.agent import agent from app.schema import OpengptsUserId -from app.storage import get_assistant, get_public_assistant +from app.storage import get_assistant from app.stream import astream_messages, to_sse router = APIRouter() @@ -35,11 +34,7 @@ async def _run_input_and_config(request: Request, opengpts_user_id: OpengptsUser body = await request.json() except json.JSONDecodeError: raise RequestValidationError(errors=["Invalid JSON body"]) - assistant, public_assistant = await asyncio.gather( - get_assistant(opengpts_user_id, body["assistant_id"]), - get_public_assistant(body["assistant_id"]), - ) - assistant = assistant or public_assistant + assistant = await get_assistant(opengpts_user_id, body["assistant_id"]) if not assistant: raise HTTPException(status_code=404, detail="Assistant not found") config: RunnableConfig = { diff --git a/backend/app/storage.py b/backend/app/storage.py index 4a08132a..ac97045c 100644 --- a/backend/app/storage.py +++ b/backend/app/storage.py @@ -22,21 +22,12 @@ async def get_assistant(user_id: str, assistant_id: str) -> Optional[Assistant]: """Get an assistant by ID.""" async with get_pg_pool().acquire() as conn: return await conn.fetchrow( - "SELECT * FROM assistant WHERE assistant_id = $1 AND user_id = $2", + "SELECT * FROM assistant WHERE assistant_id = $1 AND (user_id = $2 OR public = true)", assistant_id, user_id, ) -async def get_public_assistant(assistant_id: str) -> Optional[Assistant]: - """Get a public assistant by ID.""" - async with get_pg_pool().acquire() as conn: - return await conn.fetchrow( - "SELECT * FROM assistant WHERE assistant_id = $1 AND public = true", - assistant_id, - ) - - async def list_public_assistants(assistant_ids: Sequence[str]) -> List[Assistant]: """List all the public assistants.""" async with get_pg_pool().acquire() as conn: