diff --git a/src/middlewares/analytics.ts b/src/middlewares/analytics.ts index e997e2c..9b8591e 100644 --- a/src/middlewares/analytics.ts +++ b/src/middlewares/analytics.ts @@ -4,15 +4,20 @@ export function recordAnalytics( c: Context, endpoint: string, duration: number, - prompt_tokens: number, - completion_tokens: number ) { + + const getModelName = c.get('getModelName'); + const modelName = typeof getModelName === 'function' ? getModelName(c) : 'unknown'; + + const getTokenCount = c.get('getTokenCount'); + const { input_tokens, output_tokens } = typeof getTokenCount === 'function' ? getTokenCount(c) : { input_tokens: 0, output_tokens: 0 }; + + // console.log(endpoint, c.req.path, modelName, input_tokens, output_tokens, c.get('malacca-cache-status') || 'miss', c.res.status); + if (c.env.MALACCA) { - const getModelName = c.get('getModelName'); - const modelName = typeof getModelName === 'function' ? getModelName(c) : 'unknown'; c.env.MALACCA.writeDataPoint({ 'blobs': [endpoint, c.req.path, c.res.status, c.get('malacca-cache-status') || 'miss', modelName], - 'doubles': [duration, prompt_tokens, completion_tokens], + 'doubles': [duration, input_tokens, output_tokens], 'indexes': [endpoint], }); } @@ -25,24 +30,10 @@ export const metricsMiddleware: MiddlewareHandler = async (c, next) => { c.executionCtx.waitUntil((async () => { await c.get('bufferPromise') - const buf = c.get('buffer') const endTime = Date.now(); const duration = endTime - startTime; const endpoint = c.get('endpoint') || 'unknown'; - let prompt_tokens = 0; - let completion_tokens = 0; - if (c.res.status === 200) { - if (c.res.headers.get('content-type') === 'application/json') { - const usage = JSON.parse(buf)['usage']; - if (usage) { - prompt_tokens = usage['prompt_tokens'] | 0; - completion_tokens = usage['completion_tokens'] | 0; - } - } else { - completion_tokens = buf.split('\n\n').length - 1; - } - } - recordAnalytics(c, endpoint, duration, prompt_tokens, completion_tokens); + recordAnalytics(c, endpoint, duration); })()); }; diff --git a/src/middlewares/buffer.ts b/src/middlewares/buffer.ts index 7c6b040..6af88b3 100644 --- a/src/middlewares/buffer.ts +++ b/src/middlewares/buffer.ts @@ -8,6 +8,9 @@ export const bufferMiddleware: MiddlewareHandler = async (c: Context, next: Next }) c.set('bufferPromise', bufferPromise) + const reqBuffer: string = await c.req.text() || '' + c.set('reqBuffer', reqBuffer) + await next() const originalResponse = c.res diff --git a/src/middlewares/logging.ts b/src/middlewares/logging.ts index bdfbaa3..1b39b81 100644 --- a/src/middlewares/logging.ts +++ b/src/middlewares/logging.ts @@ -5,7 +5,7 @@ export const loggingMiddleware = async (c: Context, next: Next) => { // Log request and response c.executionCtx.waitUntil((async () => { - const requestBody = await c.req.text().catch(() => ({})); + const requestBody = c.get('reqBuffer') || ''; console.log('Request:', { body: requestBody, }); diff --git a/src/providers/azureOpenAI.ts b/src/providers/azureOpenAI.ts index b99ba1a..9e0cd83 100644 --- a/src/providers/azureOpenAI.ts +++ b/src/providers/azureOpenAI.ts @@ -15,6 +15,7 @@ const azureOpenAIRoute = new Hono(); const initMiddleware = async (c: Context, next: Next) => { c.set('endpoint', ProviderName); c.set('getModelName', getModelName); + c.set('getTokenCount', getTokenCount); await next(); }; @@ -30,6 +31,7 @@ export const azureOpenAIProvider: AIProvider = { basePath: BasePath, route: azureOpenAIRoute, getModelName: getModelName, + getTokenCount: getTokenCount, handleRequest: async (c: Context) => { const resourceName = c.req.param('resource_name') || ''; const deploymentName = c.req.param('deployment_name') || ''; @@ -78,4 +80,28 @@ function getModelName(c: Context): string { } } return "unknown" +} + +function getTokenCount(c: Context): { input_tokens: number, output_tokens: number } { + const buf = c.get('buffer') || "" + if (c.res.status === 200) { + if (c.res.headers.get('content-type') === 'application/json') { + const usage = JSON.parse(buf)['usage']; + if (usage) { + const input_tokens = usage['prompt_tokens'] || 0; + const output_tokens = usage['completion_tokens'] || 0; + return { input_tokens, output_tokens } + } + } else { + // For streaming response, azure openai does not return usage in the response body, so we count the words and multiply by 4/3 to get the number of input tokens + const requestBody = c.get('reqBuffer') || '{}' + const messages = JSON.stringify(JSON.parse(requestBody).messages); + const input_tokens = Math.ceil(messages.split(/\s+/).length * 4 / 3); + + // For streaming responses, we count the number of '\n\n' as the number of output tokens + const output_tokens = buf.split('\n\n').length - 1; + return { input_tokens: input_tokens, output_tokens: output_tokens } + } + } + return { input_tokens: 0, output_tokens: 0 } } \ No newline at end of file diff --git a/src/types.ts b/src/types.ts index e8c6507..b3709f2 100644 --- a/src/types.ts +++ b/src/types.ts @@ -10,6 +10,7 @@ export interface AIProvider { name: string; handleRequest: (c: Context) => Promise; getModelName: (c: Context) => string; + getTokenCount: (c: Context) => {input_tokens: number, output_tokens: number}; basePath: string; route: Hono; }