diff --git a/packages/agent/src/index.ts b/packages/agent/src/index.ts index d4bdf8332cf..fdc4dc52599 100644 --- a/packages/agent/src/index.ts +++ b/packages/agent/src/index.ts @@ -6,15 +6,15 @@ import { type Character, type ClientInstance, DbCacheAdapter, - logger, - FsCacheAdapter, type IAgentRuntime, type IDatabaseAdapter, type IDatabaseCacheAdapter, + logger, + ModelType, parseBooleanFromText, settings, stringToUuid, - validateCharacterConfig, + validateCharacterConfig } from "@elizaos/core"; import { bootstrapPlugin } from "@elizaos/plugin-bootstrap"; import fs from "node:fs"; @@ -349,10 +349,17 @@ export async function initializeClients( clients.push(startedClient); } } + if (plugin.handlers) { + for (const [modelType, handler] of Object.entries(plugin.handlers)) { + runtime.registerHandler(modelType as ModelType, handler); + } + } } } - return clients; + runtime.clients = clients; + + } export async function createAgent( @@ -374,18 +381,6 @@ export async function createAgent( }); } -function initializeFsCache(baseDir: string, character: Character) { - if (!character?.id) { - throw new Error( - "initializeFsCache requires id to be set in character definition" - ); - } - const cacheDir = path.resolve(baseDir, character.id, "cache"); - - const cache = new CacheManager(new FsCacheAdapter(cacheDir)); - return cache; -} - function initializeDbCache(character: Character, db: IDatabaseCacheAdapter) { if (!character?.id) { throw new Error( @@ -412,15 +407,6 @@ function initializeCache( "Database adapter is not provided for CacheStore.Database." ); - case CacheStore.FILESYSTEM: - logger.info("Using File System Cache..."); - if (!baseDir) { - throw new Error( - "baseDir must be provided for CacheStore.FILESYSTEM." - ); - } - return initializeFsCache(baseDir, character); - default: throw new Error( `Invalid cache store: ${cacheStore} or required configuration missing.` @@ -428,7 +414,6 @@ function initializeCache( } } - async function findDatabaseAdapter(runtime: AgentRuntime) { const { adapters } = runtime; let adapter: Adapter | undefined; @@ -449,8 +434,6 @@ async function findDatabaseAdapter(runtime: AgentRuntime) { return adapterInterface; } - - async function startAgent( character: Character, characterServer: CharacterServer @@ -482,7 +465,7 @@ async function startAgent( await runtime.initialize(); // start assigned clients - runtime.clients = await initializeClients(character, runtime); + await initializeClients(character, runtime); // add to container characterServer.registerAgent(runtime); diff --git a/packages/agent/src/server.ts b/packages/agent/src/server.ts index f1ea0bc001e..cef6c4e4980 100644 --- a/packages/agent/src/server.ts +++ b/packages/agent/src/server.ts @@ -5,13 +5,13 @@ import { generateImage, generateMessageResponse, generateObject, - getEmbeddingZeroVector, messageCompletionFooter, - ModelClass, + ModelType, stringToUuid, type Content, type Media, - type Memory + type Memory, + type IAgentRuntime } from "@elizaos/core"; import bodyParser from "body-parser"; import cors from "cors"; @@ -166,7 +166,7 @@ export class CharacterServer { return; } - const transcription = await runtime.getModelProviderManager().call(ModelClass.AUDIO_TRANSCRIPTION, { + const transcription = await runtime.getModelProviderManager().call(ModelType.AUDIO_TRANSCRIPTION, { file: fs.createReadStream(audioFile.path), model: "whisper-1", }); @@ -276,7 +276,7 @@ export class CharacterServer { const response = await generateMessageResponse({ runtime: runtime, context, - modelClass: ModelClass.TEXT_LARGE, + modelType: ModelType.TEXT_LARGE, }); if (!response) { @@ -286,13 +286,17 @@ export class CharacterServer { return; } + const zeroVector = runtime.getModelProviderManager().call(ModelType.EMBEDDING, { + text: null, + }); + // save response to memory const responseMessage: Memory = { id: stringToUuid(`${messageId}-${runtime.agentId}`), ...userMessage, userId: runtime.agentId, content: response, - embedding: getEmbeddingZeroVector(), + embedding: zeroVector, createdAt: Date.now(), }; @@ -488,7 +492,7 @@ export class CharacterServer { const response = await generateObject({ runtime, context, - modelClass: ModelClass.TEXT_SMALL, + modelType: ModelType.TEXT_SMALL, schema: hyperfiOutSchema, }); @@ -791,7 +795,7 @@ export class CharacterServer { const response = await generateMessageResponse({ runtime: runtime, context, - modelClass: ModelClass.TEXT_LARGE, + modelType: ModelType.TEXT_LARGE, }); // save response to memory @@ -824,7 +828,7 @@ export class CharacterServer { // Get the text to convert to speech const textToSpeak = response.text; - const speechResponse = await runtime.getModelProviderManager().call(ModelClass.AUDIO_TRANSCRIPTION, { + const speechResponse = await runtime.getModelProviderManager().call(ModelType.AUDIO_TRANSCRIPTION, { text: textToSpeak, runtime, }); diff --git a/packages/client/src/components/overview.tsx b/packages/client/src/components/overview.tsx index 68b77baec22..7d9ff031126 100644 --- a/packages/client/src/components/overview.tsx +++ b/packages/client/src/components/overview.tsx @@ -14,7 +14,6 @@ export default function Overview({ character }: { character: Character }) { - { - class MockFlagEmbedding { - - - static async init() { - return new MockFlagEmbedding(); - } - - async queryEmbed(text: string | string[]) { - return [new Float32Array(384).fill(0.1)]; - } - } - - return { - FlagEmbedding: MockFlagEmbedding, - EmbeddingModel: { - BGESmallENV15: "BGE-small-en-v1.5", - }, - }; -}); - -// Mock fetch for remote embedding calls -global.fetch = vi.fn(); - -describe("Embedding Module", () => { - let mockRuntime: IAgentRuntime; - - beforeEach(() => { - // Reset all mocks - vi.clearAllMocks(); - - // Prepare a mock runtime - mockRuntime = { - agentId: "00000000-0000-0000-0000-000000000000" as `${string}-${string}-${string}-${string}-${string}`, - providers: [], - actions: [], - evaluators: [], - plugins: [], - character: { - name: "Test Character", - username: "test", - bio: ["Test bio"], - lore: ["Test lore"], - messageExamples: [], - }, - getModelManager: () => (), - messageManager: { - getCachedEmbeddings: vi.fn().mockResolvedValue([]) - } - } as unknown as IAgentRuntime; - - // Reset fetch mock with proper Response object - const mockResponse = { - ok: true, - json: async () => ({ - data: [{ embedding: new Array(384).fill(0.1) }], - }), - headers: new Headers(), - redirected: false, - status: 200, - statusText: "OK", - type: "basic", - url: "https://api.openai.com/v1/embeddings", - body: null, - bodyUsed: false, - clone: () => ({} as Response), - arrayBuffer: async () => new ArrayBuffer(0), - blob: async () => new Blob(), - formData: async () => new FormData(), - text: async () => "" - } as Response; - - vi.mocked(global.fetch).mockReset(); - vi.mocked(global.fetch).mockResolvedValue(mockResponse); - }); - - describe("getEmbeddingConfig", () => { - test("should return OpenAI config by default", () => { - const config = getEmbeddingConfig(); - expect(config.provider).toBe(EmbeddingProvider.OpenAI); - expect(config.model).toBe("text-embedding-3-small"); - expect(config.dimensions).toBe(1536); - }); - - test("should use runtime provider when available", () => { - const mockModelProvider = { - provider: EmbeddingProvider.OpenAI, - models: { - [ModelClass.TEXT_EMBEDDING]: { - name: "text-embedding-3-small", - dimensions: 1536 - } - } - }; - - const runtime = { - getModelManager: () => mockModelProvider - } as unknown as IAgentRuntime; - - const config = getEmbeddingConfig(runtime); - expect(config.provider).toBe(EmbeddingProvider.OpenAI); - expect(config.model).toBe("text-embedding-3-small"); - expect(config.dimensions).toBe(1536); - }); - }); - - describe("getEmbeddingType", () => { - test("should return 'local' by default", () => { - const type = getEmbeddingType(mockRuntime); - expect(type).toBe("local"); - }); - }); - - describe("getEmbeddingZeroVector", () => { - test("should return 384-length zero vector by default (BGE)", () => { - const vector = getEmbeddingZeroVector(); - expect(vector).toHaveLength(384); - expect(vector.every((val) => val === 0)).toBe(true); - }); - - test("should return 1536-length zero vector for OpenAI if enabled", () => { - const vector = getEmbeddingZeroVector(); - expect(vector).toHaveLength(1536); - expect(vector.every((val) => val === 0)).toBe(true); - }); - }); - - describe("embed function", () => { - test("should return an empty array for empty input text", async () => { - const result = await embed(mockRuntime, ""); - expect(result).toEqual([]); - }); - - test("should return cached embedding if it already exists", async () => { - const cachedEmbedding = new Array(384).fill(0.5); - mockRuntime.messageManager.getCachedEmbeddings = vi - .fn() - .mockResolvedValue([{ embedding: cachedEmbedding }]); - - const result = await embed(mockRuntime, "test input"); - expect(result).toBe(cachedEmbedding); - }); - - test("should handle local embedding successfully", async () => { - const result = await embed(mockRuntime, "test input"); - expect(result).toHaveLength(384); - expect(result.every((v) => typeof v === "number")).toBe(true); - }); - - test("should handle remote embedding successfully", async () => { - const result = await embed(mockRuntime, "test input"); - expect(result).toHaveLength(384); - expect(vi.mocked(global.fetch)).toHaveBeenCalled(); - }); - - test("should throw on remote embedding if fetch fails", async () => { - // Mock fetch to reject - vi.mocked(global.fetch).mockRejectedValueOnce(new Error("API Error")); - - await expect(embed(mockRuntime, "test input")).rejects.toThrow("API Error"); - }); - - test("should handle concurrent embedding requests", async () => { - const promises = Array(5) - .fill(null) - .map(() => embed(mockRuntime, "concurrent test")); - await expect(Promise.all(promises)).resolves.toBeDefined(); - }); - }); - - // Add tests for new embedding configurations - describe("embedding configuration", () => { - test("should handle embedding provider configuration", async () => { - const mockModelProvider = { - generateText: vi.fn(), - generateObject: vi.fn(), - generateImage: vi.fn(), - generateEmbedding: vi.fn(), - provider: EmbeddingProvider.OpenAI, - models: { - [ModelClass.TEXT_EMBEDDING]: { - name: "text-embedding-3-small", - dimensions: 1536 - } - }, - getModelManager: () => mockModelProvider - }; - - const runtime = { - agentId: "test-agent", - databaseAdapter: {} as IDatabaseAdapter, - getModelManager: () => mockModelProvider, - } as unknown as IAgentRuntime; - - const config = getEmbeddingConfig(runtime); - expect(config.provider).toBe(EmbeddingProvider.OpenAI); - expect(config.model).toBe("text-embedding-3-small"); - expect(config.dimensions).toBe(1536); - }); - - test("should return default config when no runtime provided", () => { - const config = getEmbeddingConfig(); - expect(config.provider).toBe(EmbeddingProvider.OpenAI); - expect(config.model).toBe("text-embedding-3-small"); - expect(config.dimensions).toBe(1536); - }); - }); - - describe("embedding type detection", () => { - test("should determine embedding type based on runtime configuration", () => { - const mockRuntimeLocal = { - ...mockRuntime, - getSetting: (key: string) => "false" - } as IAgentRuntime; - - expect(getEmbeddingType(mockRuntimeLocal)).toBe("local"); - }); - }); -}); diff --git a/packages/core/__tests__/knowledge.test.ts b/packages/core/__tests__/knowledge.test.ts index cc03f98ea06..78ed60188b1 100644 --- a/packages/core/__tests__/knowledge.test.ts +++ b/packages/core/__tests__/knowledge.test.ts @@ -1,15 +1,7 @@ -import { describe, it, expect, vi, beforeEach } from "vitest"; +import { beforeEach, describe, expect, it, vi } from "vitest"; import knowledge from "../src/knowledge"; import type { AgentRuntime } from "../src/runtime"; -import { KnowledgeItem, type Memory } from "../src/types"; - -// Mock dependencies -vi.mock("../embedding", () => ({ - embed: vi.fn().mockResolvedValue(new Float32Array(1536).fill(0)), - getEmbeddingZeroVector: vi - .fn() - .mockReturnValue(new Float32Array(1536).fill(0)), -})); +import { type Memory } from "../src/types"; vi.mock("../generation", () => ({ splitChunks: vi.fn().mockImplementation(async (text) => [text]), diff --git a/packages/core/__tests__/runtime.test.ts b/packages/core/__tests__/runtime.test.ts index f82d21a8dab..dc498298733 100644 --- a/packages/core/__tests__/runtime.test.ts +++ b/packages/core/__tests__/runtime.test.ts @@ -5,19 +5,11 @@ import { type IDatabaseAdapter, type IMemoryManager, type Memory, - ModelClass, - ServiceType, + ModelType, type UUID } from "../src/types"; import { mockCharacter } from "./mockCharacter"; -// Mock the embedding module -vi.mock("../src/embedding", () => ({ - embed: vi.fn().mockResolvedValue([0.1, 0.2, 0.3]), - getRemoteEmbedding: vi.fn().mockResolvedValue(new Float32Array([0.1, 0.2, 0.3])), - getLocalEmbedding: vi.fn().mockResolvedValue(new Float32Array([0.1, 0.2, 0.3])) -})); - // Mock dependencies with minimal implementations const mockDatabaseAdapter: IDatabaseAdapter = { db: {}, @@ -86,9 +78,6 @@ describe("AgentRuntime", () => { beforeEach(() => { vi.clearAllMocks(); - const ModelManager = { - getProvider: () => mockModelProvider, - }; runtime = new AgentRuntime({ character: { @@ -142,23 +131,9 @@ describe("AgentRuntime", () => { }); }); - describe("service management", () => { - it("should allow registering and retrieving services", async () => { - const mockService = { - serviceType: ServiceType.TEXT_GENERATION, - type: ServiceType.TEXT_GENERATION, - initialize: vi.fn().mockResolvedValue(undefined), - }; - - await runtime.registerService(mockService); - const retrievedService = runtime.getService(ServiceType.TEXT_GENERATION); - expect(retrievedService).toBe(mockService); - }); - }); - describe("model provider management", () => { it("should provide access to the configured model provider", () => { - const provider = runtime.getModelManager(); + const provider = runtime; expect(provider).toBeDefined(); }); }); @@ -239,17 +214,17 @@ describe("Model Provider Configuration", () => { cacheManager: mockCacheManager, }); - const provider = runtime.getModelManager(); + const provider = runtime; expect(provider.models.default).toBeDefined(); expect(provider.models.default.name).toBeDefined(); }); test("should handle missing optional model configurations", () => { - const provider = runtime.getModelManager(); + const provider = runtime; // These might be undefined but shouldn't throw errors - expect(() => provider.models[ModelClass.IMAGE]).not.toThrow(); - expect(() => provider.models[ModelClass.IMAGE_VISION]).not.toThrow(); + expect(() => provider.models[ModelType.IMAGE_GENERATION]).not.toThrow(); + expect(() => provider.models[ModelType.IMAGE_DESCRIPTION]).not.toThrow(); }); test("should validate model provider name format", () => { @@ -260,7 +235,6 @@ describe("Model Provider Configuration", () => { expect(() => new AgentRuntime({ character: { ...mockCharacter, - modelProvider: invalidProvider, bio: ["Test bio"], // Ensure bio is an array lore: ["Test lore"], // Ensure lore is an array messageExamples: [], // Required by Character type @@ -285,7 +259,6 @@ describe("Model Provider Configuration", () => { expect(() => new AgentRuntime({ character: { ...mockCharacter, - modelProvider: validProvider, bio: ["Test bio"], // Ensure bio is an array lore: ["Test lore"], // Ensure lore is an array messageExamples: [], // Required by Character type @@ -298,7 +271,6 @@ describe("Model Provider Configuration", () => { post: [] } }, - modelProvider: validProvider, databaseAdapter: mockDatabaseAdapter, cacheManager: mockCacheManager, })).not.toThrow(); @@ -307,22 +279,6 @@ describe("Model Provider Configuration", () => { }); }); -describe("ModelManager", () => { - test("should get correct model provider settings", async () => { - const runtime = new AgentRuntime({ - databaseAdapter: mockDatabaseAdapter, - cacheManager: { - get: vi.fn(), - set: vi.fn(), - delete: vi.fn(), - }, - }); - - const provider = runtime.getModelManager(); - expect(provider).toBeDefined(); - }); -}); - describe("MemoryManagerService", () => { test("should provide access to different memory managers", async () => { const runtime = new AgentRuntime({ @@ -361,23 +317,4 @@ describe("MemoryManagerService", () => { runtime.registerMemoryManager(customManager); expect(runtime.getMemoryManager("custom")).toBe(customManager); }); -}); - -describe("ServiceManager", () => { - test("should handle service registration and retrieval", async () => { - const runtime = new AgentRuntime({ - databaseAdapter: mockDatabaseAdapter, - cacheManager: mockCacheManager - }); - - const mockService = { - serviceType: ServiceType.TEXT_GENERATION, - type: ServiceType.TEXT_GENERATION, - initialize: vi.fn().mockResolvedValue(undefined) - }; - - await runtime.registerService(mockService); - const retrievedService = runtime.getService(ServiceType.TEXT_GENERATION); - expect(retrievedService).toBe(mockService); - }); -}); +}); \ No newline at end of file diff --git a/packages/core/src/cache.ts b/packages/core/src/cache.ts index 7cf2a38f9ce..ced8218cca9 100644 --- a/packages/core/src/cache.ts +++ b/packages/core/src/cache.ts @@ -33,39 +33,6 @@ export class MemoryCacheAdapter implements ICacheAdapter { } } -export class FsCacheAdapter implements ICacheAdapter { - constructor(private dataDir: string) {} - - async get(key: string): Promise { - try { - return await fs.readFile(path.join(this.dataDir, key), "utf8"); - } catch { - // console.error(error); - return undefined; - } - } - - async set(key: string, value: string): Promise { - try { - const filePath = path.join(this.dataDir, key); - // Ensure the directory exists - await fs.mkdir(path.dirname(filePath), { recursive: true }); - await fs.writeFile(filePath, value, "utf8"); - } catch (error) { - console.error(error); - } - } - - async delete(key: string): Promise { - try { - const filePath = path.join(this.dataDir, key); - await fs.unlink(filePath); - } catch { - // console.error(error); - } - } -} - export class DbCacheAdapter implements ICacheAdapter { constructor( private db: IDatabaseCacheAdapter, diff --git a/packages/core/src/embedding.ts b/packages/core/src/embedding.ts deleted file mode 100644 index a444f252855..00000000000 --- a/packages/core/src/embedding.ts +++ /dev/null @@ -1,69 +0,0 @@ -// TODO: Maybe create these functions to read from character settings or env -// import { getEmbeddingModelSettings, getEndpoint } from "./models.ts"; -import logger from "./logger.ts"; -import { type IAgentRuntime } from "./types.ts"; - -export type EmbeddingConfig = { - readonly dimensions: number; - readonly model: string; - readonly provider: string; -}; - -/** - * Gets embeddings from a remote API endpoint. Falls back to local BGE/384 - * - * @param {string} input - The text to generate embeddings for - * @param {EmbeddingOptions} options - Configuration options including: - * - model: The model name to use - * - endpoint: Base API endpoint URL - * - apiKey: Optional API key for authentication - * - isOllama: Whether this is an Ollama endpoint - * - dimensions: Desired embedding dimensions - * @param {IAgentRuntime} runtime - The agent runtime context - * @returns {Promise} Array of embedding values - * @throws {Error} If the API request fails - */ - -export async function embed(runtime: IAgentRuntime, input: string) { - logger.debug("Embedding request:", { - input: `${input?.slice(0, 50)}...`, - inputType: typeof input, - inputLength: input?.length, - isString: typeof input === "string", - isEmpty: !input, - }); - - // Validate input - if (!input || typeof input !== "string" || input.trim().length === 0) { - logger.warn("Invalid embedding input:", { - input, - type: typeof input, - length: input?.length, - }); - return []; // Return empty embedding array - } - - // Check cache first - const cachedEmbedding = await retrieveCachedEmbedding(runtime, input); - if (cachedEmbedding) return cachedEmbedding; - - // Use remote embedding - return await runtime.getModelManager().generateEmbedding(input); - - async function retrieveCachedEmbedding( - runtime: IAgentRuntime, - input: string - ) { - if (!input) { - logger.log("No input to retrieve cached embedding for"); - return null; - } - - const similaritySearchResult = - await runtime.messageManager.getCachedEmbeddings(input); - if (similaritySearchResult.length > 0) { - return similaritySearchResult[0].embedding; - } - return null; - } -} diff --git a/packages/core/src/generation.ts b/packages/core/src/generation.ts index 43fc70d3b27..c87f9f1abe5 100644 --- a/packages/core/src/generation.ts +++ b/packages/core/src/generation.ts @@ -5,13 +5,13 @@ import { parseJSONObjectFromText } from "./parsing.ts"; import { type Content, type IAgentRuntime, - ModelClass + ModelType } from "./types.ts"; interface GenerateObjectOptions { runtime: IAgentRuntime; context: string; - modelClass: ModelClass; + modelType: ModelType; output?: "object" | "array" | "enum" | "no-schema" | undefined; schema?: ZodSchema; schemaName?: string; @@ -69,21 +69,20 @@ async function withRetry( export async function generateText({ runtime, context, - modelClass = ModelClass.TEXT_SMALL, + modelType = ModelType.TEXT_SMALL, stopSequences, }: { runtime: IAgentRuntime; context: string; - modelClass: ModelClass; + modelType: ModelType; stopSequences?: string[]; customSystemPrompt?: string; }): Promise { logFunctionCall("generateText", runtime); - const { text } = await runtime.getModelManager().generateText({ + const { text } = await runtime.call(modelType, { context, - modelClass, - stop: stopSequences, + stopSequences, }); return text; @@ -92,12 +91,12 @@ export async function generateText({ export async function generateTextArray({ runtime, context, - modelClass = ModelClass.TEXT_SMALL, + modelType = ModelType.TEXT_SMALL, stopSequences, }: { runtime: IAgentRuntime; context: string; - modelClass: ModelClass; + modelType: ModelType; stopSequences?: string[]; }): Promise { logFunctionCall("generateTextArray", runtime); @@ -106,7 +105,7 @@ export async function generateTextArray({ const result = await generateObject({ runtime, context, - modelClass, + modelType, schema: z.array(z.string()), stopSequences, }); @@ -120,14 +119,14 @@ export async function generateTextArray({ async function generateEnum({ runtime, context, - modelClass = ModelClass.TEXT_SMALL, + modelType = ModelType.TEXT_SMALL, enumValues, functionName, stopSequences, }: { runtime: IAgentRuntime; context: string; - modelClass: ModelClass; + modelType: ModelType; enumValues: Array; functionName: string; stopSequences?: string[]; @@ -142,7 +141,7 @@ async function generateEnum({ const result = await generateObject({ runtime, context, - modelClass, + modelType, output: "enum", enum: enumValues, mode: "json", @@ -159,12 +158,12 @@ async function generateEnum({ export async function generateShouldRespond({ runtime, context, - modelClass = ModelClass.TEXT_SMALL, + modelType = ModelType.TEXT_SMALL, stopSequences, }: { runtime: IAgentRuntime; context: string; - modelClass: ModelClass; + modelType: ModelType; stopSequences?: string[]; }): Promise<"RESPOND" | "IGNORE" | "STOP" | null> { const RESPONSE_VALUES = ["RESPOND", "IGNORE", "STOP"] as string[]; @@ -172,7 +171,7 @@ export async function generateShouldRespond({ const result = await generateEnum({ runtime, context, - modelClass, + modelType, enumValues: RESPONSE_VALUES, functionName: "generateShouldRespond", stopSequences, @@ -184,12 +183,12 @@ export async function generateShouldRespond({ export async function generateTrueOrFalse({ runtime, context = "", - modelClass = ModelClass.TEXT_SMALL, + modelType = ModelType.TEXT_SMALL, stopSequences, }: { runtime: IAgentRuntime; context: string; - modelClass: ModelClass; + modelType: ModelType; stopSequences?: string[]; }): Promise { logFunctionCall("generateTrueOrFalse", runtime); @@ -199,7 +198,7 @@ export async function generateTrueOrFalse({ const result = await generateEnum({ runtime, context, - modelClass, + modelType, enumValues: BOOL_VALUES, functionName: "generateTrueOrFalse", stopSequences, @@ -212,7 +211,7 @@ export async function generateTrueOrFalse({ export const generateObject = async ({ runtime, context, - modelClass = ModelClass.TEXT_SMALL, + modelType = ModelType.TEXT_SMALL, output = "object", schema, schemaName, @@ -228,27 +227,26 @@ export const generateObject = async ({ throw new Error(errorMessage); } - const { object } = await runtime.getModelManager().generateObject({ + const { object } = await runtime.call(modelType, { context, - modelClass, stop: stopSequences, }); - logger.debug(`Received Object response from ${modelClass} model.`); + logger.debug(`Received Object response from ${modelType} model.`); return schema ? schema.parse(object) : object; }; export async function generateObjectArray({ runtime, context, - modelClass = ModelClass.TEXT_SMALL, + modelType = ModelType.TEXT_SMALL, schema, schemaName, schemaDescription, }: { runtime: IAgentRuntime; context: string; - modelClass: ModelClass; + modelType: ModelType; schema?: ZodSchema; schemaName?: string; schemaDescription?: string; @@ -261,7 +259,7 @@ export async function generateObjectArray({ const result = await generateObject({ runtime, context, - modelClass, + modelType, output: "array", schema, schemaName, @@ -274,12 +272,12 @@ export async function generateObjectArray({ export async function generateMessageResponse({ runtime, context, - modelClass = ModelClass.TEXT_SMALL, + modelType = ModelType.TEXT_SMALL, stopSequences, }: { runtime: IAgentRuntime; context: string; - modelClass: ModelClass; + modelType: ModelType; stopSequences?: string[]; }): Promise { logFunctionCall("generateMessageResponse", runtime); @@ -287,9 +285,8 @@ export async function generateMessageResponse({ logger.debug("Context:", context); return await withRetry(async () => { - const { text } = await runtime.getModelManager().generateText({ + const { text } = await runtime.call(modelType, { context, - modelClass, stop: stopSequences, }); @@ -334,7 +331,7 @@ export const generateImage = async ( return await withRetry( async () => { - const result = await runtime.getModelManager().generateImage(data); + const result = await runtime.call(ModelType.IMAGE, data); return { success: true, data: result.images, @@ -357,13 +354,8 @@ export const generateCaption = async ( }> => { logFunctionCall("generateCaption", runtime); const { imageUrl } = data; - const imageDescriptionService = runtime.getModelManager().describeImage(imageUrl); - - if (!imageDescriptionService) { - throw new Error("Image description service not found"); - } + const resp = await runtime.call(ModelType.IMAGE_DESCRIPTION, imageUrl); - const resp = await imageDescriptionService.describeImage(imageUrl); return { title: resp.title.trim(), description: resp.description.trim(), diff --git a/packages/core/src/helper.ts b/packages/core/src/helper.ts index 25c9479a741..7176680b81a 100644 --- a/packages/core/src/helper.ts +++ b/packages/core/src/helper.ts @@ -1,18 +1,16 @@ -import { AutoTokenizer } from "@huggingface/transformers"; import { encodingForModel, type TiktokenModel } from "js-tiktoken"; -import logger from "./logger.ts"; -import { TokenizerType, type IAgentRuntime, type ModelSettings } from "./types.ts"; import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"; +import logger from "./logger.ts"; +import { type IAgentRuntime, type ModelSettings } from "./types.ts"; export function logFunctionCall(functionName: string, runtime?: IAgentRuntime) { logger.info(`Function call: ${functionName}`, { functionName, - // runtime: JSON.stringify(runtime?.getModelManager()) + // runtime: JSON.stringify(runtime?) }); } - export async function trimTokens( context: string, maxTokens: number, @@ -30,52 +28,16 @@ export async function trimTokens( return truncateTiktoken("gpt-4o", context, maxTokens); } - // Choose the truncation method based on tokenizer type - if (tokenizerType === TokenizerType.Auto) { - return truncateAuto(tokenizerModel, context, maxTokens); - } - - if (tokenizerType === TokenizerType.TikToken) { - return truncateTiktoken( - tokenizerModel as TiktokenModel, - context, - maxTokens - ); - } + return truncateTiktoken( + tokenizerModel as TiktokenModel, + context, + maxTokens + ); logger.warn(`Unsupported tokenizer type: ${tokenizerType}`); return truncateTiktoken("gpt-4o", context, maxTokens); } - - - -async function truncateAuto( - modelPath: string, - context: string, - maxTokens: number -) { - try { - const tokenizer = await AutoTokenizer.from_pretrained(modelPath); - const tokens = tokenizer.encode(context); - - // If already within limits, return unchanged - if (tokens.length <= maxTokens) { - return context; - } - - // Keep the most recent tokens by slicing from the end - const truncatedTokens = tokens.slice(-maxTokens); - - // Decode back to text - js-tiktoken decode() returns a string directly - return tokenizer.decode(truncatedTokens); - } catch (error) { - logger.error("Error in trimTokens:", error); - // Return truncated string if tokenization fails - return context.slice(-maxTokens * 4); // Rough estimate of 4 chars per token - } -} - async function truncateTiktoken( model: TiktokenModel, context: string, @@ -104,7 +66,6 @@ async function truncateTiktoken( } } - export async function splitChunks( content: string, chunkSize = 512, @@ -126,7 +87,6 @@ export async function splitChunks( return chunks; } - export function getModelSettings(modelSettings: Record) { if (!modelSettings) { throw new Error("MODEL_SETTINGS is not defined"); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 6e680f26e48..145983b7a9d 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -4,7 +4,6 @@ export * from "./actions.ts"; export * from "./cache.ts"; export * from "./context.ts"; export * from "./database.ts"; -export * from "./embedding.ts"; export * from "./environment.ts"; export * from "./evaluators.ts"; export * from "./generation.ts"; @@ -21,5 +20,4 @@ export * from "./relationships.ts"; export * from "./runtime.ts"; export * from "./settings.ts"; export * from "./types.ts"; -export * from "./utils.ts"; export * from "./uuid.ts"; diff --git a/packages/core/src/knowledge.ts b/packages/core/src/knowledge.ts index 7dd9857c685..c6dda472a4c 100644 --- a/packages/core/src/knowledge.ts +++ b/packages/core/src/knowledge.ts @@ -1,9 +1,8 @@ -import type { AgentRuntime } from "./runtime.ts"; -import { embed, getEmbeddingZeroVector } from "./embedding.ts"; -import type { KnowledgeItem, UUID, Memory } from "./types.ts"; -import { stringToUuid } from "./uuid.ts"; import { splitChunks } from "./helper.ts"; import logger from "./logger.ts"; +import type { AgentRuntime } from "./runtime.ts"; +import { type KnowledgeItem, type Memory, ModelType, type UUID } from "./types.ts"; +import { stringToUuid } from "./uuid.ts"; async function get( runtime: AgentRuntime, @@ -32,7 +31,9 @@ async function get( return []; } - const embedding = await embed(runtime, processed); + const embedding = await runtime.call(ModelType.TEXT_EMBEDDING, { + text: processed, + }); const fragments = await runtime.knowledgeManager.searchMemoriesByEmbedding( embedding, { @@ -70,6 +71,9 @@ async function set( chunkSize = 512, bleed = 20 ) { + const embedding = await runtime.call(ModelType.TEXT_EMBEDDING, { + text: null, + }); await runtime.documentsManager.createMemory({ id: item.id, agentId: runtime.agentId, @@ -77,14 +81,16 @@ async function set( userId: runtime.agentId, createdAt: Date.now(), content: item.content, - embedding: getEmbeddingZeroVector(), + embedding: embedding, }); const preprocessed = preprocess(item.content.text); const fragments = await splitChunks(preprocessed, chunkSize, bleed); for (const fragment of fragments) { - const embedding = await embed(runtime, fragment); + const embedding = await runtime.call(ModelType.TEXT_EMBEDDING, { + text: fragment, + }); await runtime.knowledgeManager.createMemory({ // We namespace the knowledge base uuid to avoid id // collision with the document above. diff --git a/packages/core/src/memory.ts b/packages/core/src/memory.ts index c51d54e691b..bd564662182 100644 --- a/packages/core/src/memory.ts +++ b/packages/core/src/memory.ts @@ -1,10 +1,10 @@ -import { embed, getEmbeddingZeroVector } from "./embedding.ts"; import logger from "./logger.ts"; -import type { - IAgentRuntime, - IMemoryManager, - Memory, - UUID, +import { + ModelType, + type IAgentRuntime, + type IMemoryManager, + type Memory, + type UUID, } from "./types.ts"; const defaultMatchThreshold = 0.1; @@ -66,11 +66,15 @@ export class MemoryManager implements IMemoryManager { try { // Generate embedding from text content - memory.embedding = await embed(this.runtime, memoryText); + memory.embedding = await this.runtime.call(ModelType.TEXT_EMBEDDING, { + text: memoryText, + }); } catch (error) { logger.error("Failed to generate embedding:", error); // Fallback to zero vector if embedding fails - memory.embedding = getEmbeddingZeroVector().slice(); + memory.embedding = await this.runtime.call(ModelType.TEXT_EMBEDDING, { + text: null, + }); } return memory; @@ -140,6 +144,7 @@ export class MemoryManager implements IMemoryManager { match_threshold?: number; count?: number; roomId: UUID; + agentId: UUID; unique?: boolean; } ): Promise { diff --git a/packages/core/src/runtime.ts b/packages/core/src/runtime.ts index f3907106435..3c4b21a5692 100644 --- a/packages/core/src/runtime.ts +++ b/packages/core/src/runtime.ts @@ -1,6 +1,3 @@ -import { glob } from "glob"; -import { existsSync } from "node:fs"; -import { readFile } from "node:fs/promises"; import { join } from "node:path"; import { names, uniqueNamesGenerator } from "unique-names-generator"; import { v4 as uuidv4 } from "uuid"; @@ -33,24 +30,18 @@ import { type Character, type ClientInstance, type DirectoryItem, - type EmbeddingModelSettings, type Evaluator, - GenerateTextParams, type Goal, type HandlerCallback, type IAgentRuntime, type ICacheManager, type IDatabaseAdapter, - type ImageModelSettings, type IMemoryManager, - IModelManager, type KnowledgeItem, type Memory, - ModelClass, + ModelType, type Plugin, type Provider, - type Service, - type ServiceType, type State, type UUID } from "./types.ts"; @@ -125,131 +116,6 @@ class KnowledgeManager { } } -/** - * Manages model provider settings and configuration - */ -class ModelManager implements IModelManager { - private runtime: AgentRuntime; - - constructor(runtime: AgentRuntime) { - this.runtime = runtime; - } - - modelHandlers = new Map Promise)[]>(); - - registerModelHandler(modelClass: ModelClass, handler: (params: any) => Promise) { - if (!this.modelHandlers.has(modelClass)) { - this.modelHandlers.set(modelClass, []); - } - this.modelHandlers.get(modelClass)?.push(handler); - } - - getModelHandler(modelClass: ModelClass): ((params: any) => Promise) | undefined { - const handlers = this.modelHandlers.get(modelClass); - if (!handlers?.length) { - return undefined; - } - return handlers[0]; - } - - getApiKey() { - return this.runtime.getSetting("API_KEY"); - } - - // TODO: This could be more of a pure registration handler - generateText (params: GenerateTextParams) { - const handler = this.getModelHandler(params.modelClass); - if (!handler) { - throw new Error(`No handler found for ${params.modelClass}`); - } - return handler(params); - } - - generateEmbedding (text: string) { - const handler = this.getModelHandler(ModelClass.TEXT_EMBEDDING); - if (!handler) { - throw new Error(`No handler found for ${ModelClass.TEXT_EMBEDDING}`); - } - return handler(text); - } - - generateImage (params: ImageModelSettings) { - const handler = this.getModelHandler(ModelClass.IMAGE); - if (!handler) { - throw new Error(`No handler found for ${ModelClass.IMAGE}`); - } - return handler(params); - } - - generateAudio (params: any) { - const handler = this.getModelHandler(ModelClass.AUDIO); - if (!handler) { - throw new Error(`No handler found for ${ModelClass.AUDIO}`); - } - return handler(params); - } -} - -/** - * Manages services and their lifecycle - */ -class ServiceManager { - private runtime: AgentRuntime; - private services: Map; - - constructor(runtime: AgentRuntime) { - this.runtime = runtime; - this.services = new Map(); - } - - async registerService(service: Service): Promise { - const serviceType = service.serviceType; - logger.log(`${this.runtime.character.name}(${this.runtime.agentId}) - Registering service:`, serviceType); - - if (this.services.has(serviceType)) { - logger.warn( - `${this.runtime.character.name}(${this.runtime.agentId}) - Service ${serviceType} is already registered. Skipping registration.` - ); - return; - } - - // Add the service to the services map - this.services.set(serviceType, service); - logger.success(`${this.runtime.character.name}(${this.runtime.agentId}) - Service ${serviceType} registered successfully`); - } - - getService(service: ServiceType): T | null { - const serviceInstance = this.services.get(service); - if (!serviceInstance) { - logger.error(`Service ${service} not found`); - return null; - } - return serviceInstance as T; - } - - async initializeServices(): Promise { - for (const [serviceType, service] of this.services.entries()) { - try { - await service.initialize(this.runtime); - this.services.set(serviceType, service); - logger.success( - `${this.runtime.character.name}(${this.runtime.agentId}) - Service ${serviceType} initialized successfully` - ); - } catch (error) { - logger.error( - `${this.runtime.character.name}(${this.runtime.agentId}) - Failed to initialize service ${serviceType}:`, - error - ); - throw error; - } - } - } - - getAllServices(): Map { - return this.services; - } -} - /** * Manages memory-related operations and memory managers */ @@ -372,15 +238,14 @@ export class AgentRuntime implements IAgentRuntime { readonly fetch = fetch; public cacheManager!: ICacheManager; public clients: ClientInstance[] = []; - readonly services: Map; public adapters: Adapter[]; private readonly knowledgeRoot: string; - private readonly modelManager: IModelManager; - private readonly serviceManager: ServiceManager; private readonly memoryManagerService: MemoryManagerService; + handlers = new Map Promise)[]>(); + constructor(opts: { conversationLength?: number; agentId?: UUID; @@ -390,7 +255,6 @@ export class AgentRuntime implements IAgentRuntime { evaluators?: Evaluator[]; plugins?: Plugin[]; providers?: Provider[]; - services?: Service[]; managers?: IMemoryManager[]; databaseAdapter?: IDatabaseAdapter; fetch?: typeof fetch; @@ -435,11 +299,6 @@ export class AgentRuntime implements IAgentRuntime { this.cacheManager = opts.cacheManager; } - this.services = new Map(); - - // Initialize managers - moved ModelManager initialization earlier - this.modelManager = new ModelManager(this); - this.serviceManager = new ServiceManager(this); this.memoryManagerService = new MemoryManagerService(this, this.knowledgeRoot); // Register additional memory managers from options @@ -449,21 +308,6 @@ export class AgentRuntime implements IAgentRuntime { } } - // Register services from options and plugins - if (opts.services) { - for (const service of opts.services) { - this.registerService(service); - } - } - - for (const plugin of this.plugins) { - if (plugin.services) { - for (const service of plugin.services) { - this.registerService(service); - } - } - } - this.plugins = [ ...(opts.character?.plugins ?? []), ...(opts.plugins ?? []), @@ -478,10 +322,6 @@ export class AgentRuntime implements IAgentRuntime { this.registerEvaluator(evaluator); } - for (const service of (plugin.services ?? [])) { - this.registerService(service); - } - for (const provider of (plugin.providers ?? [])) { this.registerContextProvider(provider); } @@ -503,10 +343,6 @@ export class AgentRuntime implements IAgentRuntime { this.adapters = opts.adapters ?? []; } - getModelManager(): IModelManager { - return this.modelManager; - } - async initialize() { await this.ensureRoomExists(this.agentId); await this.ensureUserExists( @@ -516,8 +352,6 @@ export class AgentRuntime implements IAgentRuntime { ); await this.ensureParticipantExists(this.agentId, this.agentId); - await this.serviceManager.initializeServices(); - if (this.character?.knowledge && this.character.knowledge.length > 0) { // Non-RAG mode: only process string knowledge const stringKnowledge = this.character.knowledge.filter( @@ -730,7 +564,7 @@ export class AgentRuntime implements IAgentRuntime { const result = await generateText({ runtime: this, context, - modelClass: ModelClass.TEXT_SMALL, + modelType: ModelType.TEXT_SMALL, }); const evaluators = parseJsonArrayFromText( @@ -1366,14 +1200,6 @@ Text: ${attachment.text} return this.memoryManagerService.getMemoryManager(tableName); } - getService(service: ServiceType): T | null { - return this.serviceManager.getService(service); - } - - async registerService(service: Service): Promise { - await this.serviceManager.registerService(service); - } - // Memory manager getters get messageManager(): IMemoryManager { return this.memoryManagerService.getMessageManager(); @@ -1394,4 +1220,27 @@ Text: ${attachment.text} get knowledgeManager(): IMemoryManager { return this.memoryManagerService.getKnowledgeManager(); } + + registerHandler(modelType: ModelType, handler: (params: any) => Promise) { + if (!this.handlers.has(modelType)) { + this.handlers.set(modelType, []); + } + this.handlers.get(modelType)?.push(handler); + } + + getHandler(modelType: ModelType): ((params: any) => Promise) | undefined { + const handlers = this.handlers.get(modelType); + if (!handlers?.length) { + return undefined; + } + return handlers[0]; + } + + call(modelType: ModelType, params: any): Promise { + const handler = this.getHandler(modelType); + if (!handler) { + throw new Error(`No handler found for model type: ${modelType}`); + } + return handler(params); + } } diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts index 0931969cb77..f004449420a 100644 --- a/packages/core/src/types.ts +++ b/packages/core/src/types.ts @@ -116,12 +116,12 @@ export interface Goal { /** * Model size/type classification */ -export enum ModelClass { +export enum ModelType { TEXT_SMALL = "text_small", TEXT_LARGE = "text_large", TEXT_EMBEDDING = "text_embedding", IMAGE = "image", - VISION = "vision", + IMAGE_DESCRIPTION = "image_description", TRANSCRIPTION = "transcription", TEXT_TO_SPEECH = "text_to_speech", SPEECH_TO_TEXT = "speech_to_text", @@ -561,14 +561,16 @@ export type Plugin = { /** Optional evaluators */ evaluators?: Evaluator[]; - /** Optional services */ - services?: Service[]; - /** Optional clients */ clients?: Client[]; /** Optional adapters */ adapters?: Adapter[]; + + /** Optional handlers */ + handlers?: { + [key: string]: (...args: any[]) => Promise; + }; }; export interface ModelConfiguration { @@ -887,36 +889,6 @@ export interface ICacheManager { delete(key: string): Promise; } -export abstract class Service { - private static instance: Service | null = null; - - static get serviceType(): ServiceType { - throw new Error("Service must implement static serviceType getter"); - } - - public static getInstance(): T { - if (!Service.instance) { - Service.instance = new (this as any)(); - } - return Service.instance as T; - } - - get serviceType(): ServiceType { - return (this.constructor as typeof Service).serviceType; - } - - // Add abstract initialize method that must be implemented by derived classes - abstract initialize(runtime: IAgentRuntime): Promise; -} - -export interface IModelManager { - generateText(params: GenerateTextParams): Promise<{ text: string }>; - generateObject(params: GenerateTextParams): Promise<{ object: any }>; - generateEmbedding(text: string): Promise; - generateImage(params: ImageModelSettings): Promise<{ images: string[] }>; - generateAudio(params: any): Promise<{ audio: string }>; -} - export interface IAgentRuntime { // Properties agentId: UUID; @@ -938,7 +910,6 @@ export interface IAgentRuntime { cacheManager: ICacheManager; - services: Map; clients: ClientInstance[]; initialize(): Promise; @@ -947,14 +918,8 @@ export interface IAgentRuntime { getMemoryManager(name: string): IMemoryManager | null; - getService(service: ServiceType): T | null; - - registerService(service: Service): void; - getSetting(key: string): string | null; - getModelManager(): IModelManager; - // Methods getConversationLength(): number; @@ -1001,156 +966,10 @@ export interface IAgentRuntime { ): Promise; updateRecentMessageState(state: State): Promise; -} - -export interface IImageDescriptionService extends Service { - describeImage( - imageUrl: string - ): Promise<{ title: string; description: string }>; -} - -export interface ITranscriptionService extends Service { - transcribeAttachment(audioBuffer: ArrayBuffer): Promise; - transcribeAttachmentLocally(audioBuffer: ArrayBuffer): Promise; - transcribe(audioBuffer: ArrayBuffer): Promise; - transcribeLocally(audioBuffer: ArrayBuffer): Promise; -} - -export interface IVideoService extends Service { - isVideoUrl(url: string): boolean; - fetchVideoInfo(url: string): Promise; - downloadVideo(videoInfo: Media): Promise; - processVideo(url: string, runtime: IAgentRuntime): Promise; -} - -export interface ITextGenerationService extends Service { - initializeModel(): Promise; - queueMessageCompletion( - context: string, - temperature: number, - stop: string[], - frequency_penalty: number, - presence_penalty: number, - max_tokens: number - ): Promise; - queueTextCompletion( - context: string, - temperature: number, - stop: string[], - frequency_penalty: number, - presence_penalty: number, - max_tokens: number - ): Promise; - getEmbeddingResponse(input: string): Promise; -} - -export interface IBrowserService extends Service { - closeBrowser(): Promise; - getPageContent( - url: string, - runtime: IAgentRuntime - ): Promise<{ title: string; description: string; bodyContent: string }>; -} -export interface ISpeechService extends Service { - getInstance(): ISpeechService; - generate(runtime: IAgentRuntime, text: string): Promise; -} - -export interface IPdfService extends Service { - getInstance(): IPdfService; - convertPdfToText(pdfBuffer: Buffer): Promise; -} - -export interface IAwsS3Service extends Service { - uploadFile( - imagePath: string, - subDirectory: string, - useSignedUrl: boolean, - expiresIn: number - ): Promise<{ - success: boolean; - url?: string; - error?: string; - }>; - generateSignedUrl(fileName: string, expiresIn: number): Promise; -} - -export interface UploadIrysResult { - success: boolean; - url?: string; - error?: string; - data?: any; -} - -export interface DataIrysFetchedFromGQL { - success: boolean; - data: any; - error?: string; -} - -export interface GraphQLTag { - name: string; - values: any[]; -} - -export enum IrysMessageType { - REQUEST = "REQUEST", - DATA_STORAGE = "DATA_STORAGE", - REQUEST_RESPONSE = "REQUEST_RESPONSE", -} - -export enum IrysDataType { - FILE = "FILE", - IMAGE = "IMAGE", - OTHER = "OTHER", -} - -export interface IrysTimestamp { - from: number; - to: number; -} - -export interface IIrysService extends Service { - getDataFromAnAgent( - agentsWalletPublicKeys: string[], - tags: GraphQLTag[], - timestamp: IrysTimestamp - ): Promise; - workerUploadDataOnIrys( - data: any, - dataType: IrysDataType, - messageType: IrysMessageType, - serviceCategory: string[], - protocol: string[], - validationThreshold: number[], - minimumProviders: number[], - testProvider: boolean[], - reputation: number[] - ): Promise; - providerUploadDataOnIrys( - data: any, - dataType: IrysDataType, - serviceCategory: string[], - protocol: string[] - ): Promise; -} - -export interface ITeeLogService extends Service { - getInstance(): ITeeLogService; - log( - agentId: string, - roomId: string, - userId: string, - type: string, - content: string - ): Promise; -} - -export enum ServiceType { - BROWSER = "browser", - PDF = "pdf", - STORAGE = "STORAGE", + call(modelType: ModelType, params: any): Promise; + registerHandler(modelType: ModelType, handler: (params: any) => Promise): void; + getHandler(modelType: ModelType): ((params: any) => Promise) | undefined; } export enum LoggingLevel { @@ -1184,7 +1003,8 @@ export interface ChunkRow { } export type GenerateTextParams = { + runtime: IAgentRuntime; context: string; - modelClass: ModelClass; - stop?: string[]; + modelType: ModelType; + stopSequences?: string[]; }; diff --git a/packages/core/src/utils.ts b/packages/core/src/utils.ts deleted file mode 100644 index 394087f04bf..00000000000 --- a/packages/core/src/utils.ts +++ /dev/null @@ -1,4 +0,0 @@ -export { embed } from "./embedding.ts"; -export { logger } from "./logger.ts"; -export { AgentRuntime } from "./runtime.ts"; - diff --git a/packages/core/src/uuid.ts b/packages/core/src/uuid.ts index 4cc8f09e700..180484ad810 100644 --- a/packages/core/src/uuid.ts +++ b/packages/core/src/uuid.ts @@ -5,55 +5,55 @@ import type { UUID } from "./types.ts"; export const uuidSchema = z.string().uuid() as z.ZodType; export function validateUuid(value: unknown): UUID | null { - const result = uuidSchema.safeParse(value); - return result.success ? result.data : null; + const result = uuidSchema.safeParse(value); + return result.success ? result.data : null; } export function stringToUuid(target: string | number): UUID { - if (typeof target === "number") { - target = (target as number).toString(); + if (typeof target === "number") { + target = (target as number).toString(); + } + + if (typeof target !== "string") { + throw TypeError("Value must be string"); + } + + const _uint8ToHex = (ubyte: number): string => { + const first = ubyte >> 4; + const second = ubyte - (first << 4); + const HEX_DIGITS = "0123456789abcdef".split(""); + return HEX_DIGITS[first] + HEX_DIGITS[second]; + }; + + const _uint8ArrayToHex = (buf: Uint8Array): string => { + let out = ""; + for (let i = 0; i < buf.length; i++) { + out += _uint8ToHex(buf[i]); } - - if (typeof target !== "string") { - throw TypeError("Value must be string"); - } - - const _uint8ToHex = (ubyte: number): string => { - const first = ubyte >> 4; - const second = ubyte - (first << 4); - const HEX_DIGITS = "0123456789abcdef".split(""); - return HEX_DIGITS[first] + HEX_DIGITS[second]; - }; - - const _uint8ArrayToHex = (buf: Uint8Array): string => { - let out = ""; - for (let i = 0; i < buf.length; i++) { - out += _uint8ToHex(buf[i]); - } - return out; - }; - - const escapedStr = encodeURIComponent(target); - const buffer = new Uint8Array(escapedStr.length); - for (let i = 0; i < escapedStr.length; i++) { - buffer[i] = escapedStr[i].charCodeAt(0); - } - - const hash = sha1(buffer); - const hashBuffer = new Uint8Array(hash.length / 2); - for (let i = 0; i < hash.length; i += 2) { - hashBuffer[i / 2] = Number.parseInt(hash.slice(i, i + 2), 16); - } - - return (_uint8ArrayToHex(hashBuffer.slice(0, 4)) + - "-" + - _uint8ArrayToHex(hashBuffer.slice(4, 6)) + - "-" + - _uint8ToHex(hashBuffer[6] & 0x0f) + - _uint8ToHex(hashBuffer[7]) + - "-" + - _uint8ToHex((hashBuffer[8] & 0x3f) | 0x80) + - _uint8ToHex(hashBuffer[9]) + - "-" + - _uint8ArrayToHex(hashBuffer.slice(10, 16))) as UUID; + return out; + }; + + const escapedStr = encodeURIComponent(target); + const buffer = new Uint8Array(escapedStr.length); + for (let i = 0; i < escapedStr.length; i++) { + buffer[i] = escapedStr[i].charCodeAt(0); + } + + const hash = sha1(buffer); + const hashBuffer = new Uint8Array(hash.length / 2); + for (let i = 0; i < hash.length; i += 2) { + hashBuffer[i / 2] = Number.parseInt(hash.slice(i, i + 2), 16); + } + + return (_uint8ArrayToHex(hashBuffer.slice(0, 4)) + + "-" + + _uint8ArrayToHex(hashBuffer.slice(4, 6)) + + "-" + + _uint8ToHex(hashBuffer[6] & 0x0f) + + _uint8ToHex(hashBuffer[7]) + + "-" + + _uint8ToHex((hashBuffer[8] & 0x3f) | 0x80) + + _uint8ToHex(hashBuffer[9]) + + "-" + + _uint8ArrayToHex(hashBuffer.slice(10, 16))) as UUID; } diff --git a/packages/plugin-bootstrap/__tests__/actions/continue.test.ts b/packages/plugin-bootstrap/__tests__/actions/continue.test.ts index f94981c8bfd..44ca3664bb3 100644 --- a/packages/plugin-bootstrap/__tests__/actions/continue.test.ts +++ b/packages/plugin-bootstrap/__tests__/actions/continue.test.ts @@ -1,6 +1,6 @@ import { describe, expect, it, vi, beforeEach } from 'vitest'; import { continueAction } from '../../src/actions/continue'; -import { composeContext, generateMessageResponse, generateTrueOrFalse, ModelClass } from '@elizaos/core'; +import { composeContext, generateMessageResponse, generateTrueOrFalse, ModelType } from '@elizaos/core'; vi.mock('@elizaos/core', () => ({ composeContext: vi.fn(), @@ -14,7 +14,7 @@ vi.mock('@elizaos/core', () => ({ }, messageCompletionFooter: '\nResponse format:\n```\n{"content": {"text": string}}\n```', booleanFooter: '\nResponse format: YES or NO', - ModelClass: { + ModelType: { SMALL: 'small', LARGE: 'large' } diff --git a/packages/plugin-bootstrap/__tests__/evaluators/fact.test.ts b/packages/plugin-bootstrap/__tests__/evaluators/fact.test.ts index 70c11bf1da7..724349f6abc 100644 --- a/packages/plugin-bootstrap/__tests__/evaluators/fact.test.ts +++ b/packages/plugin-bootstrap/__tests__/evaluators/fact.test.ts @@ -21,7 +21,7 @@ vi.mock('@elizaos/core', () => ({ } }) })), - ModelClass: { + ModelType: { SMALL: 'small' } })); diff --git a/packages/plugin-bootstrap/__tests__/evaluators/goal.test.ts b/packages/plugin-bootstrap/__tests__/evaluators/goal.test.ts index 2b278321f9d..cebb63d18e6 100644 --- a/packages/plugin-bootstrap/__tests__/evaluators/goal.test.ts +++ b/packages/plugin-bootstrap/__tests__/evaluators/goal.test.ts @@ -7,7 +7,7 @@ vi.mock('@elizaos/core', () => ({ generateText: vi.fn(), getGoals: vi.fn(), parseJsonArrayFromText: vi.fn(), - ModelClass: { + ModelType: { SMALL: 'small' } })); diff --git a/packages/plugin-bootstrap/src/actions/continue.ts b/packages/plugin-bootstrap/src/actions/continue.ts index c2635322f37..5f0b99e7a0c 100644 --- a/packages/plugin-bootstrap/src/actions/continue.ts +++ b/packages/plugin-bootstrap/src/actions/continue.ts @@ -8,7 +8,7 @@ import { type HandlerCallback, type IAgentRuntime, type Memory, - ModelClass, + ModelType, type State, } from "@elizaos/core"; @@ -167,7 +167,7 @@ export const continueAction: Action = { const response = await generateTrueOrFalse({ context: shouldRespondContext, - modelClass: ModelClass.TEXT_SMALL, + modelType: ModelType.TEXT_SMALL, runtime, }); @@ -194,7 +194,7 @@ export const continueAction: Action = { const response = await generateMessageResponse({ runtime, context, - modelClass: ModelClass.TEXT_LARGE, + modelType: ModelType.TEXT_LARGE, }); response.inReplyTo = message.id; diff --git a/packages/plugin-bootstrap/src/actions/followRoom.ts b/packages/plugin-bootstrap/src/actions/followRoom.ts index cdac3ad7226..bb6940444c9 100644 --- a/packages/plugin-bootstrap/src/actions/followRoom.ts +++ b/packages/plugin-bootstrap/src/actions/followRoom.ts @@ -6,7 +6,7 @@ import { type ActionExample, type IAgentRuntime, type Memory, - ModelClass, + ModelType, type State, } from "@elizaos/core"; @@ -67,7 +67,7 @@ export const followRoomAction: Action = { const response = await generateTrueOrFalse({ runtime, context: shouldFollowContext, - modelClass: ModelClass.TEXT_LARGE, + modelType: ModelType.TEXT_LARGE, }); return response; diff --git a/packages/plugin-bootstrap/src/actions/muteRoom.ts b/packages/plugin-bootstrap/src/actions/muteRoom.ts index d1ae5e5d4f6..847dfdb24f2 100644 --- a/packages/plugin-bootstrap/src/actions/muteRoom.ts +++ b/packages/plugin-bootstrap/src/actions/muteRoom.ts @@ -6,7 +6,7 @@ import { type ActionExample, type IAgentRuntime, type Memory, - ModelClass, + ModelType, type State, } from "@elizaos/core"; @@ -54,7 +54,7 @@ export const muteRoomAction: Action = { const response = await generateTrueOrFalse({ runtime, context: shouldMuteContext, - modelClass: ModelClass.TEXT_LARGE, + modelType: ModelType.TEXT_LARGE, }); return response; diff --git a/packages/plugin-bootstrap/src/actions/unfollowRoom.ts b/packages/plugin-bootstrap/src/actions/unfollowRoom.ts index 0598445aecb..9e0b346cb74 100644 --- a/packages/plugin-bootstrap/src/actions/unfollowRoom.ts +++ b/packages/plugin-bootstrap/src/actions/unfollowRoom.ts @@ -6,7 +6,7 @@ import { type ActionExample, type IAgentRuntime, type Memory, - ModelClass, + ModelType, type State, } from "@elizaos/core"; @@ -52,7 +52,7 @@ export const unfollowRoomAction: Action = { const response = await generateTrueOrFalse({ runtime, context: shouldUnfollowContext, - modelClass: ModelClass.TEXT_LARGE, + modelType: ModelType.TEXT_LARGE, }); return response; diff --git a/packages/plugin-bootstrap/src/actions/unmuteRoom.ts b/packages/plugin-bootstrap/src/actions/unmuteRoom.ts index 308cec076da..56d48e98c97 100644 --- a/packages/plugin-bootstrap/src/actions/unmuteRoom.ts +++ b/packages/plugin-bootstrap/src/actions/unmuteRoom.ts @@ -6,7 +6,7 @@ import { type ActionExample, type IAgentRuntime, type Memory, - ModelClass, + ModelType, type State, } from "@elizaos/core"; @@ -52,7 +52,7 @@ export const unmuteRoomAction: Action = { const response = generateTrueOrFalse({ context: shouldUnmuteContext, runtime, - modelClass: ModelClass.TEXT_LARGE, + modelType: ModelType.TEXT_LARGE, }); return response; diff --git a/packages/plugin-bootstrap/src/evaluators/fact.ts b/packages/plugin-bootstrap/src/evaluators/fact.ts index 4ed418569de..d0d6e540d2e 100644 --- a/packages/plugin-bootstrap/src/evaluators/fact.ts +++ b/packages/plugin-bootstrap/src/evaluators/fact.ts @@ -6,7 +6,7 @@ import { type ActionExample, type IAgentRuntime, type Memory, - ModelClass, + ModelType, type Evaluator, } from "@elizaos/core"; @@ -75,7 +75,7 @@ async function handler(runtime: IAgentRuntime, message: Memory) { const facts = await generateObjectArray({ runtime, context, - modelClass: ModelClass.TEXT_LARGE, + modelType: ModelType.TEXT_LARGE, schema: claimSchema, schemaName: "Fact", schemaDescription: "A fact about the user or the world", diff --git a/packages/plugin-bootstrap/src/evaluators/goal.ts b/packages/plugin-bootstrap/src/evaluators/goal.ts index be7a1f91f11..c527f8ff452 100644 --- a/packages/plugin-bootstrap/src/evaluators/goal.ts +++ b/packages/plugin-bootstrap/src/evaluators/goal.ts @@ -5,7 +5,7 @@ import { parseJsonArrayFromText } from "@elizaos/core"; import { type IAgentRuntime, type Memory, - ModelClass, + ModelType, type Goal, type State, type Evaluator, @@ -64,7 +64,7 @@ async function handler( const response = await generateText({ runtime, context, - modelClass: ModelClass.TEXT_LARGE, + modelType: ModelType.TEXT_LARGE, }); // Parse the JSON response to extract goal updates diff --git a/packages/plugin-bootstrap/src/providers/facts.ts b/packages/plugin-bootstrap/src/providers/facts.ts index 874ca6d21c6..6ce69cdbab0 100644 --- a/packages/plugin-bootstrap/src/providers/facts.ts +++ b/packages/plugin-bootstrap/src/providers/facts.ts @@ -1,10 +1,10 @@ +import type { Memory, Provider, State } from "@elizaos/core"; import { - embed, MemoryManager, + ModelType, formatMessages, type AgentRuntime as IAgentRuntime, } from "@elizaos/core"; -import type { Memory, Provider, State } from "@elizaos/core"; import { formatFacts } from "../evaluators/fact.ts"; const factsProvider: Provider = { @@ -16,22 +16,23 @@ const factsProvider: Provider = { actors: state?.actorsData, }); - const _embedding = await embed(runtime, recentMessages); + const embedding = await runtime.call(ModelType.TEXT_EMBEDDING, { + text: recentMessages, + }); const memoryManager = new MemoryManager({ runtime, tableName: "facts", }); - const relevantFacts = []; - // await memoryManager.searchMemoriesByEmbedding( - // embedding, - // { - // roomId: message.roomId, - // count: 10, - // agentId: runtime.agentId, - // } - // ); + const relevantFacts = await memoryManager.searchMemoriesByEmbedding( + embedding, + { + roomId: message.roomId, + count: 10, + agentId: runtime.agentId, + } + ); const recentFactsData = await memoryManager.getMemories({ roomId: message.roomId, diff --git a/packages/plugin-openai/src/index.ts b/packages/plugin-openai/src/index.ts index 8caaccb76ac..d63eb922cd6 100644 --- a/packages/plugin-openai/src/index.ts +++ b/packages/plugin-openai/src/index.ts @@ -1,38 +1,64 @@ import { createOpenAI } from "@ai-sdk/openai"; import type { Plugin } from "@elizaos/core"; -import { GenerateTextParams, ModelClass } from "@elizaos/core"; +import { GenerateTextParams, ModelType } from "@elizaos/core"; import { generateText as aiGenerateText } from "ai"; export const openaiPlugin: Plugin = { - modelHandlers: { - [ModelClass.TEXT_EMBEDDING]: (text: string) => { - return text; + name: "openai", + description: "OpenAI plugin", + handlers: { + [ModelType.TEXT_EMBEDDING]: async (text: string | null) => { + if (!text) { + // Return zero vector of appropriate length for model + return new Array(1536).fill(0); + } + + const baseURL = process.env.OPENAI_BASE_URL ?? "https://api.openai.com/v1"; + + // use fetch to call embedding endpoint + const response = await fetch(`${baseURL}/embeddings`, { + method: "POST", + headers: { + "Authorization": `Bearer ${process.env.OPENAI_API_KEY}`, + "Content-Type": "application/json" + }, + body: JSON.stringify({ + model: "text-embedding-3-small", + input: text, + }) + }); + + const data = await response.json(); + return data.data[0].embedding; }, - [ModelClass.TEXT_GENERATION]: ({ + [ModelType.TEXT_LARGE]: async ({ runtime, context, - modelClass, - stop, + modelType, + stopSequences = ["\n"], }: GenerateTextParams ) => { - const variables = runtime.getPluginData("openai"); - const modelConfiguration = variables?.modelConfig; - const temperature = variables?.temperature; - const frequency_penalty = modelConfiguration?.frequency_penalty; - const presence_penalty = modelConfiguration?.presence_penalty; - const max_response_length = modelConfiguration?.maxOutputTokens; + // TODO: pull variables from json + // const variables = runtime.getPluginData("openai"); + + const temperature = 0.7; + const frequency_penalty = 0.7; + const presence_penalty = 0.7; + const max_response_length = 8192; const baseURL = process.env.OPENAI_BASE_URL ?? "https://api.openai.com/v1"; - //logger.debug("OpenAI baseURL result:", { baseURL }); const openai = createOpenAI({ apiKey: process.env.OPENAI_API_KEY, baseURL, fetch: runtime.fetch, }); - - // get the model name from the modelClass - const model = modelClass.name; + + const smallModel = process.env.OPENAI_SMALL_MODEL ?? process.env.SMALL_MODEL ?? "gpt-4o-mini"; + const largeModel = process.env.OPENAI_LARGE_MODEL ?? process.env.LARGE_MODEL ?? "gpt-4o"; + + // get the model name from the modelType + const model = modelType === ModelType.TEXT_SMALL ? smallModel : largeModel; const { text: openaiResponse } = await aiGenerateText({ model: openai.languageModel(model), @@ -42,6 +68,7 @@ export const openaiPlugin: Plugin = { maxTokens: max_response_length, frequencyPenalty: frequency_penalty, presencePenalty: presence_penalty, + stopSequences: stopSequences, }); return openaiResponse; diff --git a/packages/plugin-sqlite/src/index.ts b/packages/plugin-sqlite/src/index.ts index db7160a4717..d664865779a 100644 --- a/packages/plugin-sqlite/src/index.ts +++ b/packages/plugin-sqlite/src/index.ts @@ -1,26 +1,26 @@ -import path from "node:path"; import fs from "node:fs"; +import path from "node:path"; -export * from "./sqliteTables.ts"; export * from "./sqlite_vec.ts"; +export * from "./sqliteTables.ts"; -import { - DatabaseAdapter, - logger, - type IDatabaseCacheAdapter, -} from "@elizaos/core"; import type { Account, Actor, - GoalStatus, - Participant, + Adapter, Goal, + GoalStatus, + IAgentRuntime, Memory, + Participant, + Plugin, Relationship, UUID, - Adapter, - IAgentRuntime, - Plugin, +} from "@elizaos/core"; +import { + DatabaseAdapter, + logger, + type IDatabaseCacheAdapter, } from "@elizaos/core"; import type { Database as BetterSqlite3Database } from "better-sqlite3"; import { v4 } from "uuid"; @@ -146,13 +146,7 @@ export class SqliteDatabaseAdapter if (row === null) { return null; } - return { - ...row, - details: - typeof row.details === "string" - ? JSON.parse(row.details) - : row.details, - }; + return row; }) .filter((row): row is Actor => row !== null); }