diff --git a/src/pages/api/generate.ts b/src/pages/api/generate.ts index 27d103d0..892f7427 100644 --- a/src/pages/api/generate.ts +++ b/src/pages/api/generate.ts @@ -1,19 +1,14 @@ -// #vercel-disable-blocks -import { fetch } from 'undici' -// #vercel-end -import { generatePayload, parseStreamResponse } from '@/utils/openAI' +import { sendMessage } from '@/utils/openAI' import { verifySignature } from '@/utils/auth' import type { APIRoute } from 'astro' -const apiKey = import.meta.env.GEMINI_API_KEY -const baseUrl = 'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent' - const sitePassword = import.meta.env.SITE_PASSWORD || '' const passList = sitePassword.split(',') || [] export const post: APIRoute = async(context) => { const body = await context.request.json() const { sign, time, messages, pass } = body + if (!messages) { return new Response(JSON.stringify({ error: { @@ -21,6 +16,7 @@ export const post: APIRoute = async(context) => { }, }), { status: 400 }) } + if (sitePassword && !(sitePassword === pass || passList.includes(pass))) { return new Response(JSON.stringify({ error: { @@ -28,28 +24,30 @@ export const post: APIRoute = async(context) => { }, }), { status: 401 }) } - if (import.meta.env.PROD && !await verifySignature({ t: time, m: messages?.[messages.length - 1]?.parts[0]?.text || '' }, sign)) { + + if (import.meta.env.PROD && !await verifySignature({ t: time, m: messages[messages.length - 1].parts.map(part => part.text).join('') }, sign)) { return new Response(JSON.stringify({ error: { message: 'Invalid signature.', }, }), { status: 401 }) } - - console.log('Received messages:', messages); - - const initOptions = generatePayload(messages) - - const response = await fetch(`${baseUrl}?key=${apiKey}`, initOptions).catch((err: Error) => { - console.error(err) + try { + const result = await sendMessage(messages) + let text = '' + for await (const chunk of result.stream) { + const chunkText = chunk.text() + text += chunkText + } + return new Response(JSON.stringify({ text }), { status: 200 }) + } catch (error) { + console.error(error) return new Response(JSON.stringify({ error: { - code: err.name, - message: err.message, + code: error.name, + message: error.message, }, }), { status: 500 }) - }) as Response - - return parseStreamResponse(response) as Response + } } diff --git a/src/utils/openAI.ts b/src/utils/openAI.ts index cd51e194..1ae92773 100644 --- a/src/utils/openAI.ts +++ b/src/utils/openAI.ts @@ -1,57 +1,22 @@ -// openAI.ts -import { createParser } from 'eventsource-parser' -import type { ParsedEvent, ReconnectInterval } from 'eventsource-parser' -import type { ChatMessage } from '@/types' +import { GoogleGenerativeAI } from '@google/generative-ai' -export const generatePayload = ( - messages: ChatMessage[], -): RequestInit => ({ - headers: { - 'Content-Type': 'application/json', - }, - method: 'POST', - body: JSON.stringify({ - contents: messages.map(message => ({ - role: message.role, - parts: [{ text: message.content }] - })), - }), -}) - -export const parseStreamResponse = (rawResponse: Response) => { - const encoder = new TextEncoder() - const decoder = new TextDecoder() - if (!rawResponse.ok) { - return new Response(rawResponse.body, { - status: rawResponse.status, - statusText: rawResponse.statusText, - }) - } +const apiKey = process.env.GENERATIVE_LANGUAGE_API_KEY +const genAI = new GoogleGenerativeAI(apiKey) - const stream = new ReadableStream({ - async start(controller) { - const streamParser = (event: ParsedEvent | ReconnectInterval) => { - if (event.type === 'event') { - try { - const json = JSON.parse(event.data) - json.contents.forEach((content: { parts: { text: string }[] }) => { - content.parts.forEach(part => { - const text = part.text - const queue = encoder.encode(text) - controller.enqueue(queue) - }) - }) - } catch (e) { - controller.error(e) - } - } - } +export const sendMessage = async(messages: ChatMessage[]) => { + const model = genAI.getGenerativeModel({ model: 'gemini-pro' }) - const parser = createParser(streamParser) - for await (const chunk of rawResponse.body as any) - parser.feed(decoder.decode(chunk)) + const chat = model.startChat({ + history: messages.map(msg => ({ + role: msg.role, + parts: msg.parts.map(part => part.text), + })), + generationConfig: { + maxOutputTokens: 4000, // or your desired token limit }, }) - return new Response(stream) + const lastMessage = messages[messages.length - 1] + const result = await model.sendMessageStream(lastMessage.parts.map(part => part.text).join('')) + return result }