Skip to content

Commit

Permalink
Update generation.ts
Browse files Browse the repository at this point in the history
  • Loading branch information
wtfsayo committed Feb 7, 2025
1 parent c07210c commit 6e0d4c0
Showing 1 changed file with 42 additions and 31 deletions.
73 changes: 42 additions & 31 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,39 @@ function createModelClient(apiKey: string, baseURL: string, runtime: IAgentRunti
});
}

export function initializeModelClient(runtime: IAgentRuntime, modelClass: ModelClass = ModelClass.DEFAULT) {
// Add this utility function near other utility functions
function validateModelConfig(
provider: string,
config: {
apiKey?: string;
baseURL?: string;
modelProvider?: any;
modelClass?: ModelClass;
model?: string;
}
) {
const validations = [
{ value: config.apiKey, name: 'API key', for: provider },
{ value: config.baseURL, name: 'endpoint URL', for: provider },
{ value: config.modelProvider, name: 'model provider' },
{ value: config.modelProvider?.models, name: 'model configurations', in: 'provider' },
{ value: config.model, name: 'model name', for: `class ${config.modelClass}` }
];

for (const check of validations) {
if (!check.value) {
const message = check.for
? `No ${check.name} found for ${check.for}`
: check.in
? `${check.name} not found in ${check.in}`
: `${check.name} not initialized`;
elizaLogger.error(message);
throw new Error(message);
}
}
}

export function initializeModelClient(runtime: IAgentRuntime, modelClass: ModelClass = ModelClass.DEFAULT) {
elizaLogger.info(`Initializing model client with runtime: ${runtime.modelProvider}`);
const provider = runtime.getModelProvider()?.provider || runtime.modelProvider;
const baseURL = runtime.getModelProvider()?.endpoint;
Expand All @@ -116,38 +147,18 @@ export function initializeModelClient(runtime: IAgentRuntime, modelClass: ModelC
runtime.character.settings.secrets.PROVIDER_API_KEY ||
runtime.getSetting('PROVIDER_API_KEY');

if (!apiKey) {
elizaLogger.error(`No API key found for ${provider}`);
throw new Error(`No API key found for ${provider}`);
}

if (!baseURL) {
elizaLogger.error(`No endpoint URL found for ${provider}`);
throw new Error(`No endpoint URL found for ${provider}`);
}

const modelProvider = runtime.getModelProvider();
if (!modelProvider) {
elizaLogger.error('Model provider not initialized');
throw new Error('Model provider not initialized');
}

if (!modelProvider.models) {
elizaLogger.error('Model configurations not found in provider');
throw new Error('Model configurations not found in provider');
}

const modelConfig = modelProvider.models[modelClass];
if (!modelConfig) {
elizaLogger.error(`No model configuration found for class ${modelClass}`);
throw new Error(`No model configuration found for class ${modelClass}`);
}
const modelConfig = modelProvider?.models?.[modelClass];
const model = modelConfig?.name;

const model = modelConfig.name;
if (!model) {
elizaLogger.error(`Model name not specified for class ${modelClass}`);
throw new Error(`Model name not specified for class ${modelClass}`);
}
// Single validation call replaces multiple if-checks
validateModelConfig(provider, {
apiKey,
baseURL,
modelProvider,
modelClass,
model
});

const client = createModelClient(apiKey, baseURL, runtime);

Expand Down

0 comments on commit 6e0d4c0

Please sign in to comment.