Skip to content

Commit

Permalink
Update the get_assistant storage method.
Browse files Browse the repository at this point in the history
  • Loading branch information
bakar-io committed Mar 20, 2024
1 parent b17c425 commit 392de43
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 17 deletions.
9 changes: 2 additions & 7 deletions backend/app/api/runs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import json
from typing import Optional, Sequence

Expand All @@ -16,7 +15,7 @@

from app.agent import agent
from app.schema import OpengptsUserId
from app.storage import get_assistant, get_public_assistant
from app.storage import get_assistant
from app.stream import astream_messages, to_sse

router = APIRouter()
Expand All @@ -35,11 +34,7 @@ async def _run_input_and_config(request: Request, opengpts_user_id: OpengptsUser
body = await request.json()
except json.JSONDecodeError:
raise RequestValidationError(errors=["Invalid JSON body"])
assistant, public_assistant = await asyncio.gather(
get_assistant(opengpts_user_id, body["assistant_id"]),
get_public_assistant(body["assistant_id"]),
)
assistant = assistant or public_assistant
assistant = await get_assistant(opengpts_user_id, body["assistant_id"])
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
config: RunnableConfig = {
Expand Down
11 changes: 1 addition & 10 deletions backend/app/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,12 @@ async def get_assistant(user_id: str, assistant_id: str) -> Optional[Assistant]:
"""Get an assistant by ID."""
async with get_pg_pool().acquire() as conn:
return await conn.fetchrow(
"SELECT * FROM assistant WHERE assistant_id = $1 AND user_id = $2",
"SELECT * FROM assistant WHERE assistant_id = $1 AND (user_id = $2 OR public = true)",
assistant_id,
user_id,
)


async def get_public_assistant(assistant_id: str) -> Optional[Assistant]:
"""Get a public assistant by ID."""
async with get_pg_pool().acquire() as conn:
return await conn.fetchrow(
"SELECT * FROM assistant WHERE assistant_id = $1 AND public = true",
assistant_id,
)


async def list_public_assistants(assistant_ids: Sequence[str]) -> List[Assistant]:
"""List all the public assistants."""
async with get_pg_pool().acquire() as conn:
Expand Down

0 comments on commit 392de43

Please sign in to comment.