Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support anthropic provider #3364

Merged
merged 12 commits into from
Feb 7, 2025
Merged
1 change: 1 addition & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"dependencies": {
"@huggingface/transformers": "3.3.3",
"@ai-sdk/openai": "1.1.9",
"@ai-sdk/anthropic": "1.1.6",
"@tavily/core": "^0.0.2",
"@types/uuid": "10.0.0",
"ai": "4.1.16",
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/environment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ export const CharacterSchema = z.object({
id: z.string().uuid().optional(),
name: z.string(),
system: z.string().optional(),
modelProvider: z.nativeEnum(ModelProviderName),
modelProvider: z.string().optional(),
modelEndpointOverride: z.string().optional(),
templates: z.record(z.string()).optional(),
bio: z.union([z.string(), z.array(z.string())]),
Expand Down
163 changes: 114 additions & 49 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// ================ IMPORTS ================
import { createOpenAI } from "@ai-sdk/openai";
import { createAnthropic } from "@ai-sdk/anthropic";
import {
experimental_generateImage as aiGenerateImage,
generateObject as aiGenerateObject,
Expand Down Expand Up @@ -85,8 +86,59 @@ async function withRetry<T>(
}
}

export function initializeModelClient(runtime: IAgentRuntime, modelClass:ModelClass = ModelClass.DEFAULT) {
function isAnthropicProvider(runtime: IAgentRuntime): boolean {
const provider = runtime.getModelProvider()?.provider;
return (
provider.toLowerCase().includes("anthropic") ||
provider.toLowerCase().includes("claude")
);
}

function createModelClient(apiKey: string, baseURL: string, runtime: IAgentRuntime) {
const createClient = isAnthropicProvider(runtime)
? createAnthropic
: createOpenAI;

return createClient({
apiKey,
baseURL,
fetch: runtime.fetch,
});
}

// Add this utility function near other utility functions
function validateModelConfig(
provider: string,
config: {
apiKey?: string;
baseURL?: string;
modelProvider?: any;
modelClass?: ModelClass;
model?: string;
}
) {
const validations = [
{ value: config.apiKey, name: 'API key', for: provider },
{ value: config.baseURL, name: 'endpoint URL', for: provider },
{ value: config.modelProvider, name: 'model provider' },
{ value: config.modelProvider?.models, name: 'model configurations', in: 'provider' },
{ value: config.model, name: 'model name', for: `class ${config.modelClass}` }
];
wtfsayo marked this conversation as resolved.
Show resolved Hide resolved

for (const check of validations) {
if (!check.value) {
const message = check.for
? `No ${check.name} found for ${check.for}`
: check.in
? `${check.name} not found in ${check.in}`
: `${check.name} not initialized`;
elizaLogger.error(message);
throw new Error(message);
}
}
}

export function initializeModelClient(runtime: IAgentRuntime, modelClass: ModelClass = ModelClass.DEFAULT) {
elizaLogger.info(`Initializing model client with runtime: ${runtime.modelProvider}`);
const provider = runtime.getModelProvider()?.provider || runtime.modelProvider;
const baseURL = runtime.getModelProvider()?.endpoint;
Expand All @@ -95,44 +147,20 @@ async function withRetry<T>(
runtime.character.settings.secrets.PROVIDER_API_KEY ||
runtime.getSetting('PROVIDER_API_KEY');

if (!apiKey) {
elizaLogger.error(`No API key found for ${provider}`);
throw new Error(`No API key found for ${provider}`);
}

if (!baseURL) {
elizaLogger.error(`No endpoint URL found for ${provider}`);
throw new Error(`No endpoint URL found for ${provider}`);
}

const modelProvider = runtime.getModelProvider();
if (!modelProvider) {
elizaLogger.error('Model provider not initialized');
throw new Error('Model provider not initialized');
}
const modelConfig = modelProvider?.models?.[modelClass];
const model = modelConfig?.name;

if (!modelProvider.models) {
elizaLogger.error('Model configurations not found in provider');
throw new Error('Model configurations not found in provider');
}

const modelConfig = modelProvider.models[modelClass];
if (!modelConfig) {
elizaLogger.error(`No model configuration found for class ${modelClass}`);
throw new Error(`No model configuration found for class ${modelClass}`);
}

const model = modelConfig.name;
if (!model) {
elizaLogger.error(`Model name not specified for class ${modelClass}`);
throw new Error(`Model name not specified for class ${modelClass}`);
}

const client = createOpenAI({
// Single validation call replaces multiple if-checks
validateModelConfig(provider, {
apiKey,
baseURL,
fetch: runtime.fetch,
modelProvider,
modelClass,
model
});

const client = createModelClient(apiKey, baseURL, runtime);

elizaLogger.info(`Initialized model client for ${provider} with baseURL ${baseURL} and model ${model}`);

Expand Down Expand Up @@ -325,6 +353,39 @@ export async function generateTrueOrFalse({
return result === 'true';
}

function getModelConfig(
runtime: IAgentRuntime,
client: any,
model: string,
mode: 'auto' | 'json' | 'tool',
options: {
context: string;
output?: 'object' | 'array' | 'enum' | 'no-schema';
schema?: ZodSchema;
schemaName?: string;
schemaDescription?: string;
enumValues?: string[];
stopSequences?: string[];
}
): any {
if (isAnthropicProvider(runtime) && mode === "json") {
elizaLogger.warn("Anthropic does not support JSON mode. Switching to 'auto'.");
mode = "auto";
}

const config = {
model: client.languageModel(model),
prompt: options.context.toString(),
system: runtime.character.system ?? settings.SYSTEM_PROMPT ?? undefined,
output: options.output as never,
mode: mode as never,
...(options.schema ? { schema: options.schema, schemaName: options.schemaName, schemaDescription: options.schemaDescription } : {}),
...(options.enumValues ? { enum: options.enumValues } : {}),
};

return options.stopSequences ? { ...config, stopSequences: options.stopSequences } : config;
}

// ================ OBJECT GENERATION FUNCTIONS ================
export const generateObject = async ({
runtime,
Expand Down Expand Up @@ -352,21 +413,17 @@ export const generateObject = async ({
throw new Error('Enum values are required when output type is enum');
}

// Create the base configuration object
const config = {
model: client.languageModel(model),
prompt: context.toString(),
system: runtime.character.system ?? settings.SYSTEM_PROMPT ?? undefined,
output: output as never,
mode: mode as never,
...(schema ? { schema, schemaName, schemaDescription } : {}),
...(enumValues ? { enum: enumValues } : {})
};

// Only add stopSequences if it's defined
const finalConfig = stopSequences ? { ...config, stopSequences } : config;
const config = getModelConfig(runtime, client, model, mode, {
context,
output,
schema,
schemaName,
schemaDescription,
enumValues,
stopSequences,
});

const {object} = await aiGenerateObject(finalConfig);
const {object} = await aiGenerateObject(config);

elizaLogger.debug(`Received Object response from ${model} model.`);
return schema ? schema.parse(object) : object;
Expand Down Expand Up @@ -506,14 +563,22 @@ export const generateImage = async (
elizaLogger.warn("No model settings found for the image model provider.");
return { success: false, error: "No model settings available" };
}

if (isAnthropicProvider(runtime)) {
return {
success: false,
error: "Unsupported provider: Anthropic does not support image generation.",
};
}

const { model, client } = initializeModelClient(runtime, ModelClass.IMAGE);
elizaLogger.info("Generating image with options:", {
imageModelProvider: model,
});

await withRetry(async () => {
const result = await aiGenerateImage({
model: client.imageModel(model),
model: (client as any).imageModel(model),
prompt: data.prompt,
size: `${data.width}x${data.height}`,
n: data.count,
Expand Down
Loading