Skip to content

Commit

Permalink
Merge pull request #3364 from elizaOS/tcm/support-anthropic
Browse files Browse the repository at this point in the history
support anthropic provider
  • Loading branch information
tcm390 authored Feb 7, 2025
2 parents ef876f8 + 6e0d4c0 commit 2d875c3
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 50 deletions.
1 change: 1 addition & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"dependencies": {
"@huggingface/transformers": "3.3.3",
"@ai-sdk/openai": "1.1.9",
"@ai-sdk/anthropic": "1.1.6",
"@tavily/core": "^0.0.2",
"@types/uuid": "10.0.0",
"ai": "4.1.16",
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/environment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ export const CharacterSchema = z.object({
id: z.string().uuid().optional(),
name: z.string(),
system: z.string().optional(),
modelProvider: z.nativeEnum(ModelProviderName),
modelProvider: z.string().optional(),
modelEndpointOverride: z.string().optional(),
templates: z.record(z.string()).optional(),
bio: z.union([z.string(), z.array(z.string())]),
Expand Down
163 changes: 114 additions & 49 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// ================ IMPORTS ================
import { createOpenAI } from "@ai-sdk/openai";
import { createAnthropic } from "@ai-sdk/anthropic";
import {
experimental_generateImage as aiGenerateImage,
generateObject as aiGenerateObject,
Expand Down Expand Up @@ -85,8 +86,59 @@ async function withRetry<T>(
}
}

export function initializeModelClient(runtime: IAgentRuntime, modelClass:ModelClass = ModelClass.DEFAULT) {
function isAnthropicProvider(runtime: IAgentRuntime): boolean {
const provider = runtime.getModelProvider()?.provider;
return (
provider.toLowerCase().includes("anthropic") ||
provider.toLowerCase().includes("claude")
);
}

function createModelClient(apiKey: string, baseURL: string, runtime: IAgentRuntime) {
const createClient = isAnthropicProvider(runtime)
? createAnthropic
: createOpenAI;

return createClient({
apiKey,
baseURL,
fetch: runtime.fetch,
});
}

// Add this utility function near other utility functions
function validateModelConfig(
provider: string,
config: {
apiKey?: string;
baseURL?: string;
modelProvider?: any;
modelClass?: ModelClass;
model?: string;
}
) {
const validations = [
{ value: config.apiKey, name: 'API key', for: provider },
{ value: config.baseURL, name: 'endpoint URL', for: provider },
{ value: config.modelProvider, name: 'model provider' },
{ value: config.modelProvider?.models, name: 'model configurations', in: 'provider' },
{ value: config.model, name: 'model name', for: `class ${config.modelClass}` }
];

for (const check of validations) {
if (!check.value) {
const message = check.for
? `No ${check.name} found for ${check.for}`
: check.in
? `${check.name} not found in ${check.in}`
: `${check.name} not initialized`;
elizaLogger.error(message);
throw new Error(message);
}
}
}

export function initializeModelClient(runtime: IAgentRuntime, modelClass: ModelClass = ModelClass.DEFAULT) {
elizaLogger.info(`Initializing model client with runtime: ${runtime.modelProvider}`);
const provider = runtime.getModelProvider()?.provider || runtime.modelProvider;
const baseURL = runtime.getModelProvider()?.endpoint;
Expand All @@ -95,44 +147,20 @@ async function withRetry<T>(
runtime.character.settings.secrets.PROVIDER_API_KEY ||
runtime.getSetting('PROVIDER_API_KEY');

if (!apiKey) {
elizaLogger.error(`No API key found for ${provider}`);
throw new Error(`No API key found for ${provider}`);
}

if (!baseURL) {
elizaLogger.error(`No endpoint URL found for ${provider}`);
throw new Error(`No endpoint URL found for ${provider}`);
}

const modelProvider = runtime.getModelProvider();
if (!modelProvider) {
elizaLogger.error('Model provider not initialized');
throw new Error('Model provider not initialized');
}
const modelConfig = modelProvider?.models?.[modelClass];
const model = modelConfig?.name;

if (!modelProvider.models) {
elizaLogger.error('Model configurations not found in provider');
throw new Error('Model configurations not found in provider');
}

const modelConfig = modelProvider.models[modelClass];
if (!modelConfig) {
elizaLogger.error(`No model configuration found for class ${modelClass}`);
throw new Error(`No model configuration found for class ${modelClass}`);
}

const model = modelConfig.name;
if (!model) {
elizaLogger.error(`Model name not specified for class ${modelClass}`);
throw new Error(`Model name not specified for class ${modelClass}`);
}

const client = createOpenAI({
// Single validation call replaces multiple if-checks
validateModelConfig(provider, {
apiKey,
baseURL,
fetch: runtime.fetch,
modelProvider,
modelClass,
model
});

const client = createModelClient(apiKey, baseURL, runtime);

elizaLogger.info(`Initialized model client for ${provider} with baseURL ${baseURL} and model ${model}`);

Expand Down Expand Up @@ -325,6 +353,39 @@ export async function generateTrueOrFalse({
return result === 'true';
}

function getModelConfig(
runtime: IAgentRuntime,
client: any,
model: string,
mode: 'auto' | 'json' | 'tool',
options: {
context: string;
output?: 'object' | 'array' | 'enum' | 'no-schema';
schema?: ZodSchema;
schemaName?: string;
schemaDescription?: string;
enumValues?: string[];
stopSequences?: string[];
}
): any {
if (isAnthropicProvider(runtime) && mode === "json") {
elizaLogger.warn("Anthropic does not support JSON mode. Switching to 'auto'.");
mode = "auto";
}

const config = {
model: client.languageModel(model),
prompt: options.context.toString(),
system: runtime.character.system ?? settings.SYSTEM_PROMPT ?? undefined,
output: options.output as never,
mode: mode as never,
...(options.schema ? { schema: options.schema, schemaName: options.schemaName, schemaDescription: options.schemaDescription } : {}),
...(options.enumValues ? { enum: options.enumValues } : {}),
};

return options.stopSequences ? { ...config, stopSequences: options.stopSequences } : config;
}

// ================ OBJECT GENERATION FUNCTIONS ================
export const generateObject = async ({
runtime,
Expand Down Expand Up @@ -352,21 +413,17 @@ export const generateObject = async ({
throw new Error('Enum values are required when output type is enum');
}

// Create the base configuration object
const config = {
model: client.languageModel(model),
prompt: context.toString(),
system: runtime.character.system ?? settings.SYSTEM_PROMPT ?? undefined,
output: output as never,
mode: mode as never,
...(schema ? { schema, schemaName, schemaDescription } : {}),
...(enumValues ? { enum: enumValues } : {})
};

// Only add stopSequences if it's defined
const finalConfig = stopSequences ? { ...config, stopSequences } : config;
const config = getModelConfig(runtime, client, model, mode, {
context,
output,
schema,
schemaName,
schemaDescription,
enumValues,
stopSequences,
});

const {object} = await aiGenerateObject(finalConfig);
const {object} = await aiGenerateObject(config);

elizaLogger.debug(`Received Object response from ${model} model.`);
return schema ? schema.parse(object) : object;
Expand Down Expand Up @@ -506,14 +563,22 @@ export const generateImage = async (
elizaLogger.warn("No model settings found for the image model provider.");
return { success: false, error: "No model settings available" };
}

if (isAnthropicProvider(runtime)) {
return {
success: false,
error: "Unsupported provider: Anthropic does not support image generation.",
};
}

const { model, client } = initializeModelClient(runtime, ModelClass.IMAGE);
elizaLogger.info("Generating image with options:", {
imageModelProvider: model,
});

await withRetry(async () => {
const result = await aiGenerateImage({
model: client.imageModel(model),
model: (client as any).imageModel(model),
prompt: data.prompt,
size: `${data.width}x${data.height}`,
n: data.count,
Expand Down

0 comments on commit 2d875c3

Please sign in to comment.