From a081dce7d52098c89733cb429a5c3e544d5cfd95 Mon Sep 17 00:00:00 2001 From: Sayo Date: Tue, 11 Feb 2025 13:46:27 +0530 Subject: [PATCH 1/3] refactor + fixes to chat bubble ui + lint fixes + cleanup (#3437) --- packages/agent/src/helper.ts | 89 +++++++++ packages/agent/src/server.ts | 96 ++-------- packages/client/src/components/chat.tsx | 171 +++++++++--------- .../src/components/ui/chat/chat-bubble.tsx | 4 +- packages/plugin-local-ai/src/index.ts | 28 +-- 5 files changed, 204 insertions(+), 184 deletions(-) create mode 100644 packages/agent/src/helper.ts diff --git a/packages/agent/src/helper.ts b/packages/agent/src/helper.ts new file mode 100644 index 00000000000..e0120f7ce8b --- /dev/null +++ b/packages/agent/src/helper.ts @@ -0,0 +1,89 @@ +import { messageCompletionFooter } from "@elizaos/core"; +import path from "node:path"; +import multer from "multer"; +import fs from "node:fs"; + + + +export const messageHandlerTemplate = + // {{goals}} + // "# Action Examples" is already included + `{{actionExamples}} +(Action examples are for reference only. Do not use the information from them in your response.) + +# Knowledge +{{knowledge}} + +# Task: Generate dialog and actions for the character {{agentName}}. +About {{agentName}}: +{{bio}} +{{lore}} + +{{providers}} + +{{attachments}} + +# Capabilities +Note that {{agentName}} is capable of reading/seeing/hearing various forms of media, including images, videos, audio, plaintext and PDFs. Recent attachments have been included above under the "Attachments" section. + +{{messageDirections}} + +{{recentMessages}} + +{{actions}} + +# Instructions: Write the next message for {{agentName}}. +${messageCompletionFooter}`; + +export const hyperfiHandlerTemplate = `{{actionExamples}} +(Action examples are for reference only. Do not use the information from them in your response.) + +# Knowledge +{{knowledge}} + +# Task: Generate dialog and actions for the character {{agentName}}. +About {{agentName}}: +{{bio}} +{{lore}} + +{{providers}} + +{{attachments}} + +# Capabilities +Note that {{agentName}} is capable of reading/seeing/hearing various forms of media, including images, videos, audio, plaintext and PDFs. Recent attachments have been included above under the "Attachments" section. + +{{messageDirections}} + +{{recentMessages}} + +{{actions}} + +# Instructions: Write the next message for {{agentName}}. + +Response format should be formatted in a JSON block like this: +\`\`\`json +{ "lookAt": "{{nearby}}" or null, "emote": "{{emotes}}" or null, "say": "string" or null, "actions": (array of strings) or null } +\`\`\` +`; + + + + +export const storage = multer.diskStorage({ + destination: (req, file, cb) => { + const uploadDir = path.join(process.cwd(), "data", "uploads"); + // Create the directory if it doesn't exist + if (!fs.existsSync(uploadDir)) { + fs.mkdirSync(uploadDir, { recursive: true }); + } + cb(null, uploadDir); + }, + filename: (req, file, cb) => { + const uniqueSuffix = `${Date.now()}-${Math.round(Math.random() * 1e9)}`; + cb(null, `${uniqueSuffix}-${file.originalname}`); + }, +}); + +// some people have more memory than disk.io +export const upload = multer({ storage /*: multer.memoryStorage() */ }); diff --git a/packages/agent/src/server.ts b/packages/agent/src/server.ts index 302b9d9c378..827b9b56f8f 100644 --- a/packages/agent/src/server.ts +++ b/packages/agent/src/server.ts @@ -9,105 +9,31 @@ import { type Content, type Media, type Memory, - type IAgentRuntime + type IAgentRuntime, + type Character } from "@elizaos/core"; import bodyParser from "body-parser"; import cors from "cors"; import express, { type Request as ExpressRequest } from "express"; -import multer from "multer"; + import * as fs from "node:fs"; import * as path from "node:path"; import { z } from "zod"; import { createApiRouter } from "./api.ts"; import replyAction from "./reply.ts"; +import { messageHandlerTemplate } from "./helper.ts"; +import { upload } from "./helper.ts"; -const storage = multer.diskStorage({ - destination: (req, file, cb) => { - const uploadDir = path.join(process.cwd(), "data", "uploads"); - // Create the directory if it doesn't exist - if (!fs.existsSync(uploadDir)) { - fs.mkdirSync(uploadDir, { recursive: true }); - } - cb(null, uploadDir); - }, - filename: (req, file, cb) => { - const uniqueSuffix = `${Date.now()}-${Math.round(Math.random() * 1e9)}`; - cb(null, `${uniqueSuffix}-${file.originalname}`); - }, -}); - -// some people have more memory than disk.io -const upload = multer({ storage /*: multer.memoryStorage() */ }); - -export const messageHandlerTemplate = - // {{goals}} - // "# Action Examples" is already included - `{{actionExamples}} -(Action examples are for reference only. Do not use the information from them in your response.) - -# Knowledge -{{knowledge}} - -# Task: Generate dialog and actions for the character {{agentName}}. -About {{agentName}}: -{{bio}} -{{lore}} - -{{providers}} - -{{attachments}} - -# Capabilities -Note that {{agentName}} is capable of reading/seeing/hearing various forms of media, including images, videos, audio, plaintext and PDFs. Recent attachments have been included above under the "Attachments" section. - -{{messageDirections}} - -{{recentMessages}} - -{{actions}} - -# Instructions: Write the next message for {{agentName}}. -${messageCompletionFooter}`; - -export const hyperfiHandlerTemplate = `{{actionExamples}} -(Action examples are for reference only. Do not use the information from them in your response.) - -# Knowledge -{{knowledge}} - -# Task: Generate dialog and actions for the character {{agentName}}. -About {{agentName}}: -{{bio}} -{{lore}} - -{{providers}} - -{{attachments}} - -# Capabilities -Note that {{agentName}} is capable of reading/seeing/hearing various forms of media, including images, videos, audio, plaintext and PDFs. Recent attachments have been included above under the "Attachments" section. - -{{messageDirections}} - -{{recentMessages}} - -{{actions}} -# Instructions: Write the next message for {{agentName}}. -Response format should be formatted in a JSON block like this: -\`\`\`json -{ "lookAt": "{{nearby}}" or null, "emote": "{{emotes}}" or null, "say": "string" or null, "actions": (array of strings) or null } -\`\`\` -`; export class CharacterServer { public app: express.Application; private agents: Map; // container management private server: any; // Store server instance - public startAgent: Function; // Store startAgent functor - public loadCharacterTryPath: Function; // Store loadCharacterTryPath functor - public jsonToCharacter: Function; // Store jsonToCharacter functor + public startAgent: () => Promise; // Store startAgent function + public loadCharacterTryPath: (characterPath: string) => Promise; // Store loadCharacterTryPath function + public jsonToCharacter: (filePath: string, character: string | never) => Promise; // Store jsonToCharacter function constructor() { logger.log("DirectClient constructor"); @@ -177,7 +103,7 @@ export class CharacterServer { async (req: express.Request, res: express.Response) => { const agentId = req.params.agentId; const roomId = stringToUuid( - req.body.roomId ?? "default-room-" + agentId + req.body.roomId ?? `default-room-${agentId}` ); const userId = stringToUuid(req.body.userId ?? "user"); @@ -248,7 +174,7 @@ export class CharacterServer { }; const memory: Memory = { - id: stringToUuid(messageId + "-" + userId), + id: stringToUuid(`${messageId}-${userId}`), ...userMessage, agentId: runtime.agentId, userId, @@ -494,7 +420,7 @@ export class CharacterServer { if (hfOut.lookAt !== null || hfOut.emote !== null) { contentObj.text += ". Then I "; if (hfOut.lookAt !== null) { - contentObj.text += "looked at " + hfOut.lookAt; + contentObj.text += `looked at ${hfOut.lookAt}`; if (hfOut.emote !== null) { contentObj.text += " and "; } diff --git a/packages/client/src/components/chat.tsx b/packages/client/src/components/chat.tsx index f470063d152..675ff6560fd 100644 --- a/packages/client/src/components/chat.tsx +++ b/packages/client/src/components/chat.tsx @@ -17,7 +17,7 @@ import { useEffect, useRef, useState } from "react"; import AIWriter from "react-aiwriter"; import { AudioRecorder } from "./audio-recorder"; import CopyButton from "./copy-button"; -import { Avatar, AvatarImage } from "./ui/avatar"; +import { Avatar, AvatarFallback, AvatarImage } from "./ui/avatar"; import { Badge } from "./ui/badge"; import ChatTtsButton from "./ui/chat/chat-tts-button"; import { useAutoScroll } from "./ui/chat/hooks/useAutoScroll"; @@ -31,6 +31,70 @@ type ExtraContentFields = { type ContentWithUser = Content & ExtraContentFields; +function MessageContent({ + message, + agentId, +}: { + message: ContentWithUser; + agentId: UUID; +}) { + return ( +
+ + {message.user === "user" ? message.text : {message.text}} + {/* Attachments */} +
+ {message.attachments?.map((attachment: IAttachment) => ( +
+ attachment +
+ + +
+
+ ))} +
+
+
+ {message.text && !message.isLoading ? ( +
+ + +
+ ) : null} +
+ {message.source ? ( + {message.source} + ) : null} + {message.action ? ( + {message.action} + ) : null} + {message.createdAt ? ( + + ) : null} +
+
+
+ ); +} + export default function Page({ agentId }: { agentId: UUID }) { const { toast } = useToast(); const [selectedFile, setSelectedFile] = useState(null); @@ -83,7 +147,7 @@ export default function Page({ agentId }: { agentId: UUID }) { const newMessages = [ { text: input, - user: "{{user1}}", + user: "user", createdAt: Date.now(), attachments, }, @@ -163,9 +227,9 @@ export default function Page({ agentId }: { agentId: UUID }) { disableAutoScroll={disableAutoScroll} > {messages.map((message: ContentWithUser) => { - const variant = getMessageVariant(message?.user); return (
- {message?.user !== "user" ? ( - - - - ) : null} -
- - {message?.user !== "user" ? ( - - {message?.text} - - ) : ( - message?.text - )} - {/* Attachments */} -
- {message?.attachments?.map( - (attachment: IAttachment) => ( -
- attachment -
- - -
-
- ) - )} -
-
-
- {message?.text && - !message?.isLoading ? ( -
- - -
- ) : null} -
- {message?.source ? ( - - {message.source} - - ) : null} - {message?.action ? ( - - {message.action} - - ) : null} - {message?.createdAt ? ( - - ) : null} -
-
-
+ {message.user !== "user" ? ( + <> + + + + + + ) : ( + <> + + + + + U + + + + )}
); diff --git a/packages/client/src/components/ui/chat/chat-bubble.tsx b/packages/client/src/components/ui/chat/chat-bubble.tsx index 30c31dbcffa..ac96bc95a19 100644 --- a/packages/client/src/components/ui/chat/chat-bubble.tsx +++ b/packages/client/src/components/ui/chat/chat-bubble.tsx @@ -76,8 +76,8 @@ const chatBubbleMessageVariants = cva("p-4", { variants: { variant: { received: - "bg-secondary text-secondary-foreground rounded-r-lg rounded-tl-lg", - sent: "bg-primary text-primary-foreground rounded-l-lg rounded-tr-lg", + "bg-secondary text-secondary-foreground rounded-lg rounded-bl-none", + sent: "bg-primary text-primary-foreground rounded-lg rounded-br-none" }, layout: { default: "", diff --git a/packages/plugin-local-ai/src/index.ts b/packages/plugin-local-ai/src/index.ts index 8b37c9f818d..a1810b7ad8a 100644 --- a/packages/plugin-local-ai/src/index.ts +++ b/packages/plugin-local-ai/src/index.ts @@ -10,10 +10,10 @@ import { RawImage, type Tensor, } from "@huggingface/transformers"; -import { exec } from "child_process"; +import { exec } from "node:child_process"; import * as Echogarden from "echogarden"; import { EmbeddingModel, FlagEmbedding } from "fastembed"; -import fs from "fs"; +import fs from "node:fs"; import { getLlama, type Llama, @@ -24,11 +24,11 @@ import { type LlamaModel } from "node-llama-cpp"; import { nodewhisper } from "nodejs-whisper"; -import os from "os"; -import path from "path"; -import { PassThrough, Readable } from "stream"; -import { fileURLToPath } from "url"; -import { promisify } from "util"; +import os from "node:os"; +import path from "node:path"; +import { PassThrough, Readable } from "node:stream"; +import { fileURLToPath } from "node:url"; +import { promisify } from "node:util"; import { z } from "zod"; const execAsync = promisify(exec); @@ -285,9 +285,11 @@ class LocalAIManager { if (!this.sequence) { throw new Error("LLaMA model not initialized"); } - const session = new LlamaChatSession({ contextSequence: this.sequence }); - const wordsToPunishTokens = wordsToPunish.flatMap((word) => this.model!.tokenize(word)); + if (!this.model) { + throw new Error("Model is not initialized"); + } + const wordsToPunishTokens = wordsToPunish.flatMap((word) => this.model.tokenize(word)); const repeatPenalty: LlamaChatSessionRepeatPenalty = { punishTokensFilter: () => wordsToPunishTokens, @@ -447,9 +449,11 @@ export const localAIPlugin: Plugin = { async init(config: Record) { try { const validatedConfig = await configSchema.parseAsync(config); - Object.entries(validatedConfig).forEach(([key, value]) => { - if (value) process.env[key] = value; - }); + for (const [key, value] of Object.entries(validatedConfig)) { + if (value) { + process.env[key] = value; + } + } await localAIManager.initialize(); } catch (error) { From 3251fdc26ff218c5db6e08eaa6c7fc0fadc91af0 Mon Sep 17 00:00:00 2001 From: Sayo Date: Tue, 11 Feb 2025 15:35:51 +0530 Subject: [PATCH 2/3] chore: db path resolution (#3438) * db path resolution * Update resolve-database-path.ts --- packages/cli/src/commands/character.ts | 129 ++++++++---------- .../cli/src/utils/resolve-database-path.ts | 81 +++++++++++ 2 files changed, 138 insertions(+), 72 deletions(-) create mode 100644 packages/cli/src/utils/resolve-database-path.ts diff --git a/packages/cli/src/commands/character.ts b/packages/cli/src/commands/character.ts index 1e336e2c081..81d02cf7692 100644 --- a/packages/cli/src/commands/character.ts +++ b/packages/cli/src/commands/character.ts @@ -1,16 +1,16 @@ // src/commands/agent.ts -import { MessageExampleSchema } from "@elizaos/core" -import type { MessageExample } from "@elizaos/core"; -import { Command } from "commander" -import prompts from "prompts" -import { logger } from "../utils/logger" -import { z } from "zod" -import { getConfig } from "../utils/get-config" -import { handleError } from "../utils/handle-error" -import { Database, SqliteDatabaseAdapter } from "@elizaos-plugins/sqlite" -import { promises as fs } from "node:fs" +import { Database, SqliteDatabaseAdapter } from "@elizaos-plugins/sqlite"; +import type { MessageExample, UUID } from "@elizaos/core"; +import { MessageExampleSchema } from "@elizaos/core"; +import { Command } from "commander"; +import fs from "node:fs"; +import prompts from "prompts"; import { v4 as uuid } from "uuid"; -import type { UUID } from "@elizaos/core"; +import { z } from "zod"; +import { getConfig } from "../utils/get-config"; +import { handleError } from "../utils/handle-error"; +import { logger } from "../utils/logger"; +import { resolveDatabasePath } from "../utils/resolve-database-path"; const characterSchema = z.object({ id: z.string().uuid(), @@ -46,9 +46,10 @@ async function collectCharacterData( let currentStep = 0; const steps = ['name', 'bio', 'lore', 'adjectives', 'postExamples', 'messageExamples']; + let response: { value?: string }; + while (currentStep < steps.length) { const field = steps[currentStep]; - let response; switch (field) { case 'name': @@ -140,31 +141,26 @@ character .description("list all characters") .action(async () => { try { - const cwd = process.cwd() - const config = await getConfig(cwd) - if (!config) { - logger.error("No project.json found. Please run init first.") - process.exit(1) - } - - const db = new Database((config.database.config as { path: string }).path) - const adapter = new SqliteDatabaseAdapter(db) - await adapter.init() + const dbPath = await resolveDatabasePath({ requiredConfig: false }); + + const db = new Database(dbPath); + const adapter = new SqliteDatabaseAdapter(db); + await adapter.init(); - const characters = await adapter.listCharacters() + const characters = await adapter.listCharacters(); if (characters.length === 0) { - logger.info("No characters found") + logger.info("No characters found"); } else { - logger.info("\nCharacters:") + logger.info("\nCharacters:"); for (const character of characters) { - logger.info(` ${character.name} (${character.id})`) + logger.info(` ${character.name} (${character.id})`); } } - await adapter.close() + await adapter.close(); } catch (error) { - handleError(error) + handleError(error); } }) @@ -173,14 +169,6 @@ character .description("create a new character") .action(async () => { try { - const cwd = process.cwd() - const config = await getConfig(cwd) - if (!config) { - logger.error("No project.json found. Please run init first.") - process.exit(1) - } - - logger.info("\nCreating new character (type 'back' or 'forward' to navigate)") const formData = await collectCharacterData() if (!formData) { @@ -188,7 +176,8 @@ character return } - const db = new Database((config.database.config as { path: string }).path) + const dbPath = await resolveDatabasePath({ requiredConfig: true }); + const db = new Database(dbPath); const adapter = new SqliteDatabaseAdapter(db) await adapter.init() @@ -213,8 +202,8 @@ character const characterToCreate = { ...characterData, - messageExamples: characterData.messageExamples.map( - (msgArr: any) => msgArr.map((msg: any) => ({ + messageExamples: (characterData.messageExamples as MessageExample[][]).map( + (msgArr: MessageExample[]): MessageExample[] => msgArr.map((msg: MessageExample) => ({ user: msg.user || "unknown", content: msg.content })) @@ -252,21 +241,15 @@ character .argument("", "character ID") .action(async (characterId) => { try { - const cwd = process.cwd() - const config = await getConfig(cwd) - if (!config) { - logger.error("No project.json found. Please run init first.") - process.exit(1) - } + const dbPath = await resolveDatabasePath({ requiredConfig: true }); + const db = new Database(dbPath); + const adapter = new SqliteDatabaseAdapter(db); + await adapter.init(); - const db = new Database((config.database.config as { path: string }).path) - const adapter = new SqliteDatabaseAdapter(db) - await adapter.init() - - const existingCharacter = await adapter.getCharacter(characterId) + const existingCharacter = await adapter.getCharacter(characterId); if (!existingCharacter) { - logger.error(`Character ${characterId} not found`) - process.exit(1) + logger.error(`Character ${characterId} not found`); + process.exit(1); } logger.info(`\nEditing character ${existingCharacter.name} (type 'back' or 'forward' to navigate)`) @@ -277,8 +260,8 @@ character lore: existingCharacter.lore || [], adjectives: existingCharacter.adjectives || [], postExamples: existingCharacter.postExamples || [], - messageExamples: (existingCharacter.messageExamples || []).map( - (msgArr: any) => msgArr.map((msg: any) => ({ + messageExamples: (existingCharacter.messageExamples || [] as MessageExample[][]).map( + (msgArr: MessageExample[]): MessageExample[] => msgArr.map((msg: MessageExample) => ({ user: msg.user ?? "unknown", content: msg.content })) @@ -320,19 +303,27 @@ character .command("import") .description("import a character from file") .argument("", "JSON file path") - .action(async (file) => { + .action(async (fileArg) => { try { - const cwd = process.cwd() - const config = await getConfig(cwd) - if (!config) { - logger.error("No project.json found. Please run init first.") - process.exit(1) + // Use the provided argument if available; otherwise, prompt the user. + const filePath: string = fileArg || (await prompts({ + type: "text", + name: "file", + message: "Enter the path to the Character JSON file", + })).file; + + if (!filePath) { + logger.info("Import cancelled") + return } - - const characterData = JSON.parse(await fs.readFile(file, "utf8")) + + const characterData = JSON.parse(await fs.promises.readFile(filePath, "utf8")) const character = characterSchema.parse(characterData) - const db = new Database((config.database.config as { path: string }).path) + // resolve database path + const dbPath = await resolveDatabasePath({ requiredConfig: true }) + + const db = new Database(dbPath) const adapter = new SqliteDatabaseAdapter(db) await adapter.init() @@ -386,7 +377,7 @@ character } const outputPath = opts.output || `${character.name}.json` - await fs.writeFile(outputPath, JSON.stringify(character, null, 2)) + await fs.promises.writeFile(outputPath, JSON.stringify(character, null, 2)) logger.success(`Exported character to ${outputPath}`) await adapter.close() @@ -401,14 +392,8 @@ character .argument("", "character ID") .action(async (characterId) => { try { - const cwd = process.cwd() - const config = await getConfig(cwd) - if (!config) { - logger.error("No project.json found. Please run init first.") - process.exit(1) - } - - const db = new Database((config.database.config as { path: string }).path) + const dbPath = await resolveDatabasePath({ requiredConfig: true }) + const db = new Database(dbPath) const adapter = new SqliteDatabaseAdapter(db) await adapter.init() diff --git a/packages/cli/src/utils/resolve-database-path.ts b/packages/cli/src/utils/resolve-database-path.ts new file mode 100644 index 00000000000..4831b22f876 --- /dev/null +++ b/packages/cli/src/utils/resolve-database-path.ts @@ -0,0 +1,81 @@ +import { execaCommand } from "execa"; +import fs from "node:fs"; +import path from "node:path"; +import prompts from "prompts"; +import { getConfig } from "./get-config"; +import { logger } from "./logger"; + +// Helper function to search for db.sqlite using available shell commands. +async function searchDatabaseFile(): Promise { + const commands = ["find . -name db.sqlite", "dir /s /b db.sqlite"]; + for (const cmd of commands) { + try { + const { stdout } = await execaCommand(cmd); + const result = stdout.trim(); + if (result) { + return result; + } + } catch (error) { + logger.error(`Error executing command "${cmd}":`, error); + } + } + return ""; +} + +export async function resolveDatabasePath(options?: { requiredConfig?: boolean }): Promise { + const { requiredConfig = true } = options || {}; + const cwd = process.cwd(); + const config = await getConfig(cwd); + + // If a project config exists, use its database path. + if (config) { + return (config.database.config as { path: string }).path; + } + + // For commands that require an initialized project, exit early. + if (requiredConfig) { + logger.error("No project.json found. Please run init first."); + process.exit(1); + } + + // Otherwise, try to locate db.sqlite using shell commands. + let dbPath = await searchDatabaseFile(); + + // If path is resolved, log it. + if (dbPath) { + logger.info(`Resolved database path: ${dbPath}`); + } + + // If no database file was found, prompt the user to provide one. + if (!dbPath) { + logger.info("No db.sqlite found. Please provide a path to create a database."); + const dbInput = await prompts({ + type: "text", + name: "value", + message: "Enter path to create a database:", + }); + + // check if path was provided + if (!dbInput.value) { + logger.error("No path provided. Please provide a path to create a database."); + process.exit(1); + } + + // check if the path is valid + if (!fs.existsSync(dbInput.value)) { + logger.error("Invalid path. Please provide a valid path or directory."); + process.exit(1); + } + + // Use the input directly if it contains ".sqlite", otherwise append it. + dbPath = dbInput.value.includes(".sqlite") ? dbInput.value : path.join(dbInput.value, "db.sqlite"); + + // Create the file if it does not exist. + if (!fs.existsSync(dbPath)) { + await fs.promises.writeFile(dbPath, ""); + logger.success(`Created db.sqlite in ${dbInput.value}`); + } + } + + return dbPath; +} \ No newline at end of file From 08dd74ba65aa2d7113664f899cd29745876e4070 Mon Sep 17 00:00:00 2001 From: Sayo Date: Tue, 11 Feb 2025 21:55:36 +0530 Subject: [PATCH 3/3] add tests (#3445) --- .../__tests__/sqlite-adapter.test.ts | 337 +++++++++++++++++- 1 file changed, 331 insertions(+), 6 deletions(-) diff --git a/packages/plugin-sqlite/__tests__/sqlite-adapter.test.ts b/packages/plugin-sqlite/__tests__/sqlite-adapter.test.ts index 315228a31d7..e94a9607b35 100644 --- a/packages/plugin-sqlite/__tests__/sqlite-adapter.test.ts +++ b/packages/plugin-sqlite/__tests__/sqlite-adapter.test.ts @@ -1,14 +1,15 @@ -import type { UUID } from '@elizaos/core'; +import type { Account, Actor, Character, UUID } from '@elizaos/core'; +import { stringToUuid } from '@elizaos/core'; +import type { Database } from 'better-sqlite3'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { SqliteDatabaseAdapter } from '../src'; import { load } from '../src/sqlite_vec'; -import type { Database } from 'better-sqlite3'; // Mock the logger vi.mock('@elizaos/core', async () => { - const actual = await vi.importActual('@elizaos/core'); + const actual = await vi.importActual('@elizaos/core'); return { - ...actual as any, + ...actual, logger: { error: vi.fn() } @@ -20,9 +21,17 @@ vi.mock('../src/sqlite_vec', () => ({ load: vi.fn() })); + +interface MockDatabase { + prepare: ReturnType; + exec: ReturnType; + close: ReturnType; +} + describe('SqliteDatabaseAdapter', () => { let adapter: SqliteDatabaseAdapter; - let mockDb: any; + let mockDb: MockDatabase; + const testUuid = stringToUuid('test-character-id'); beforeEach(() => { // Create mock database methods @@ -38,7 +47,7 @@ describe('SqliteDatabaseAdapter', () => { }; // Initialize adapter with mock db - adapter = new SqliteDatabaseAdapter(mockDb as Database); + adapter = new SqliteDatabaseAdapter(mockDb as unknown as Database); }); afterEach(() => { @@ -168,4 +177,320 @@ describe('SqliteDatabaseAdapter', () => { expect(mockDb.close).toHaveBeenCalled(); }); }); + + describe('createAccount', () => { + it('should create an account successfully', async () => { + const runMock = vi.fn(); + mockDb.prepare.mockReturnValueOnce({ + run: runMock + }); + + const account: Account = { + id: 'test-id' as UUID, + name: 'Test User', + username: 'testuser', + email: 'test@example.com', + avatarUrl: 'https://example.com/avatar.png' + }; + + const result = await adapter.createAccount(account); + + expect(mockDb.prepare).toHaveBeenCalledWith( + 'INSERT INTO accounts (id, name, username, email, avatarUrl) VALUES (?, ?, ?, ?, ?)' + ); + expect(runMock).toHaveBeenCalledWith( + account.id, + account.name, + account.username, + account.email, + account.avatarUrl + ); + expect(result).toBe(true); + }); + + it('should handle errors when creating account', async () => { + mockDb.prepare.mockReturnValueOnce({ + run: vi.fn().mockImplementationOnce(() => { + throw new Error('Database error'); + }) + }); + + const account: Account = { + id: 'test-id' as UUID, + name: 'Test User', + username: 'testuser', + email: 'test@example.com', + avatarUrl: 'https://example.com/avatar.png' + }; + + const result = await adapter.createAccount(account); + expect(result).toBe(false); + }); + }); + + describe('getActorDetails', () => { + it('should return actor details', async () => { + const mockActors: Actor[] = [ + { id: 'actor-1' as UUID, name: 'Actor 1', username: 'actor1' }, + { id: 'actor-2' as UUID, name: 'Actor 2', username: 'actor2' } + ]; + + mockDb.prepare.mockReturnValueOnce({ + all: vi.fn().mockReturnValueOnce(mockActors) + }); + + const result = await adapter.getActorDetails({ roomId: 'room-1' as UUID }); + + expect(mockDb.prepare).toHaveBeenCalledWith(expect.stringContaining('SELECT a.id, a.name, a.username')); + expect(result).toEqual(mockActors); + }); + + it('should filter out null actors', async () => { + mockDb.prepare.mockReturnValueOnce({ + all: vi.fn().mockReturnValueOnce([null, { id: 'actor-1' as UUID, name: 'Actor 1', username: 'actor1' }, null]) + }); + + const result = await adapter.getActorDetails({ roomId: 'room-1' as UUID }); + + expect(result).toHaveLength(1); + expect(result[0]).toEqual({ id: 'actor-1', name: 'Actor 1', username: 'actor1' }); + }); + }); + + describe('getMemoryById', () => { + it('should return memory when it exists', async () => { + const mockMemory = { + id: 'memory-1' as UUID, + content: JSON.stringify({ text: 'Test memory' }) + }; + + mockDb.prepare.mockReturnValueOnce({ + get: vi.fn().mockReturnValueOnce(mockMemory), + bind: vi.fn() + }); + + const result = await adapter.getMemoryById('memory-1' as UUID); + + expect(mockDb.prepare).toHaveBeenCalledWith('SELECT * FROM memories WHERE id = ?'); + expect(result).toEqual({ + ...mockMemory, + content: { text: 'Test memory' } + }); + }); + + it('should return null when memory does not exist', async () => { + mockDb.prepare.mockReturnValueOnce({ + get: vi.fn().mockReturnValueOnce(undefined), + bind: vi.fn() + }); + + const result = await adapter.getMemoryById('non-existent' as UUID); + expect(result).toBeNull(); + }); + }); + + describe('Character operations', () => { + const mockCharacter: Required> = { + id: testUuid, + name: 'Test Character', + bio: 'Test Bio', + lore: ['Test lore'], + messageExamples: [[]], + postExamples: ['Test post'], + topics: ['Test topic'], + adjectives: ['Test adjective'], + style: { + all: ['Test style'], + chat: ['Test chat style'], + post: ['Test post style'] + } + }; + + it('should create a character', async () => { + const runMock = vi.fn(); + mockDb.prepare.mockReturnValueOnce({ + run: runMock + }); + + await adapter.createCharacter(mockCharacter); + + expect(mockDb.prepare).toHaveBeenCalledWith(expect.stringContaining('INSERT INTO characters')); + expect(runMock).toHaveBeenCalledWith( + mockCharacter.id, + mockCharacter.name, + mockCharacter.bio, + JSON.stringify(mockCharacter) + ); + }); + + it('should create a character with generated UUID', async () => { + const runMock = vi.fn(); + mockDb.prepare.mockReturnValueOnce({ + run: runMock + }); + + const characterWithoutId: Omit = { + name: mockCharacter.name, + bio: mockCharacter.bio, + lore: mockCharacter.lore, + messageExamples: mockCharacter.messageExamples, + postExamples: mockCharacter.postExamples, + topics: mockCharacter.topics, + adjectives: mockCharacter.adjectives, + style: mockCharacter.style + }; + + await adapter.createCharacter(characterWithoutId as Character); + + expect(mockDb.prepare).toHaveBeenCalledWith(expect.stringContaining('INSERT INTO characters')); + expect(runMock).toHaveBeenCalledWith( + expect.any(String), + characterWithoutId.name, + characterWithoutId.bio, + expect.any(String) + ); + }); + + it('should update a character', async () => { + const runMock = vi.fn(); + mockDb.prepare.mockReturnValueOnce({ + run: runMock + }); + + await adapter.updateCharacter(mockCharacter); + + expect(mockDb.prepare).toHaveBeenCalledWith(expect.stringContaining('UPDATE characters')); + expect(runMock).toHaveBeenCalledWith( + mockCharacter.name, + mockCharacter.bio, + JSON.stringify(mockCharacter), + mockCharacter.id + ); + }); + + it('should get a character', async () => { + mockDb.prepare.mockReturnValueOnce({ + get: vi.fn().mockReturnValueOnce(mockCharacter) + }); + + const result = await adapter.getCharacter(mockCharacter.id); + + expect(mockDb.prepare).toHaveBeenCalledWith('SELECT * FROM characters WHERE id = ?'); + expect(result).toEqual(mockCharacter); + }); + + it('should return null when getting non-existent character', async () => { + mockDb.prepare.mockReturnValueOnce({ + get: vi.fn().mockReturnValueOnce(null) + }); + + const result = await adapter.getCharacter(testUuid); + + expect(mockDb.prepare).toHaveBeenCalledWith('SELECT * FROM characters WHERE id = ?'); + expect(result).toBeNull(); + }); + + it('should remove a character', async () => { + const runMock = vi.fn(); + mockDb.prepare.mockReturnValueOnce({ + run: runMock + }); + + await adapter.removeCharacter(mockCharacter.id); + + expect(mockDb.prepare).toHaveBeenCalledWith('DELETE FROM characters WHERE id = ?'); + expect(runMock).toHaveBeenCalledWith(mockCharacter.id); + }); + + it('should list all characters', async () => { + const mockCharacters = [mockCharacter]; + mockDb.prepare.mockReturnValueOnce({ + all: vi.fn().mockReturnValueOnce(mockCharacters) + }); + + const result = await adapter.listCharacters(); + + expect(mockDb.prepare).toHaveBeenCalledWith('SELECT * FROM characters'); + expect(result).toEqual(mockCharacters); + }); + + it('should import a character', async () => { + const runMock = vi.fn(); + mockDb.prepare.mockReturnValueOnce({ + run: runMock + }); + + await adapter.importCharacter(mockCharacter); + + expect(mockDb.prepare).toHaveBeenCalledWith(expect.stringContaining('INSERT INTO characters')); + expect(runMock).toHaveBeenCalledWith( + mockCharacter.id, + mockCharacter.name, + mockCharacter.bio, + JSON.stringify(mockCharacter) + ); + }); + + it('should export a character', async () => { + mockDb.prepare.mockReturnValueOnce({ + get: vi.fn().mockReturnValueOnce(mockCharacter) + }); + + const result = await adapter.exportCharacter(mockCharacter.id); + + expect(mockDb.prepare).toHaveBeenCalledWith('SELECT * FROM characters WHERE id = ?'); + expect(result).toEqual(mockCharacter); + }); + }); + + describe('Cache operations', () => { + const mockParams = { + key: 'test-key', + agentId: 'agent-1' as UUID, + value: 'test-value' + }; + + it('should set cache value', async () => { + const runMock = vi.fn(); + mockDb.prepare.mockReturnValueOnce({ + run: runMock + }); + + const result = await adapter.setCache(mockParams); + + expect(mockDb.prepare).toHaveBeenCalledWith(expect.stringContaining('INSERT OR REPLACE INTO cache')); + expect(runMock).toHaveBeenCalledWith(mockParams.key, mockParams.agentId, mockParams.value); + expect(result).toBe(true); + }); + + it('should get cache value', async () => { + mockDb.prepare.mockReturnValueOnce({ + get: vi.fn().mockReturnValueOnce({ value: mockParams.value }) + }); + + const result = await adapter.getCache({ + key: mockParams.key, + agentId: mockParams.agentId + }); + + expect(mockDb.prepare).toHaveBeenCalledWith(expect.stringContaining('SELECT value FROM cache')); + expect(result).toBe(mockParams.value); + }); + + it('should delete cache value', async () => { + const runMock = vi.fn(); + mockDb.prepare.mockReturnValueOnce({ + run: runMock + }); + + const result = await adapter.deleteCache({ + key: mockParams.key, + agentId: mockParams.agentId + }); + + expect(mockDb.prepare).toHaveBeenCalledWith('DELETE FROM cache WHERE key = ? AND agentId = ?'); + expect(runMock).toHaveBeenCalledWith(mockParams.key, mockParams.agentId); + expect(result).toBe(true); + }); + }); });