diff --git a/packages/agent/src/index.ts b/packages/agent/src/index.ts
index d4bdf8332cf..fdc4dc52599 100644
--- a/packages/agent/src/index.ts
+++ b/packages/agent/src/index.ts
@@ -6,15 +6,15 @@ import {
type Character,
type ClientInstance,
DbCacheAdapter,
- logger,
- FsCacheAdapter,
type IAgentRuntime,
type IDatabaseAdapter,
type IDatabaseCacheAdapter,
+ logger,
+ ModelType,
parseBooleanFromText,
settings,
stringToUuid,
- validateCharacterConfig,
+ validateCharacterConfig
} from "@elizaos/core";
import { bootstrapPlugin } from "@elizaos/plugin-bootstrap";
import fs from "node:fs";
@@ -349,10 +349,17 @@ export async function initializeClients(
clients.push(startedClient);
}
}
+ if (plugin.handlers) {
+ for (const [modelType, handler] of Object.entries(plugin.handlers)) {
+ runtime.registerHandler(modelType as ModelType, handler);
+ }
+ }
}
}
- return clients;
+ runtime.clients = clients;
+
+
}
export async function createAgent(
@@ -374,18 +381,6 @@ export async function createAgent(
});
}
-function initializeFsCache(baseDir: string, character: Character) {
- if (!character?.id) {
- throw new Error(
- "initializeFsCache requires id to be set in character definition"
- );
- }
- const cacheDir = path.resolve(baseDir, character.id, "cache");
-
- const cache = new CacheManager(new FsCacheAdapter(cacheDir));
- return cache;
-}
-
function initializeDbCache(character: Character, db: IDatabaseCacheAdapter) {
if (!character?.id) {
throw new Error(
@@ -412,15 +407,6 @@ function initializeCache(
"Database adapter is not provided for CacheStore.Database."
);
- case CacheStore.FILESYSTEM:
- logger.info("Using File System Cache...");
- if (!baseDir) {
- throw new Error(
- "baseDir must be provided for CacheStore.FILESYSTEM."
- );
- }
- return initializeFsCache(baseDir, character);
-
default:
throw new Error(
`Invalid cache store: ${cacheStore} or required configuration missing.`
@@ -428,7 +414,6 @@ function initializeCache(
}
}
-
async function findDatabaseAdapter(runtime: AgentRuntime) {
const { adapters } = runtime;
let adapter: Adapter | undefined;
@@ -449,8 +434,6 @@ async function findDatabaseAdapter(runtime: AgentRuntime) {
return adapterInterface;
}
-
-
async function startAgent(
character: Character,
characterServer: CharacterServer
@@ -482,7 +465,7 @@ async function startAgent(
await runtime.initialize();
// start assigned clients
- runtime.clients = await initializeClients(character, runtime);
+ await initializeClients(character, runtime);
// add to container
characterServer.registerAgent(runtime);
diff --git a/packages/agent/src/server.ts b/packages/agent/src/server.ts
index f1ea0bc001e..cef6c4e4980 100644
--- a/packages/agent/src/server.ts
+++ b/packages/agent/src/server.ts
@@ -5,13 +5,13 @@ import {
generateImage,
generateMessageResponse,
generateObject,
- getEmbeddingZeroVector,
messageCompletionFooter,
- ModelClass,
+ ModelType,
stringToUuid,
type Content,
type Media,
- type Memory
+ type Memory,
+ type IAgentRuntime
} from "@elizaos/core";
import bodyParser from "body-parser";
import cors from "cors";
@@ -166,7 +166,7 @@ export class CharacterServer {
return;
}
- const transcription = await runtime.getModelProviderManager().call(ModelClass.AUDIO_TRANSCRIPTION, {
+ const transcription = await runtime.getModelProviderManager().call(ModelType.AUDIO_TRANSCRIPTION, {
file: fs.createReadStream(audioFile.path),
model: "whisper-1",
});
@@ -276,7 +276,7 @@ export class CharacterServer {
const response = await generateMessageResponse({
runtime: runtime,
context,
- modelClass: ModelClass.TEXT_LARGE,
+ modelType: ModelType.TEXT_LARGE,
});
if (!response) {
@@ -286,13 +286,17 @@ export class CharacterServer {
return;
}
+ const zeroVector = runtime.getModelProviderManager().call(ModelType.EMBEDDING, {
+ text: null,
+ });
+
// save response to memory
const responseMessage: Memory = {
id: stringToUuid(`${messageId}-${runtime.agentId}`),
...userMessage,
userId: runtime.agentId,
content: response,
- embedding: getEmbeddingZeroVector(),
+ embedding: zeroVector,
createdAt: Date.now(),
};
@@ -488,7 +492,7 @@ export class CharacterServer {
const response = await generateObject({
runtime,
context,
- modelClass: ModelClass.TEXT_SMALL,
+ modelType: ModelType.TEXT_SMALL,
schema: hyperfiOutSchema,
});
@@ -791,7 +795,7 @@ export class CharacterServer {
const response = await generateMessageResponse({
runtime: runtime,
context,
- modelClass: ModelClass.TEXT_LARGE,
+ modelType: ModelType.TEXT_LARGE,
});
// save response to memory
@@ -824,7 +828,7 @@ export class CharacterServer {
// Get the text to convert to speech
const textToSpeak = response.text;
- const speechResponse = await runtime.getModelProviderManager().call(ModelClass.AUDIO_TRANSCRIPTION, {
+ const speechResponse = await runtime.getModelProviderManager().call(ModelType.AUDIO_TRANSCRIPTION, {
text: textToSpeak,
runtime,
});
diff --git a/packages/client/src/components/overview.tsx b/packages/client/src/components/overview.tsx
index 68b77baec22..7d9ff031126 100644
--- a/packages/client/src/components/overview.tsx
+++ b/packages/client/src/components/overview.tsx
@@ -14,7 +14,6 @@ export default function Overview({ character }: { character: Character }) {
-
{
- class MockFlagEmbedding {
-
-
- static async init() {
- return new MockFlagEmbedding();
- }
-
- async queryEmbed(text: string | string[]) {
- return [new Float32Array(384).fill(0.1)];
- }
- }
-
- return {
- FlagEmbedding: MockFlagEmbedding,
- EmbeddingModel: {
- BGESmallENV15: "BGE-small-en-v1.5",
- },
- };
-});
-
-// Mock fetch for remote embedding calls
-global.fetch = vi.fn();
-
-describe("Embedding Module", () => {
- let mockRuntime: IAgentRuntime;
-
- beforeEach(() => {
- // Reset all mocks
- vi.clearAllMocks();
-
- // Prepare a mock runtime
- mockRuntime = {
- agentId: "00000000-0000-0000-0000-000000000000" as `${string}-${string}-${string}-${string}-${string}`,
- providers: [],
- actions: [],
- evaluators: [],
- plugins: [],
- character: {
- name: "Test Character",
- username: "test",
- bio: ["Test bio"],
- lore: ["Test lore"],
- messageExamples: [],
- },
- getModelManager: () => (),
- messageManager: {
- getCachedEmbeddings: vi.fn().mockResolvedValue([])
- }
- } as unknown as IAgentRuntime;
-
- // Reset fetch mock with proper Response object
- const mockResponse = {
- ok: true,
- json: async () => ({
- data: [{ embedding: new Array(384).fill(0.1) }],
- }),
- headers: new Headers(),
- redirected: false,
- status: 200,
- statusText: "OK",
- type: "basic",
- url: "https://api.openai.com/v1/embeddings",
- body: null,
- bodyUsed: false,
- clone: () => ({} as Response),
- arrayBuffer: async () => new ArrayBuffer(0),
- blob: async () => new Blob(),
- formData: async () => new FormData(),
- text: async () => ""
- } as Response;
-
- vi.mocked(global.fetch).mockReset();
- vi.mocked(global.fetch).mockResolvedValue(mockResponse);
- });
-
- describe("getEmbeddingConfig", () => {
- test("should return OpenAI config by default", () => {
- const config = getEmbeddingConfig();
- expect(config.provider).toBe(EmbeddingProvider.OpenAI);
- expect(config.model).toBe("text-embedding-3-small");
- expect(config.dimensions).toBe(1536);
- });
-
- test("should use runtime provider when available", () => {
- const mockModelProvider = {
- provider: EmbeddingProvider.OpenAI,
- models: {
- [ModelClass.TEXT_EMBEDDING]: {
- name: "text-embedding-3-small",
- dimensions: 1536
- }
- }
- };
-
- const runtime = {
- getModelManager: () => mockModelProvider
- } as unknown as IAgentRuntime;
-
- const config = getEmbeddingConfig(runtime);
- expect(config.provider).toBe(EmbeddingProvider.OpenAI);
- expect(config.model).toBe("text-embedding-3-small");
- expect(config.dimensions).toBe(1536);
- });
- });
-
- describe("getEmbeddingType", () => {
- test("should return 'local' by default", () => {
- const type = getEmbeddingType(mockRuntime);
- expect(type).toBe("local");
- });
- });
-
- describe("getEmbeddingZeroVector", () => {
- test("should return 384-length zero vector by default (BGE)", () => {
- const vector = getEmbeddingZeroVector();
- expect(vector).toHaveLength(384);
- expect(vector.every((val) => val === 0)).toBe(true);
- });
-
- test("should return 1536-length zero vector for OpenAI if enabled", () => {
- const vector = getEmbeddingZeroVector();
- expect(vector).toHaveLength(1536);
- expect(vector.every((val) => val === 0)).toBe(true);
- });
- });
-
- describe("embed function", () => {
- test("should return an empty array for empty input text", async () => {
- const result = await embed(mockRuntime, "");
- expect(result).toEqual([]);
- });
-
- test("should return cached embedding if it already exists", async () => {
- const cachedEmbedding = new Array(384).fill(0.5);
- mockRuntime.messageManager.getCachedEmbeddings = vi
- .fn()
- .mockResolvedValue([{ embedding: cachedEmbedding }]);
-
- const result = await embed(mockRuntime, "test input");
- expect(result).toBe(cachedEmbedding);
- });
-
- test("should handle local embedding successfully", async () => {
- const result = await embed(mockRuntime, "test input");
- expect(result).toHaveLength(384);
- expect(result.every((v) => typeof v === "number")).toBe(true);
- });
-
- test("should handle remote embedding successfully", async () => {
- const result = await embed(mockRuntime, "test input");
- expect(result).toHaveLength(384);
- expect(vi.mocked(global.fetch)).toHaveBeenCalled();
- });
-
- test("should throw on remote embedding if fetch fails", async () => {
- // Mock fetch to reject
- vi.mocked(global.fetch).mockRejectedValueOnce(new Error("API Error"));
-
- await expect(embed(mockRuntime, "test input")).rejects.toThrow("API Error");
- });
-
- test("should handle concurrent embedding requests", async () => {
- const promises = Array(5)
- .fill(null)
- .map(() => embed(mockRuntime, "concurrent test"));
- await expect(Promise.all(promises)).resolves.toBeDefined();
- });
- });
-
- // Add tests for new embedding configurations
- describe("embedding configuration", () => {
- test("should handle embedding provider configuration", async () => {
- const mockModelProvider = {
- generateText: vi.fn(),
- generateObject: vi.fn(),
- generateImage: vi.fn(),
- generateEmbedding: vi.fn(),
- provider: EmbeddingProvider.OpenAI,
- models: {
- [ModelClass.TEXT_EMBEDDING]: {
- name: "text-embedding-3-small",
- dimensions: 1536
- }
- },
- getModelManager: () => mockModelProvider
- };
-
- const runtime = {
- agentId: "test-agent",
- databaseAdapter: {} as IDatabaseAdapter,
- getModelManager: () => mockModelProvider,
- } as unknown as IAgentRuntime;
-
- const config = getEmbeddingConfig(runtime);
- expect(config.provider).toBe(EmbeddingProvider.OpenAI);
- expect(config.model).toBe("text-embedding-3-small");
- expect(config.dimensions).toBe(1536);
- });
-
- test("should return default config when no runtime provided", () => {
- const config = getEmbeddingConfig();
- expect(config.provider).toBe(EmbeddingProvider.OpenAI);
- expect(config.model).toBe("text-embedding-3-small");
- expect(config.dimensions).toBe(1536);
- });
- });
-
- describe("embedding type detection", () => {
- test("should determine embedding type based on runtime configuration", () => {
- const mockRuntimeLocal = {
- ...mockRuntime,
- getSetting: (key: string) => "false"
- } as IAgentRuntime;
-
- expect(getEmbeddingType(mockRuntimeLocal)).toBe("local");
- });
- });
-});
diff --git a/packages/core/__tests__/knowledge.test.ts b/packages/core/__tests__/knowledge.test.ts
index cc03f98ea06..78ed60188b1 100644
--- a/packages/core/__tests__/knowledge.test.ts
+++ b/packages/core/__tests__/knowledge.test.ts
@@ -1,15 +1,7 @@
-import { describe, it, expect, vi, beforeEach } from "vitest";
+import { beforeEach, describe, expect, it, vi } from "vitest";
import knowledge from "../src/knowledge";
import type { AgentRuntime } from "../src/runtime";
-import { KnowledgeItem, type Memory } from "../src/types";
-
-// Mock dependencies
-vi.mock("../embedding", () => ({
- embed: vi.fn().mockResolvedValue(new Float32Array(1536).fill(0)),
- getEmbeddingZeroVector: vi
- .fn()
- .mockReturnValue(new Float32Array(1536).fill(0)),
-}));
+import { type Memory } from "../src/types";
vi.mock("../generation", () => ({
splitChunks: vi.fn().mockImplementation(async (text) => [text]),
diff --git a/packages/core/__tests__/runtime.test.ts b/packages/core/__tests__/runtime.test.ts
index f82d21a8dab..dc498298733 100644
--- a/packages/core/__tests__/runtime.test.ts
+++ b/packages/core/__tests__/runtime.test.ts
@@ -5,19 +5,11 @@ import {
type IDatabaseAdapter,
type IMemoryManager,
type Memory,
- ModelClass,
- ServiceType,
+ ModelType,
type UUID
} from "../src/types";
import { mockCharacter } from "./mockCharacter";
-// Mock the embedding module
-vi.mock("../src/embedding", () => ({
- embed: vi.fn().mockResolvedValue([0.1, 0.2, 0.3]),
- getRemoteEmbedding: vi.fn().mockResolvedValue(new Float32Array([0.1, 0.2, 0.3])),
- getLocalEmbedding: vi.fn().mockResolvedValue(new Float32Array([0.1, 0.2, 0.3]))
-}));
-
// Mock dependencies with minimal implementations
const mockDatabaseAdapter: IDatabaseAdapter = {
db: {},
@@ -86,9 +78,6 @@ describe("AgentRuntime", () => {
beforeEach(() => {
vi.clearAllMocks();
- const ModelManager = {
- getProvider: () => mockModelProvider,
- };
runtime = new AgentRuntime({
character: {
@@ -142,23 +131,9 @@ describe("AgentRuntime", () => {
});
});
- describe("service management", () => {
- it("should allow registering and retrieving services", async () => {
- const mockService = {
- serviceType: ServiceType.TEXT_GENERATION,
- type: ServiceType.TEXT_GENERATION,
- initialize: vi.fn().mockResolvedValue(undefined),
- };
-
- await runtime.registerService(mockService);
- const retrievedService = runtime.getService(ServiceType.TEXT_GENERATION);
- expect(retrievedService).toBe(mockService);
- });
- });
-
describe("model provider management", () => {
it("should provide access to the configured model provider", () => {
- const provider = runtime.getModelManager();
+ const provider = runtime;
expect(provider).toBeDefined();
});
});
@@ -239,17 +214,17 @@ describe("Model Provider Configuration", () => {
cacheManager: mockCacheManager,
});
- const provider = runtime.getModelManager();
+ const provider = runtime;
expect(provider.models.default).toBeDefined();
expect(provider.models.default.name).toBeDefined();
});
test("should handle missing optional model configurations", () => {
- const provider = runtime.getModelManager();
+ const provider = runtime;
// These might be undefined but shouldn't throw errors
- expect(() => provider.models[ModelClass.IMAGE]).not.toThrow();
- expect(() => provider.models[ModelClass.IMAGE_VISION]).not.toThrow();
+ expect(() => provider.models[ModelType.IMAGE_GENERATION]).not.toThrow();
+ expect(() => provider.models[ModelType.IMAGE_DESCRIPTION]).not.toThrow();
});
test("should validate model provider name format", () => {
@@ -260,7 +235,6 @@ describe("Model Provider Configuration", () => {
expect(() => new AgentRuntime({
character: {
...mockCharacter,
- modelProvider: invalidProvider,
bio: ["Test bio"], // Ensure bio is an array
lore: ["Test lore"], // Ensure lore is an array
messageExamples: [], // Required by Character type
@@ -285,7 +259,6 @@ describe("Model Provider Configuration", () => {
expect(() => new AgentRuntime({
character: {
...mockCharacter,
- modelProvider: validProvider,
bio: ["Test bio"], // Ensure bio is an array
lore: ["Test lore"], // Ensure lore is an array
messageExamples: [], // Required by Character type
@@ -298,7 +271,6 @@ describe("Model Provider Configuration", () => {
post: []
}
},
- modelProvider: validProvider,
databaseAdapter: mockDatabaseAdapter,
cacheManager: mockCacheManager,
})).not.toThrow();
@@ -307,22 +279,6 @@ describe("Model Provider Configuration", () => {
});
});
-describe("ModelManager", () => {
- test("should get correct model provider settings", async () => {
- const runtime = new AgentRuntime({
- databaseAdapter: mockDatabaseAdapter,
- cacheManager: {
- get: vi.fn(),
- set: vi.fn(),
- delete: vi.fn(),
- },
- });
-
- const provider = runtime.getModelManager();
- expect(provider).toBeDefined();
- });
-});
-
describe("MemoryManagerService", () => {
test("should provide access to different memory managers", async () => {
const runtime = new AgentRuntime({
@@ -361,23 +317,4 @@ describe("MemoryManagerService", () => {
runtime.registerMemoryManager(customManager);
expect(runtime.getMemoryManager("custom")).toBe(customManager);
});
-});
-
-describe("ServiceManager", () => {
- test("should handle service registration and retrieval", async () => {
- const runtime = new AgentRuntime({
- databaseAdapter: mockDatabaseAdapter,
- cacheManager: mockCacheManager
- });
-
- const mockService = {
- serviceType: ServiceType.TEXT_GENERATION,
- type: ServiceType.TEXT_GENERATION,
- initialize: vi.fn().mockResolvedValue(undefined)
- };
-
- await runtime.registerService(mockService);
- const retrievedService = runtime.getService(ServiceType.TEXT_GENERATION);
- expect(retrievedService).toBe(mockService);
- });
-});
+});
\ No newline at end of file
diff --git a/packages/core/src/cache.ts b/packages/core/src/cache.ts
index 7cf2a38f9ce..ced8218cca9 100644
--- a/packages/core/src/cache.ts
+++ b/packages/core/src/cache.ts
@@ -33,39 +33,6 @@ export class MemoryCacheAdapter implements ICacheAdapter {
}
}
-export class FsCacheAdapter implements ICacheAdapter {
- constructor(private dataDir: string) {}
-
- async get(key: string): Promise {
- try {
- return await fs.readFile(path.join(this.dataDir, key), "utf8");
- } catch {
- // console.error(error);
- return undefined;
- }
- }
-
- async set(key: string, value: string): Promise {
- try {
- const filePath = path.join(this.dataDir, key);
- // Ensure the directory exists
- await fs.mkdir(path.dirname(filePath), { recursive: true });
- await fs.writeFile(filePath, value, "utf8");
- } catch (error) {
- console.error(error);
- }
- }
-
- async delete(key: string): Promise {
- try {
- const filePath = path.join(this.dataDir, key);
- await fs.unlink(filePath);
- } catch {
- // console.error(error);
- }
- }
-}
-
export class DbCacheAdapter implements ICacheAdapter {
constructor(
private db: IDatabaseCacheAdapter,
diff --git a/packages/core/src/embedding.ts b/packages/core/src/embedding.ts
deleted file mode 100644
index a444f252855..00000000000
--- a/packages/core/src/embedding.ts
+++ /dev/null
@@ -1,69 +0,0 @@
-// TODO: Maybe create these functions to read from character settings or env
-// import { getEmbeddingModelSettings, getEndpoint } from "./models.ts";
-import logger from "./logger.ts";
-import { type IAgentRuntime } from "./types.ts";
-
-export type EmbeddingConfig = {
- readonly dimensions: number;
- readonly model: string;
- readonly provider: string;
-};
-
-/**
- * Gets embeddings from a remote API endpoint. Falls back to local BGE/384
- *
- * @param {string} input - The text to generate embeddings for
- * @param {EmbeddingOptions} options - Configuration options including:
- * - model: The model name to use
- * - endpoint: Base API endpoint URL
- * - apiKey: Optional API key for authentication
- * - isOllama: Whether this is an Ollama endpoint
- * - dimensions: Desired embedding dimensions
- * @param {IAgentRuntime} runtime - The agent runtime context
- * @returns {Promise} Array of embedding values
- * @throws {Error} If the API request fails
- */
-
-export async function embed(runtime: IAgentRuntime, input: string) {
- logger.debug("Embedding request:", {
- input: `${input?.slice(0, 50)}...`,
- inputType: typeof input,
- inputLength: input?.length,
- isString: typeof input === "string",
- isEmpty: !input,
- });
-
- // Validate input
- if (!input || typeof input !== "string" || input.trim().length === 0) {
- logger.warn("Invalid embedding input:", {
- input,
- type: typeof input,
- length: input?.length,
- });
- return []; // Return empty embedding array
- }
-
- // Check cache first
- const cachedEmbedding = await retrieveCachedEmbedding(runtime, input);
- if (cachedEmbedding) return cachedEmbedding;
-
- // Use remote embedding
- return await runtime.getModelManager().generateEmbedding(input);
-
- async function retrieveCachedEmbedding(
- runtime: IAgentRuntime,
- input: string
- ) {
- if (!input) {
- logger.log("No input to retrieve cached embedding for");
- return null;
- }
-
- const similaritySearchResult =
- await runtime.messageManager.getCachedEmbeddings(input);
- if (similaritySearchResult.length > 0) {
- return similaritySearchResult[0].embedding;
- }
- return null;
- }
-}
diff --git a/packages/core/src/generation.ts b/packages/core/src/generation.ts
index 43fc70d3b27..c87f9f1abe5 100644
--- a/packages/core/src/generation.ts
+++ b/packages/core/src/generation.ts
@@ -5,13 +5,13 @@ import { parseJSONObjectFromText } from "./parsing.ts";
import {
type Content,
type IAgentRuntime,
- ModelClass
+ ModelType
} from "./types.ts";
interface GenerateObjectOptions {
runtime: IAgentRuntime;
context: string;
- modelClass: ModelClass;
+ modelType: ModelType;
output?: "object" | "array" | "enum" | "no-schema" | undefined;
schema?: ZodSchema;
schemaName?: string;
@@ -69,21 +69,20 @@ async function withRetry(
export async function generateText({
runtime,
context,
- modelClass = ModelClass.TEXT_SMALL,
+ modelType = ModelType.TEXT_SMALL,
stopSequences,
}: {
runtime: IAgentRuntime;
context: string;
- modelClass: ModelClass;
+ modelType: ModelType;
stopSequences?: string[];
customSystemPrompt?: string;
}): Promise {
logFunctionCall("generateText", runtime);
- const { text } = await runtime.getModelManager().generateText({
+ const { text } = await runtime.call(modelType, {
context,
- modelClass,
- stop: stopSequences,
+ stopSequences,
});
return text;
@@ -92,12 +91,12 @@ export async function generateText({
export async function generateTextArray({
runtime,
context,
- modelClass = ModelClass.TEXT_SMALL,
+ modelType = ModelType.TEXT_SMALL,
stopSequences,
}: {
runtime: IAgentRuntime;
context: string;
- modelClass: ModelClass;
+ modelType: ModelType;
stopSequences?: string[];
}): Promise {
logFunctionCall("generateTextArray", runtime);
@@ -106,7 +105,7 @@ export async function generateTextArray({
const result = await generateObject({
runtime,
context,
- modelClass,
+ modelType,
schema: z.array(z.string()),
stopSequences,
});
@@ -120,14 +119,14 @@ export async function generateTextArray({
async function generateEnum({
runtime,
context,
- modelClass = ModelClass.TEXT_SMALL,
+ modelType = ModelType.TEXT_SMALL,
enumValues,
functionName,
stopSequences,
}: {
runtime: IAgentRuntime;
context: string;
- modelClass: ModelClass;
+ modelType: ModelType;
enumValues: Array;
functionName: string;
stopSequences?: string[];
@@ -142,7 +141,7 @@ async function generateEnum({
const result = await generateObject({
runtime,
context,
- modelClass,
+ modelType,
output: "enum",
enum: enumValues,
mode: "json",
@@ -159,12 +158,12 @@ async function generateEnum({
export async function generateShouldRespond({
runtime,
context,
- modelClass = ModelClass.TEXT_SMALL,
+ modelType = ModelType.TEXT_SMALL,
stopSequences,
}: {
runtime: IAgentRuntime;
context: string;
- modelClass: ModelClass;
+ modelType: ModelType;
stopSequences?: string[];
}): Promise<"RESPOND" | "IGNORE" | "STOP" | null> {
const RESPONSE_VALUES = ["RESPOND", "IGNORE", "STOP"] as string[];
@@ -172,7 +171,7 @@ export async function generateShouldRespond({
const result = await generateEnum({
runtime,
context,
- modelClass,
+ modelType,
enumValues: RESPONSE_VALUES,
functionName: "generateShouldRespond",
stopSequences,
@@ -184,12 +183,12 @@ export async function generateShouldRespond({
export async function generateTrueOrFalse({
runtime,
context = "",
- modelClass = ModelClass.TEXT_SMALL,
+ modelType = ModelType.TEXT_SMALL,
stopSequences,
}: {
runtime: IAgentRuntime;
context: string;
- modelClass: ModelClass;
+ modelType: ModelType;
stopSequences?: string[];
}): Promise {
logFunctionCall("generateTrueOrFalse", runtime);
@@ -199,7 +198,7 @@ export async function generateTrueOrFalse({
const result = await generateEnum({
runtime,
context,
- modelClass,
+ modelType,
enumValues: BOOL_VALUES,
functionName: "generateTrueOrFalse",
stopSequences,
@@ -212,7 +211,7 @@ export async function generateTrueOrFalse({
export const generateObject = async ({
runtime,
context,
- modelClass = ModelClass.TEXT_SMALL,
+ modelType = ModelType.TEXT_SMALL,
output = "object",
schema,
schemaName,
@@ -228,27 +227,26 @@ export const generateObject = async ({
throw new Error(errorMessage);
}
- const { object } = await runtime.getModelManager().generateObject({
+ const { object } = await runtime.call(modelType, {
context,
- modelClass,
stop: stopSequences,
});
- logger.debug(`Received Object response from ${modelClass} model.`);
+ logger.debug(`Received Object response from ${modelType} model.`);
return schema ? schema.parse(object) : object;
};
export async function generateObjectArray({
runtime,
context,
- modelClass = ModelClass.TEXT_SMALL,
+ modelType = ModelType.TEXT_SMALL,
schema,
schemaName,
schemaDescription,
}: {
runtime: IAgentRuntime;
context: string;
- modelClass: ModelClass;
+ modelType: ModelType;
schema?: ZodSchema;
schemaName?: string;
schemaDescription?: string;
@@ -261,7 +259,7 @@ export async function generateObjectArray({
const result = await generateObject({
runtime,
context,
- modelClass,
+ modelType,
output: "array",
schema,
schemaName,
@@ -274,12 +272,12 @@ export async function generateObjectArray({
export async function generateMessageResponse({
runtime,
context,
- modelClass = ModelClass.TEXT_SMALL,
+ modelType = ModelType.TEXT_SMALL,
stopSequences,
}: {
runtime: IAgentRuntime;
context: string;
- modelClass: ModelClass;
+ modelType: ModelType;
stopSequences?: string[];
}): Promise {
logFunctionCall("generateMessageResponse", runtime);
@@ -287,9 +285,8 @@ export async function generateMessageResponse({
logger.debug("Context:", context);
return await withRetry(async () => {
- const { text } = await runtime.getModelManager().generateText({
+ const { text } = await runtime.call(modelType, {
context,
- modelClass,
stop: stopSequences,
});
@@ -334,7 +331,7 @@ export const generateImage = async (
return await withRetry(
async () => {
- const result = await runtime.getModelManager().generateImage(data);
+ const result = await runtime.call(ModelType.IMAGE, data);
return {
success: true,
data: result.images,
@@ -357,13 +354,8 @@ export const generateCaption = async (
}> => {
logFunctionCall("generateCaption", runtime);
const { imageUrl } = data;
- const imageDescriptionService = runtime.getModelManager().describeImage(imageUrl);
-
- if (!imageDescriptionService) {
- throw new Error("Image description service not found");
- }
+ const resp = await runtime.call(ModelType.IMAGE_DESCRIPTION, imageUrl);
- const resp = await imageDescriptionService.describeImage(imageUrl);
return {
title: resp.title.trim(),
description: resp.description.trim(),
diff --git a/packages/core/src/helper.ts b/packages/core/src/helper.ts
index 25c9479a741..7176680b81a 100644
--- a/packages/core/src/helper.ts
+++ b/packages/core/src/helper.ts
@@ -1,18 +1,16 @@
-import { AutoTokenizer } from "@huggingface/transformers";
import { encodingForModel, type TiktokenModel } from "js-tiktoken";
-import logger from "./logger.ts";
-import { TokenizerType, type IAgentRuntime, type ModelSettings } from "./types.ts";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
+import logger from "./logger.ts";
+import { type IAgentRuntime, type ModelSettings } from "./types.ts";
export function logFunctionCall(functionName: string, runtime?: IAgentRuntime) {
logger.info(`Function call: ${functionName}`, {
functionName,
- // runtime: JSON.stringify(runtime?.getModelManager())
+ // runtime: JSON.stringify(runtime?)
});
}
-
export async function trimTokens(
context: string,
maxTokens: number,
@@ -30,52 +28,16 @@ export async function trimTokens(
return truncateTiktoken("gpt-4o", context, maxTokens);
}
- // Choose the truncation method based on tokenizer type
- if (tokenizerType === TokenizerType.Auto) {
- return truncateAuto(tokenizerModel, context, maxTokens);
- }
-
- if (tokenizerType === TokenizerType.TikToken) {
- return truncateTiktoken(
- tokenizerModel as TiktokenModel,
- context,
- maxTokens
- );
- }
+ return truncateTiktoken(
+ tokenizerModel as TiktokenModel,
+ context,
+ maxTokens
+ );
logger.warn(`Unsupported tokenizer type: ${tokenizerType}`);
return truncateTiktoken("gpt-4o", context, maxTokens);
}
-
-
-
-async function truncateAuto(
- modelPath: string,
- context: string,
- maxTokens: number
-) {
- try {
- const tokenizer = await AutoTokenizer.from_pretrained(modelPath);
- const tokens = tokenizer.encode(context);
-
- // If already within limits, return unchanged
- if (tokens.length <= maxTokens) {
- return context;
- }
-
- // Keep the most recent tokens by slicing from the end
- const truncatedTokens = tokens.slice(-maxTokens);
-
- // Decode back to text - js-tiktoken decode() returns a string directly
- return tokenizer.decode(truncatedTokens);
- } catch (error) {
- logger.error("Error in trimTokens:", error);
- // Return truncated string if tokenization fails
- return context.slice(-maxTokens * 4); // Rough estimate of 4 chars per token
- }
-}
-
async function truncateTiktoken(
model: TiktokenModel,
context: string,
@@ -104,7 +66,6 @@ async function truncateTiktoken(
}
}
-
export async function splitChunks(
content: string,
chunkSize = 512,
@@ -126,7 +87,6 @@ export async function splitChunks(
return chunks;
}
-
export function getModelSettings(modelSettings: Record) {
if (!modelSettings) {
throw new Error("MODEL_SETTINGS is not defined");
diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts
index 6e680f26e48..145983b7a9d 100644
--- a/packages/core/src/index.ts
+++ b/packages/core/src/index.ts
@@ -4,7 +4,6 @@ export * from "./actions.ts";
export * from "./cache.ts";
export * from "./context.ts";
export * from "./database.ts";
-export * from "./embedding.ts";
export * from "./environment.ts";
export * from "./evaluators.ts";
export * from "./generation.ts";
@@ -21,5 +20,4 @@ export * from "./relationships.ts";
export * from "./runtime.ts";
export * from "./settings.ts";
export * from "./types.ts";
-export * from "./utils.ts";
export * from "./uuid.ts";
diff --git a/packages/core/src/knowledge.ts b/packages/core/src/knowledge.ts
index 7dd9857c685..c6dda472a4c 100644
--- a/packages/core/src/knowledge.ts
+++ b/packages/core/src/knowledge.ts
@@ -1,9 +1,8 @@
-import type { AgentRuntime } from "./runtime.ts";
-import { embed, getEmbeddingZeroVector } from "./embedding.ts";
-import type { KnowledgeItem, UUID, Memory } from "./types.ts";
-import { stringToUuid } from "./uuid.ts";
import { splitChunks } from "./helper.ts";
import logger from "./logger.ts";
+import type { AgentRuntime } from "./runtime.ts";
+import { type KnowledgeItem, type Memory, ModelType, type UUID } from "./types.ts";
+import { stringToUuid } from "./uuid.ts";
async function get(
runtime: AgentRuntime,
@@ -32,7 +31,9 @@ async function get(
return [];
}
- const embedding = await embed(runtime, processed);
+ const embedding = await runtime.call(ModelType.TEXT_EMBEDDING, {
+ text: processed,
+ });
const fragments = await runtime.knowledgeManager.searchMemoriesByEmbedding(
embedding,
{
@@ -70,6 +71,9 @@ async function set(
chunkSize = 512,
bleed = 20
) {
+ const embedding = await runtime.call(ModelType.TEXT_EMBEDDING, {
+ text: null,
+ });
await runtime.documentsManager.createMemory({
id: item.id,
agentId: runtime.agentId,
@@ -77,14 +81,16 @@ async function set(
userId: runtime.agentId,
createdAt: Date.now(),
content: item.content,
- embedding: getEmbeddingZeroVector(),
+ embedding: embedding,
});
const preprocessed = preprocess(item.content.text);
const fragments = await splitChunks(preprocessed, chunkSize, bleed);
for (const fragment of fragments) {
- const embedding = await embed(runtime, fragment);
+ const embedding = await runtime.call(ModelType.TEXT_EMBEDDING, {
+ text: fragment,
+ });
await runtime.knowledgeManager.createMemory({
// We namespace the knowledge base uuid to avoid id
// collision with the document above.
diff --git a/packages/core/src/memory.ts b/packages/core/src/memory.ts
index c51d54e691b..bd564662182 100644
--- a/packages/core/src/memory.ts
+++ b/packages/core/src/memory.ts
@@ -1,10 +1,10 @@
-import { embed, getEmbeddingZeroVector } from "./embedding.ts";
import logger from "./logger.ts";
-import type {
- IAgentRuntime,
- IMemoryManager,
- Memory,
- UUID,
+import {
+ ModelType,
+ type IAgentRuntime,
+ type IMemoryManager,
+ type Memory,
+ type UUID,
} from "./types.ts";
const defaultMatchThreshold = 0.1;
@@ -66,11 +66,15 @@ export class MemoryManager implements IMemoryManager {
try {
// Generate embedding from text content
- memory.embedding = await embed(this.runtime, memoryText);
+ memory.embedding = await this.runtime.call(ModelType.TEXT_EMBEDDING, {
+ text: memoryText,
+ });
} catch (error) {
logger.error("Failed to generate embedding:", error);
// Fallback to zero vector if embedding fails
- memory.embedding = getEmbeddingZeroVector().slice();
+ memory.embedding = await this.runtime.call(ModelType.TEXT_EMBEDDING, {
+ text: null,
+ });
}
return memory;
@@ -140,6 +144,7 @@ export class MemoryManager implements IMemoryManager {
match_threshold?: number;
count?: number;
roomId: UUID;
+ agentId: UUID;
unique?: boolean;
}
): Promise {
diff --git a/packages/core/src/runtime.ts b/packages/core/src/runtime.ts
index f3907106435..3c4b21a5692 100644
--- a/packages/core/src/runtime.ts
+++ b/packages/core/src/runtime.ts
@@ -1,6 +1,3 @@
-import { glob } from "glob";
-import { existsSync } from "node:fs";
-import { readFile } from "node:fs/promises";
import { join } from "node:path";
import { names, uniqueNamesGenerator } from "unique-names-generator";
import { v4 as uuidv4 } from "uuid";
@@ -33,24 +30,18 @@ import {
type Character,
type ClientInstance,
type DirectoryItem,
- type EmbeddingModelSettings,
type Evaluator,
- GenerateTextParams,
type Goal,
type HandlerCallback,
type IAgentRuntime,
type ICacheManager,
type IDatabaseAdapter,
- type ImageModelSettings,
type IMemoryManager,
- IModelManager,
type KnowledgeItem,
type Memory,
- ModelClass,
+ ModelType,
type Plugin,
type Provider,
- type Service,
- type ServiceType,
type State,
type UUID
} from "./types.ts";
@@ -125,131 +116,6 @@ class KnowledgeManager {
}
}
-/**
- * Manages model provider settings and configuration
- */
-class ModelManager implements IModelManager {
- private runtime: AgentRuntime;
-
- constructor(runtime: AgentRuntime) {
- this.runtime = runtime;
- }
-
- modelHandlers = new Map Promise)[]>();
-
- registerModelHandler(modelClass: ModelClass, handler: (params: any) => Promise) {
- if (!this.modelHandlers.has(modelClass)) {
- this.modelHandlers.set(modelClass, []);
- }
- this.modelHandlers.get(modelClass)?.push(handler);
- }
-
- getModelHandler(modelClass: ModelClass): ((params: any) => Promise) | undefined {
- const handlers = this.modelHandlers.get(modelClass);
- if (!handlers?.length) {
- return undefined;
- }
- return handlers[0];
- }
-
- getApiKey() {
- return this.runtime.getSetting("API_KEY");
- }
-
- // TODO: This could be more of a pure registration handler
- generateText (params: GenerateTextParams) {
- const handler = this.getModelHandler(params.modelClass);
- if (!handler) {
- throw new Error(`No handler found for ${params.modelClass}`);
- }
- return handler(params);
- }
-
- generateEmbedding (text: string) {
- const handler = this.getModelHandler(ModelClass.TEXT_EMBEDDING);
- if (!handler) {
- throw new Error(`No handler found for ${ModelClass.TEXT_EMBEDDING}`);
- }
- return handler(text);
- }
-
- generateImage (params: ImageModelSettings) {
- const handler = this.getModelHandler(ModelClass.IMAGE);
- if (!handler) {
- throw new Error(`No handler found for ${ModelClass.IMAGE}`);
- }
- return handler(params);
- }
-
- generateAudio (params: any) {
- const handler = this.getModelHandler(ModelClass.AUDIO);
- if (!handler) {
- throw new Error(`No handler found for ${ModelClass.AUDIO}`);
- }
- return handler(params);
- }
-}
-
-/**
- * Manages services and their lifecycle
- */
-class ServiceManager {
- private runtime: AgentRuntime;
- private services: Map;
-
- constructor(runtime: AgentRuntime) {
- this.runtime = runtime;
- this.services = new Map();
- }
-
- async registerService(service: Service): Promise {
- const serviceType = service.serviceType;
- logger.log(`${this.runtime.character.name}(${this.runtime.agentId}) - Registering service:`, serviceType);
-
- if (this.services.has(serviceType)) {
- logger.warn(
- `${this.runtime.character.name}(${this.runtime.agentId}) - Service ${serviceType} is already registered. Skipping registration.`
- );
- return;
- }
-
- // Add the service to the services map
- this.services.set(serviceType, service);
- logger.success(`${this.runtime.character.name}(${this.runtime.agentId}) - Service ${serviceType} registered successfully`);
- }
-
- getService(service: ServiceType): T | null {
- const serviceInstance = this.services.get(service);
- if (!serviceInstance) {
- logger.error(`Service ${service} not found`);
- return null;
- }
- return serviceInstance as T;
- }
-
- async initializeServices(): Promise {
- for (const [serviceType, service] of this.services.entries()) {
- try {
- await service.initialize(this.runtime);
- this.services.set(serviceType, service);
- logger.success(
- `${this.runtime.character.name}(${this.runtime.agentId}) - Service ${serviceType} initialized successfully`
- );
- } catch (error) {
- logger.error(
- `${this.runtime.character.name}(${this.runtime.agentId}) - Failed to initialize service ${serviceType}:`,
- error
- );
- throw error;
- }
- }
- }
-
- getAllServices(): Map {
- return this.services;
- }
-}
-
/**
* Manages memory-related operations and memory managers
*/
@@ -372,15 +238,14 @@ export class AgentRuntime implements IAgentRuntime {
readonly fetch = fetch;
public cacheManager!: ICacheManager;
public clients: ClientInstance[] = [];
- readonly services: Map;
public adapters: Adapter[];
private readonly knowledgeRoot: string;
- private readonly modelManager: IModelManager;
- private readonly serviceManager: ServiceManager;
private readonly memoryManagerService: MemoryManagerService;
+ handlers = new Map Promise)[]>();
+
constructor(opts: {
conversationLength?: number;
agentId?: UUID;
@@ -390,7 +255,6 @@ export class AgentRuntime implements IAgentRuntime {
evaluators?: Evaluator[];
plugins?: Plugin[];
providers?: Provider[];
- services?: Service[];
managers?: IMemoryManager[];
databaseAdapter?: IDatabaseAdapter;
fetch?: typeof fetch;
@@ -435,11 +299,6 @@ export class AgentRuntime implements IAgentRuntime {
this.cacheManager = opts.cacheManager;
}
- this.services = new Map();
-
- // Initialize managers - moved ModelManager initialization earlier
- this.modelManager = new ModelManager(this);
- this.serviceManager = new ServiceManager(this);
this.memoryManagerService = new MemoryManagerService(this, this.knowledgeRoot);
// Register additional memory managers from options
@@ -449,21 +308,6 @@ export class AgentRuntime implements IAgentRuntime {
}
}
- // Register services from options and plugins
- if (opts.services) {
- for (const service of opts.services) {
- this.registerService(service);
- }
- }
-
- for (const plugin of this.plugins) {
- if (plugin.services) {
- for (const service of plugin.services) {
- this.registerService(service);
- }
- }
- }
-
this.plugins = [
...(opts.character?.plugins ?? []),
...(opts.plugins ?? []),
@@ -478,10 +322,6 @@ export class AgentRuntime implements IAgentRuntime {
this.registerEvaluator(evaluator);
}
- for (const service of (plugin.services ?? [])) {
- this.registerService(service);
- }
-
for (const provider of (plugin.providers ?? [])) {
this.registerContextProvider(provider);
}
@@ -503,10 +343,6 @@ export class AgentRuntime implements IAgentRuntime {
this.adapters = opts.adapters ?? [];
}
- getModelManager(): IModelManager {
- return this.modelManager;
- }
-
async initialize() {
await this.ensureRoomExists(this.agentId);
await this.ensureUserExists(
@@ -516,8 +352,6 @@ export class AgentRuntime implements IAgentRuntime {
);
await this.ensureParticipantExists(this.agentId, this.agentId);
- await this.serviceManager.initializeServices();
-
if (this.character?.knowledge && this.character.knowledge.length > 0) {
// Non-RAG mode: only process string knowledge
const stringKnowledge = this.character.knowledge.filter(
@@ -730,7 +564,7 @@ export class AgentRuntime implements IAgentRuntime {
const result = await generateText({
runtime: this,
context,
- modelClass: ModelClass.TEXT_SMALL,
+ modelType: ModelType.TEXT_SMALL,
});
const evaluators = parseJsonArrayFromText(
@@ -1366,14 +1200,6 @@ Text: ${attachment.text}
return this.memoryManagerService.getMemoryManager(tableName);
}
- getService(service: ServiceType): T | null {
- return this.serviceManager.getService(service);
- }
-
- async registerService(service: Service): Promise {
- await this.serviceManager.registerService(service);
- }
-
// Memory manager getters
get messageManager(): IMemoryManager {
return this.memoryManagerService.getMessageManager();
@@ -1394,4 +1220,27 @@ Text: ${attachment.text}
get knowledgeManager(): IMemoryManager {
return this.memoryManagerService.getKnowledgeManager();
}
+
+ registerHandler(modelType: ModelType, handler: (params: any) => Promise) {
+ if (!this.handlers.has(modelType)) {
+ this.handlers.set(modelType, []);
+ }
+ this.handlers.get(modelType)?.push(handler);
+ }
+
+ getHandler(modelType: ModelType): ((params: any) => Promise) | undefined {
+ const handlers = this.handlers.get(modelType);
+ if (!handlers?.length) {
+ return undefined;
+ }
+ return handlers[0];
+ }
+
+ call(modelType: ModelType, params: any): Promise {
+ const handler = this.getHandler(modelType);
+ if (!handler) {
+ throw new Error(`No handler found for model type: ${modelType}`);
+ }
+ return handler(params);
+ }
}
diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts
index 0931969cb77..f004449420a 100644
--- a/packages/core/src/types.ts
+++ b/packages/core/src/types.ts
@@ -116,12 +116,12 @@ export interface Goal {
/**
* Model size/type classification
*/
-export enum ModelClass {
+export enum ModelType {
TEXT_SMALL = "text_small",
TEXT_LARGE = "text_large",
TEXT_EMBEDDING = "text_embedding",
IMAGE = "image",
- VISION = "vision",
+ IMAGE_DESCRIPTION = "image_description",
TRANSCRIPTION = "transcription",
TEXT_TO_SPEECH = "text_to_speech",
SPEECH_TO_TEXT = "speech_to_text",
@@ -561,14 +561,16 @@ export type Plugin = {
/** Optional evaluators */
evaluators?: Evaluator[];
- /** Optional services */
- services?: Service[];
-
/** Optional clients */
clients?: Client[];
/** Optional adapters */
adapters?: Adapter[];
+
+ /** Optional handlers */
+ handlers?: {
+ [key: string]: (...args: any[]) => Promise;
+ };
};
export interface ModelConfiguration {
@@ -887,36 +889,6 @@ export interface ICacheManager {
delete(key: string): Promise;
}
-export abstract class Service {
- private static instance: Service | null = null;
-
- static get serviceType(): ServiceType {
- throw new Error("Service must implement static serviceType getter");
- }
-
- public static getInstance(): T {
- if (!Service.instance) {
- Service.instance = new (this as any)();
- }
- return Service.instance as T;
- }
-
- get serviceType(): ServiceType {
- return (this.constructor as typeof Service).serviceType;
- }
-
- // Add abstract initialize method that must be implemented by derived classes
- abstract initialize(runtime: IAgentRuntime): Promise;
-}
-
-export interface IModelManager {
- generateText(params: GenerateTextParams): Promise<{ text: string }>;
- generateObject(params: GenerateTextParams): Promise<{ object: any }>;
- generateEmbedding(text: string): Promise;
- generateImage(params: ImageModelSettings): Promise<{ images: string[] }>;
- generateAudio(params: any): Promise<{ audio: string }>;
-}
-
export interface IAgentRuntime {
// Properties
agentId: UUID;
@@ -938,7 +910,6 @@ export interface IAgentRuntime {
cacheManager: ICacheManager;
- services: Map;
clients: ClientInstance[];
initialize(): Promise;
@@ -947,14 +918,8 @@ export interface IAgentRuntime {
getMemoryManager(name: string): IMemoryManager | null;
- getService(service: ServiceType): T | null;
-
- registerService(service: Service): void;
-
getSetting(key: string): string | null;
- getModelManager(): IModelManager;
-
// Methods
getConversationLength(): number;
@@ -1001,156 +966,10 @@ export interface IAgentRuntime {
): Promise;
updateRecentMessageState(state: State): Promise;
-}
-
-export interface IImageDescriptionService extends Service {
- describeImage(
- imageUrl: string
- ): Promise<{ title: string; description: string }>;
-}
-
-export interface ITranscriptionService extends Service {
- transcribeAttachment(audioBuffer: ArrayBuffer): Promise;
- transcribeAttachmentLocally(audioBuffer: ArrayBuffer): Promise;
- transcribe(audioBuffer: ArrayBuffer): Promise;
- transcribeLocally(audioBuffer: ArrayBuffer): Promise;
-}
-
-export interface IVideoService extends Service {
- isVideoUrl(url: string): boolean;
- fetchVideoInfo(url: string): Promise;
- downloadVideo(videoInfo: Media): Promise;
- processVideo(url: string, runtime: IAgentRuntime): Promise;
-}
-
-export interface ITextGenerationService extends Service {
- initializeModel(): Promise;
- queueMessageCompletion(
- context: string,
- temperature: number,
- stop: string[],
- frequency_penalty: number,
- presence_penalty: number,
- max_tokens: number
- ): Promise;
- queueTextCompletion(
- context: string,
- temperature: number,
- stop: string[],
- frequency_penalty: number,
- presence_penalty: number,
- max_tokens: number
- ): Promise;
- getEmbeddingResponse(input: string): Promise;
-}
-
-export interface IBrowserService extends Service {
- closeBrowser(): Promise;
- getPageContent(
- url: string,
- runtime: IAgentRuntime
- ): Promise<{ title: string; description: string; bodyContent: string }>;
-}
-export interface ISpeechService extends Service {
- getInstance(): ISpeechService;
- generate(runtime: IAgentRuntime, text: string): Promise;
-}
-
-export interface IPdfService extends Service {
- getInstance(): IPdfService;
- convertPdfToText(pdfBuffer: Buffer): Promise;
-}
-
-export interface IAwsS3Service extends Service {
- uploadFile(
- imagePath: string,
- subDirectory: string,
- useSignedUrl: boolean,
- expiresIn: number
- ): Promise<{
- success: boolean;
- url?: string;
- error?: string;
- }>;
- generateSignedUrl(fileName: string, expiresIn: number): Promise;
-}
-
-export interface UploadIrysResult {
- success: boolean;
- url?: string;
- error?: string;
- data?: any;
-}
-
-export interface DataIrysFetchedFromGQL {
- success: boolean;
- data: any;
- error?: string;
-}
-
-export interface GraphQLTag {
- name: string;
- values: any[];
-}
-
-export enum IrysMessageType {
- REQUEST = "REQUEST",
- DATA_STORAGE = "DATA_STORAGE",
- REQUEST_RESPONSE = "REQUEST_RESPONSE",
-}
-
-export enum IrysDataType {
- FILE = "FILE",
- IMAGE = "IMAGE",
- OTHER = "OTHER",
-}
-
-export interface IrysTimestamp {
- from: number;
- to: number;
-}
-
-export interface IIrysService extends Service {
- getDataFromAnAgent(
- agentsWalletPublicKeys: string[],
- tags: GraphQLTag[],
- timestamp: IrysTimestamp
- ): Promise;
- workerUploadDataOnIrys(
- data: any,
- dataType: IrysDataType,
- messageType: IrysMessageType,
- serviceCategory: string[],
- protocol: string[],
- validationThreshold: number[],
- minimumProviders: number[],
- testProvider: boolean[],
- reputation: number[]
- ): Promise;
- providerUploadDataOnIrys(
- data: any,
- dataType: IrysDataType,
- serviceCategory: string[],
- protocol: string[]
- ): Promise;
-}
-
-export interface ITeeLogService extends Service {
- getInstance(): ITeeLogService;
- log(
- agentId: string,
- roomId: string,
- userId: string,
- type: string,
- content: string
- ): Promise;
-}
-
-export enum ServiceType {
- BROWSER = "browser",
- PDF = "pdf",
- STORAGE = "STORAGE",
+ call(modelType: ModelType, params: any): Promise;
+ registerHandler(modelType: ModelType, handler: (params: any) => Promise): void;
+ getHandler(modelType: ModelType): ((params: any) => Promise) | undefined;
}
export enum LoggingLevel {
@@ -1184,7 +1003,8 @@ export interface ChunkRow {
}
export type GenerateTextParams = {
+ runtime: IAgentRuntime;
context: string;
- modelClass: ModelClass;
- stop?: string[];
+ modelType: ModelType;
+ stopSequences?: string[];
};
diff --git a/packages/core/src/utils.ts b/packages/core/src/utils.ts
deleted file mode 100644
index 394087f04bf..00000000000
--- a/packages/core/src/utils.ts
+++ /dev/null
@@ -1,4 +0,0 @@
-export { embed } from "./embedding.ts";
-export { logger } from "./logger.ts";
-export { AgentRuntime } from "./runtime.ts";
-
diff --git a/packages/core/src/uuid.ts b/packages/core/src/uuid.ts
index 4cc8f09e700..180484ad810 100644
--- a/packages/core/src/uuid.ts
+++ b/packages/core/src/uuid.ts
@@ -5,55 +5,55 @@ import type { UUID } from "./types.ts";
export const uuidSchema = z.string().uuid() as z.ZodType;
export function validateUuid(value: unknown): UUID | null {
- const result = uuidSchema.safeParse(value);
- return result.success ? result.data : null;
+ const result = uuidSchema.safeParse(value);
+ return result.success ? result.data : null;
}
export function stringToUuid(target: string | number): UUID {
- if (typeof target === "number") {
- target = (target as number).toString();
+ if (typeof target === "number") {
+ target = (target as number).toString();
+ }
+
+ if (typeof target !== "string") {
+ throw TypeError("Value must be string");
+ }
+
+ const _uint8ToHex = (ubyte: number): string => {
+ const first = ubyte >> 4;
+ const second = ubyte - (first << 4);
+ const HEX_DIGITS = "0123456789abcdef".split("");
+ return HEX_DIGITS[first] + HEX_DIGITS[second];
+ };
+
+ const _uint8ArrayToHex = (buf: Uint8Array): string => {
+ let out = "";
+ for (let i = 0; i < buf.length; i++) {
+ out += _uint8ToHex(buf[i]);
}
-
- if (typeof target !== "string") {
- throw TypeError("Value must be string");
- }
-
- const _uint8ToHex = (ubyte: number): string => {
- const first = ubyte >> 4;
- const second = ubyte - (first << 4);
- const HEX_DIGITS = "0123456789abcdef".split("");
- return HEX_DIGITS[first] + HEX_DIGITS[second];
- };
-
- const _uint8ArrayToHex = (buf: Uint8Array): string => {
- let out = "";
- for (let i = 0; i < buf.length; i++) {
- out += _uint8ToHex(buf[i]);
- }
- return out;
- };
-
- const escapedStr = encodeURIComponent(target);
- const buffer = new Uint8Array(escapedStr.length);
- for (let i = 0; i < escapedStr.length; i++) {
- buffer[i] = escapedStr[i].charCodeAt(0);
- }
-
- const hash = sha1(buffer);
- const hashBuffer = new Uint8Array(hash.length / 2);
- for (let i = 0; i < hash.length; i += 2) {
- hashBuffer[i / 2] = Number.parseInt(hash.slice(i, i + 2), 16);
- }
-
- return (_uint8ArrayToHex(hashBuffer.slice(0, 4)) +
- "-" +
- _uint8ArrayToHex(hashBuffer.slice(4, 6)) +
- "-" +
- _uint8ToHex(hashBuffer[6] & 0x0f) +
- _uint8ToHex(hashBuffer[7]) +
- "-" +
- _uint8ToHex((hashBuffer[8] & 0x3f) | 0x80) +
- _uint8ToHex(hashBuffer[9]) +
- "-" +
- _uint8ArrayToHex(hashBuffer.slice(10, 16))) as UUID;
+ return out;
+ };
+
+ const escapedStr = encodeURIComponent(target);
+ const buffer = new Uint8Array(escapedStr.length);
+ for (let i = 0; i < escapedStr.length; i++) {
+ buffer[i] = escapedStr[i].charCodeAt(0);
+ }
+
+ const hash = sha1(buffer);
+ const hashBuffer = new Uint8Array(hash.length / 2);
+ for (let i = 0; i < hash.length; i += 2) {
+ hashBuffer[i / 2] = Number.parseInt(hash.slice(i, i + 2), 16);
+ }
+
+ return (_uint8ArrayToHex(hashBuffer.slice(0, 4)) +
+ "-" +
+ _uint8ArrayToHex(hashBuffer.slice(4, 6)) +
+ "-" +
+ _uint8ToHex(hashBuffer[6] & 0x0f) +
+ _uint8ToHex(hashBuffer[7]) +
+ "-" +
+ _uint8ToHex((hashBuffer[8] & 0x3f) | 0x80) +
+ _uint8ToHex(hashBuffer[9]) +
+ "-" +
+ _uint8ArrayToHex(hashBuffer.slice(10, 16))) as UUID;
}
diff --git a/packages/plugin-bootstrap/__tests__/actions/continue.test.ts b/packages/plugin-bootstrap/__tests__/actions/continue.test.ts
index f94981c8bfd..44ca3664bb3 100644
--- a/packages/plugin-bootstrap/__tests__/actions/continue.test.ts
+++ b/packages/plugin-bootstrap/__tests__/actions/continue.test.ts
@@ -1,6 +1,6 @@
import { describe, expect, it, vi, beforeEach } from 'vitest';
import { continueAction } from '../../src/actions/continue';
-import { composeContext, generateMessageResponse, generateTrueOrFalse, ModelClass } from '@elizaos/core';
+import { composeContext, generateMessageResponse, generateTrueOrFalse, ModelType } from '@elizaos/core';
vi.mock('@elizaos/core', () => ({
composeContext: vi.fn(),
@@ -14,7 +14,7 @@ vi.mock('@elizaos/core', () => ({
},
messageCompletionFooter: '\nResponse format:\n```\n{"content": {"text": string}}\n```',
booleanFooter: '\nResponse format: YES or NO',
- ModelClass: {
+ ModelType: {
SMALL: 'small',
LARGE: 'large'
}
diff --git a/packages/plugin-bootstrap/__tests__/evaluators/fact.test.ts b/packages/plugin-bootstrap/__tests__/evaluators/fact.test.ts
index 70c11bf1da7..724349f6abc 100644
--- a/packages/plugin-bootstrap/__tests__/evaluators/fact.test.ts
+++ b/packages/plugin-bootstrap/__tests__/evaluators/fact.test.ts
@@ -21,7 +21,7 @@ vi.mock('@elizaos/core', () => ({
}
})
})),
- ModelClass: {
+ ModelType: {
SMALL: 'small'
}
}));
diff --git a/packages/plugin-bootstrap/__tests__/evaluators/goal.test.ts b/packages/plugin-bootstrap/__tests__/evaluators/goal.test.ts
index 2b278321f9d..cebb63d18e6 100644
--- a/packages/plugin-bootstrap/__tests__/evaluators/goal.test.ts
+++ b/packages/plugin-bootstrap/__tests__/evaluators/goal.test.ts
@@ -7,7 +7,7 @@ vi.mock('@elizaos/core', () => ({
generateText: vi.fn(),
getGoals: vi.fn(),
parseJsonArrayFromText: vi.fn(),
- ModelClass: {
+ ModelType: {
SMALL: 'small'
}
}));
diff --git a/packages/plugin-bootstrap/src/actions/continue.ts b/packages/plugin-bootstrap/src/actions/continue.ts
index c2635322f37..5f0b99e7a0c 100644
--- a/packages/plugin-bootstrap/src/actions/continue.ts
+++ b/packages/plugin-bootstrap/src/actions/continue.ts
@@ -8,7 +8,7 @@ import {
type HandlerCallback,
type IAgentRuntime,
type Memory,
- ModelClass,
+ ModelType,
type State,
} from "@elizaos/core";
@@ -167,7 +167,7 @@ export const continueAction: Action = {
const response = await generateTrueOrFalse({
context: shouldRespondContext,
- modelClass: ModelClass.TEXT_SMALL,
+ modelType: ModelType.TEXT_SMALL,
runtime,
});
@@ -194,7 +194,7 @@ export const continueAction: Action = {
const response = await generateMessageResponse({
runtime,
context,
- modelClass: ModelClass.TEXT_LARGE,
+ modelType: ModelType.TEXT_LARGE,
});
response.inReplyTo = message.id;
diff --git a/packages/plugin-bootstrap/src/actions/followRoom.ts b/packages/plugin-bootstrap/src/actions/followRoom.ts
index cdac3ad7226..bb6940444c9 100644
--- a/packages/plugin-bootstrap/src/actions/followRoom.ts
+++ b/packages/plugin-bootstrap/src/actions/followRoom.ts
@@ -6,7 +6,7 @@ import {
type ActionExample,
type IAgentRuntime,
type Memory,
- ModelClass,
+ ModelType,
type State,
} from "@elizaos/core";
@@ -67,7 +67,7 @@ export const followRoomAction: Action = {
const response = await generateTrueOrFalse({
runtime,
context: shouldFollowContext,
- modelClass: ModelClass.TEXT_LARGE,
+ modelType: ModelType.TEXT_LARGE,
});
return response;
diff --git a/packages/plugin-bootstrap/src/actions/muteRoom.ts b/packages/plugin-bootstrap/src/actions/muteRoom.ts
index d1ae5e5d4f6..847dfdb24f2 100644
--- a/packages/plugin-bootstrap/src/actions/muteRoom.ts
+++ b/packages/plugin-bootstrap/src/actions/muteRoom.ts
@@ -6,7 +6,7 @@ import {
type ActionExample,
type IAgentRuntime,
type Memory,
- ModelClass,
+ ModelType,
type State,
} from "@elizaos/core";
@@ -54,7 +54,7 @@ export const muteRoomAction: Action = {
const response = await generateTrueOrFalse({
runtime,
context: shouldMuteContext,
- modelClass: ModelClass.TEXT_LARGE,
+ modelType: ModelType.TEXT_LARGE,
});
return response;
diff --git a/packages/plugin-bootstrap/src/actions/unfollowRoom.ts b/packages/plugin-bootstrap/src/actions/unfollowRoom.ts
index 0598445aecb..9e0b346cb74 100644
--- a/packages/plugin-bootstrap/src/actions/unfollowRoom.ts
+++ b/packages/plugin-bootstrap/src/actions/unfollowRoom.ts
@@ -6,7 +6,7 @@ import {
type ActionExample,
type IAgentRuntime,
type Memory,
- ModelClass,
+ ModelType,
type State,
} from "@elizaos/core";
@@ -52,7 +52,7 @@ export const unfollowRoomAction: Action = {
const response = await generateTrueOrFalse({
runtime,
context: shouldUnfollowContext,
- modelClass: ModelClass.TEXT_LARGE,
+ modelType: ModelType.TEXT_LARGE,
});
return response;
diff --git a/packages/plugin-bootstrap/src/actions/unmuteRoom.ts b/packages/plugin-bootstrap/src/actions/unmuteRoom.ts
index 308cec076da..56d48e98c97 100644
--- a/packages/plugin-bootstrap/src/actions/unmuteRoom.ts
+++ b/packages/plugin-bootstrap/src/actions/unmuteRoom.ts
@@ -6,7 +6,7 @@ import {
type ActionExample,
type IAgentRuntime,
type Memory,
- ModelClass,
+ ModelType,
type State,
} from "@elizaos/core";
@@ -52,7 +52,7 @@ export const unmuteRoomAction: Action = {
const response = generateTrueOrFalse({
context: shouldUnmuteContext,
runtime,
- modelClass: ModelClass.TEXT_LARGE,
+ modelType: ModelType.TEXT_LARGE,
});
return response;
diff --git a/packages/plugin-bootstrap/src/evaluators/fact.ts b/packages/plugin-bootstrap/src/evaluators/fact.ts
index 4ed418569de..d0d6e540d2e 100644
--- a/packages/plugin-bootstrap/src/evaluators/fact.ts
+++ b/packages/plugin-bootstrap/src/evaluators/fact.ts
@@ -6,7 +6,7 @@ import {
type ActionExample,
type IAgentRuntime,
type Memory,
- ModelClass,
+ ModelType,
type Evaluator,
} from "@elizaos/core";
@@ -75,7 +75,7 @@ async function handler(runtime: IAgentRuntime, message: Memory) {
const facts = await generateObjectArray({
runtime,
context,
- modelClass: ModelClass.TEXT_LARGE,
+ modelType: ModelType.TEXT_LARGE,
schema: claimSchema,
schemaName: "Fact",
schemaDescription: "A fact about the user or the world",
diff --git a/packages/plugin-bootstrap/src/evaluators/goal.ts b/packages/plugin-bootstrap/src/evaluators/goal.ts
index be7a1f91f11..c527f8ff452 100644
--- a/packages/plugin-bootstrap/src/evaluators/goal.ts
+++ b/packages/plugin-bootstrap/src/evaluators/goal.ts
@@ -5,7 +5,7 @@ import { parseJsonArrayFromText } from "@elizaos/core";
import {
type IAgentRuntime,
type Memory,
- ModelClass,
+ ModelType,
type Goal,
type State,
type Evaluator,
@@ -64,7 +64,7 @@ async function handler(
const response = await generateText({
runtime,
context,
- modelClass: ModelClass.TEXT_LARGE,
+ modelType: ModelType.TEXT_LARGE,
});
// Parse the JSON response to extract goal updates
diff --git a/packages/plugin-bootstrap/src/providers/facts.ts b/packages/plugin-bootstrap/src/providers/facts.ts
index 874ca6d21c6..6ce69cdbab0 100644
--- a/packages/plugin-bootstrap/src/providers/facts.ts
+++ b/packages/plugin-bootstrap/src/providers/facts.ts
@@ -1,10 +1,10 @@
+import type { Memory, Provider, State } from "@elizaos/core";
import {
- embed,
MemoryManager,
+ ModelType,
formatMessages,
type AgentRuntime as IAgentRuntime,
} from "@elizaos/core";
-import type { Memory, Provider, State } from "@elizaos/core";
import { formatFacts } from "../evaluators/fact.ts";
const factsProvider: Provider = {
@@ -16,22 +16,23 @@ const factsProvider: Provider = {
actors: state?.actorsData,
});
- const _embedding = await embed(runtime, recentMessages);
+ const embedding = await runtime.call(ModelType.TEXT_EMBEDDING, {
+ text: recentMessages,
+ });
const memoryManager = new MemoryManager({
runtime,
tableName: "facts",
});
- const relevantFacts = [];
- // await memoryManager.searchMemoriesByEmbedding(
- // embedding,
- // {
- // roomId: message.roomId,
- // count: 10,
- // agentId: runtime.agentId,
- // }
- // );
+ const relevantFacts = await memoryManager.searchMemoriesByEmbedding(
+ embedding,
+ {
+ roomId: message.roomId,
+ count: 10,
+ agentId: runtime.agentId,
+ }
+ );
const recentFactsData = await memoryManager.getMemories({
roomId: message.roomId,
diff --git a/packages/plugin-openai/src/index.ts b/packages/plugin-openai/src/index.ts
index 8caaccb76ac..d63eb922cd6 100644
--- a/packages/plugin-openai/src/index.ts
+++ b/packages/plugin-openai/src/index.ts
@@ -1,38 +1,64 @@
import { createOpenAI } from "@ai-sdk/openai";
import type { Plugin } from "@elizaos/core";
-import { GenerateTextParams, ModelClass } from "@elizaos/core";
+import { GenerateTextParams, ModelType } from "@elizaos/core";
import { generateText as aiGenerateText } from "ai";
export const openaiPlugin: Plugin = {
- modelHandlers: {
- [ModelClass.TEXT_EMBEDDING]: (text: string) => {
- return text;
+ name: "openai",
+ description: "OpenAI plugin",
+ handlers: {
+ [ModelType.TEXT_EMBEDDING]: async (text: string | null) => {
+ if (!text) {
+ // Return zero vector of appropriate length for model
+ return new Array(1536).fill(0);
+ }
+
+ const baseURL = process.env.OPENAI_BASE_URL ?? "https://api.openai.com/v1";
+
+ // use fetch to call embedding endpoint
+ const response = await fetch(`${baseURL}/embeddings`, {
+ method: "POST",
+ headers: {
+ "Authorization": `Bearer ${process.env.OPENAI_API_KEY}`,
+ "Content-Type": "application/json"
+ },
+ body: JSON.stringify({
+ model: "text-embedding-3-small",
+ input: text,
+ })
+ });
+
+ const data = await response.json();
+ return data.data[0].embedding;
},
- [ModelClass.TEXT_GENERATION]: ({
+ [ModelType.TEXT_LARGE]: async ({
runtime,
context,
- modelClass,
- stop,
+ modelType,
+ stopSequences = ["\n"],
}: GenerateTextParams
) => {
- const variables = runtime.getPluginData("openai");
- const modelConfiguration = variables?.modelConfig;
- const temperature = variables?.temperature;
- const frequency_penalty = modelConfiguration?.frequency_penalty;
- const presence_penalty = modelConfiguration?.presence_penalty;
- const max_response_length = modelConfiguration?.maxOutputTokens;
+ // TODO: pull variables from json
+ // const variables = runtime.getPluginData("openai");
+
+ const temperature = 0.7;
+ const frequency_penalty = 0.7;
+ const presence_penalty = 0.7;
+ const max_response_length = 8192;
const baseURL = process.env.OPENAI_BASE_URL ?? "https://api.openai.com/v1";
- //logger.debug("OpenAI baseURL result:", { baseURL });
const openai = createOpenAI({
apiKey: process.env.OPENAI_API_KEY,
baseURL,
fetch: runtime.fetch,
});
-
- // get the model name from the modelClass
- const model = modelClass.name;
+
+ const smallModel = process.env.OPENAI_SMALL_MODEL ?? process.env.SMALL_MODEL ?? "gpt-4o-mini";
+ const largeModel = process.env.OPENAI_LARGE_MODEL ?? process.env.LARGE_MODEL ?? "gpt-4o";
+
+ // get the model name from the modelType
+ const model = modelType === ModelType.TEXT_SMALL ? smallModel : largeModel;
const { text: openaiResponse } = await aiGenerateText({
model: openai.languageModel(model),
@@ -42,6 +68,7 @@ export const openaiPlugin: Plugin = {
maxTokens: max_response_length,
frequencyPenalty: frequency_penalty,
presencePenalty: presence_penalty,
+ stopSequences: stopSequences,
});
return openaiResponse;
diff --git a/packages/plugin-sqlite/src/index.ts b/packages/plugin-sqlite/src/index.ts
index db7160a4717..d664865779a 100644
--- a/packages/plugin-sqlite/src/index.ts
+++ b/packages/plugin-sqlite/src/index.ts
@@ -1,26 +1,26 @@
-import path from "node:path";
import fs from "node:fs";
+import path from "node:path";
-export * from "./sqliteTables.ts";
export * from "./sqlite_vec.ts";
+export * from "./sqliteTables.ts";
-import {
- DatabaseAdapter,
- logger,
- type IDatabaseCacheAdapter,
-} from "@elizaos/core";
import type {
Account,
Actor,
- GoalStatus,
- Participant,
+ Adapter,
Goal,
+ GoalStatus,
+ IAgentRuntime,
Memory,
+ Participant,
+ Plugin,
Relationship,
UUID,
- Adapter,
- IAgentRuntime,
- Plugin,
+} from "@elizaos/core";
+import {
+ DatabaseAdapter,
+ logger,
+ type IDatabaseCacheAdapter,
} from "@elizaos/core";
import type { Database as BetterSqlite3Database } from "better-sqlite3";
import { v4 } from "uuid";
@@ -146,13 +146,7 @@ export class SqliteDatabaseAdapter
if (row === null) {
return null;
}
- return {
- ...row,
- details:
- typeof row.details === "string"
- ? JSON.parse(row.details)
- : row.details,
- };
+ return row;
})
.filter((row): row is Actor => row !== null);
}