diff --git a/packages/core/package.json b/packages/core/package.json index 63e664f9bad..74a32a01ec3 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -61,6 +61,7 @@ "dependencies": { "@huggingface/transformers": "3.3.3", "@ai-sdk/openai": "1.1.9", + "@ai-sdk/anthropic": "1.1.6", "@tavily/core": "^0.0.2", "@types/uuid": "10.0.0", "ai": "4.1.16", diff --git a/packages/core/src/environment.ts b/packages/core/src/environment.ts index e2c5e61ccba..835d7f38e6b 100644 --- a/packages/core/src/environment.ts +++ b/packages/core/src/environment.ts @@ -75,7 +75,7 @@ export const CharacterSchema = z.object({ id: z.string().uuid().optional(), name: z.string(), system: z.string().optional(), - modelProvider: z.nativeEnum(ModelProviderName), + modelProvider: z.string().optional(), modelEndpointOverride: z.string().optional(), templates: z.record(z.string()).optional(), bio: z.union([z.string(), z.array(z.string())]), diff --git a/packages/core/src/generation.ts b/packages/core/src/generation.ts index 18186ec27ac..a383c24a86e 100644 --- a/packages/core/src/generation.ts +++ b/packages/core/src/generation.ts @@ -1,5 +1,6 @@ // ================ IMPORTS ================ import { createOpenAI } from "@ai-sdk/openai"; +import { createAnthropic } from "@ai-sdk/anthropic"; import { experimental_generateImage as aiGenerateImage, generateObject as aiGenerateObject, @@ -85,8 +86,59 @@ async function withRetry( } } - export function initializeModelClient(runtime: IAgentRuntime, modelClass:ModelClass = ModelClass.DEFAULT) { +function isAnthropicProvider(runtime: IAgentRuntime): boolean { + const provider = runtime.getModelProvider()?.provider; + return ( + provider.toLowerCase().includes("anthropic") || + provider.toLowerCase().includes("claude") + ); +} + +function createModelClient(apiKey: string, baseURL: string, runtime: IAgentRuntime) { + const createClient = isAnthropicProvider(runtime) + ? createAnthropic + : createOpenAI; + + return createClient({ + apiKey, + baseURL, + fetch: runtime.fetch, + }); +} +// Add this utility function near other utility functions +function validateModelConfig( + provider: string, + config: { + apiKey?: string; + baseURL?: string; + modelProvider?: any; + modelClass?: ModelClass; + model?: string; + } +) { + const validations = [ + { value: config.apiKey, name: 'API key', for: provider }, + { value: config.baseURL, name: 'endpoint URL', for: provider }, + { value: config.modelProvider, name: 'model provider' }, + { value: config.modelProvider?.models, name: 'model configurations', in: 'provider' }, + { value: config.model, name: 'model name', for: `class ${config.modelClass}` } + ]; + + for (const check of validations) { + if (!check.value) { + const message = check.for + ? `No ${check.name} found for ${check.for}` + : check.in + ? `${check.name} not found in ${check.in}` + : `${check.name} not initialized`; + elizaLogger.error(message); + throw new Error(message); + } + } +} + +export function initializeModelClient(runtime: IAgentRuntime, modelClass: ModelClass = ModelClass.DEFAULT) { elizaLogger.info(`Initializing model client with runtime: ${runtime.modelProvider}`); const provider = runtime.getModelProvider()?.provider || runtime.modelProvider; const baseURL = runtime.getModelProvider()?.endpoint; @@ -95,44 +147,20 @@ async function withRetry( runtime.character.settings.secrets.PROVIDER_API_KEY || runtime.getSetting('PROVIDER_API_KEY'); - if (!apiKey) { - elizaLogger.error(`No API key found for ${provider}`); - throw new Error(`No API key found for ${provider}`); - } - - if (!baseURL) { - elizaLogger.error(`No endpoint URL found for ${provider}`); - throw new Error(`No endpoint URL found for ${provider}`); - } - const modelProvider = runtime.getModelProvider(); - if (!modelProvider) { - elizaLogger.error('Model provider not initialized'); - throw new Error('Model provider not initialized'); - } + const modelConfig = modelProvider?.models?.[modelClass]; + const model = modelConfig?.name; - if (!modelProvider.models) { - elizaLogger.error('Model configurations not found in provider'); - throw new Error('Model configurations not found in provider'); - } - - const modelConfig = modelProvider.models[modelClass]; - if (!modelConfig) { - elizaLogger.error(`No model configuration found for class ${modelClass}`); - throw new Error(`No model configuration found for class ${modelClass}`); - } - - const model = modelConfig.name; - if (!model) { - elizaLogger.error(`Model name not specified for class ${modelClass}`); - throw new Error(`Model name not specified for class ${modelClass}`); - } - - const client = createOpenAI({ + // Single validation call replaces multiple if-checks + validateModelConfig(provider, { apiKey, baseURL, - fetch: runtime.fetch, + modelProvider, + modelClass, + model }); + + const client = createModelClient(apiKey, baseURL, runtime); elizaLogger.info(`Initialized model client for ${provider} with baseURL ${baseURL} and model ${model}`); @@ -325,6 +353,39 @@ export async function generateTrueOrFalse({ return result === 'true'; } +function getModelConfig( + runtime: IAgentRuntime, + client: any, + model: string, + mode: 'auto' | 'json' | 'tool', + options: { + context: string; + output?: 'object' | 'array' | 'enum' | 'no-schema'; + schema?: ZodSchema; + schemaName?: string; + schemaDescription?: string; + enumValues?: string[]; + stopSequences?: string[]; + } +): any { + if (isAnthropicProvider(runtime) && mode === "json") { + elizaLogger.warn("Anthropic does not support JSON mode. Switching to 'auto'."); + mode = "auto"; + } + + const config = { + model: client.languageModel(model), + prompt: options.context.toString(), + system: runtime.character.system ?? settings.SYSTEM_PROMPT ?? undefined, + output: options.output as never, + mode: mode as never, + ...(options.schema ? { schema: options.schema, schemaName: options.schemaName, schemaDescription: options.schemaDescription } : {}), + ...(options.enumValues ? { enum: options.enumValues } : {}), + }; + + return options.stopSequences ? { ...config, stopSequences: options.stopSequences } : config; +} + // ================ OBJECT GENERATION FUNCTIONS ================ export const generateObject = async ({ runtime, @@ -352,21 +413,17 @@ export const generateObject = async ({ throw new Error('Enum values are required when output type is enum'); } - // Create the base configuration object - const config = { - model: client.languageModel(model), - prompt: context.toString(), - system: runtime.character.system ?? settings.SYSTEM_PROMPT ?? undefined, - output: output as never, - mode: mode as never, - ...(schema ? { schema, schemaName, schemaDescription } : {}), - ...(enumValues ? { enum: enumValues } : {}) - }; - - // Only add stopSequences if it's defined - const finalConfig = stopSequences ? { ...config, stopSequences } : config; + const config = getModelConfig(runtime, client, model, mode, { + context, + output, + schema, + schemaName, + schemaDescription, + enumValues, + stopSequences, + }); - const {object} = await aiGenerateObject(finalConfig); + const {object} = await aiGenerateObject(config); elizaLogger.debug(`Received Object response from ${model} model.`); return schema ? schema.parse(object) : object; @@ -506,6 +563,14 @@ export const generateImage = async ( elizaLogger.warn("No model settings found for the image model provider."); return { success: false, error: "No model settings available" }; } + + if (isAnthropicProvider(runtime)) { + return { + success: false, + error: "Unsupported provider: Anthropic does not support image generation.", + }; + } + const { model, client } = initializeModelClient(runtime, ModelClass.IMAGE); elizaLogger.info("Generating image with options:", { imageModelProvider: model, @@ -513,7 +578,7 @@ export const generateImage = async ( await withRetry(async () => { const result = await aiGenerateImage({ - model: client.imageModel(model), + model: (client as any).imageModel(model), prompt: data.prompt, size: `${data.width}x${data.height}`, n: data.count,