diff --git a/packages/core/src/generation.ts b/packages/core/src/generation.ts index 3243fe9ca39..a383c24a86e 100644 --- a/packages/core/src/generation.ts +++ b/packages/core/src/generation.ts @@ -106,8 +106,39 @@ function createModelClient(apiKey: string, baseURL: string, runtime: IAgentRunti }); } -export function initializeModelClient(runtime: IAgentRuntime, modelClass: ModelClass = ModelClass.DEFAULT) { +// Add this utility function near other utility functions +function validateModelConfig( + provider: string, + config: { + apiKey?: string; + baseURL?: string; + modelProvider?: any; + modelClass?: ModelClass; + model?: string; + } +) { + const validations = [ + { value: config.apiKey, name: 'API key', for: provider }, + { value: config.baseURL, name: 'endpoint URL', for: provider }, + { value: config.modelProvider, name: 'model provider' }, + { value: config.modelProvider?.models, name: 'model configurations', in: 'provider' }, + { value: config.model, name: 'model name', for: `class ${config.modelClass}` } + ]; + + for (const check of validations) { + if (!check.value) { + const message = check.for + ? `No ${check.name} found for ${check.for}` + : check.in + ? `${check.name} not found in ${check.in}` + : `${check.name} not initialized`; + elizaLogger.error(message); + throw new Error(message); + } + } +} +export function initializeModelClient(runtime: IAgentRuntime, modelClass: ModelClass = ModelClass.DEFAULT) { elizaLogger.info(`Initializing model client with runtime: ${runtime.modelProvider}`); const provider = runtime.getModelProvider()?.provider || runtime.modelProvider; const baseURL = runtime.getModelProvider()?.endpoint; @@ -116,38 +147,18 @@ export function initializeModelClient(runtime: IAgentRuntime, modelClass: ModelC runtime.character.settings.secrets.PROVIDER_API_KEY || runtime.getSetting('PROVIDER_API_KEY'); - if (!apiKey) { - elizaLogger.error(`No API key found for ${provider}`); - throw new Error(`No API key found for ${provider}`); - } - - if (!baseURL) { - elizaLogger.error(`No endpoint URL found for ${provider}`); - throw new Error(`No endpoint URL found for ${provider}`); - } - const modelProvider = runtime.getModelProvider(); - if (!modelProvider) { - elizaLogger.error('Model provider not initialized'); - throw new Error('Model provider not initialized'); - } - - if (!modelProvider.models) { - elizaLogger.error('Model configurations not found in provider'); - throw new Error('Model configurations not found in provider'); - } - - const modelConfig = modelProvider.models[modelClass]; - if (!modelConfig) { - elizaLogger.error(`No model configuration found for class ${modelClass}`); - throw new Error(`No model configuration found for class ${modelClass}`); - } + const modelConfig = modelProvider?.models?.[modelClass]; + const model = modelConfig?.name; - const model = modelConfig.name; - if (!model) { - elizaLogger.error(`Model name not specified for class ${modelClass}`); - throw new Error(`Model name not specified for class ${modelClass}`); - } + // Single validation call replaces multiple if-checks + validateModelConfig(provider, { + apiKey, + baseURL, + modelProvider, + modelClass, + model + }); const client = createModelClient(apiKey, baseURL, runtime);