-
Notifications
You must be signed in to change notification settings - Fork 12.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6855ce7
commit cc8ca9e
Showing
2 changed files
with
33 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,53 @@ | ||
// #vercel-disable-blocks | ||
import { fetch } from 'undici' | ||
// #vercel-end | ||
import { generatePayload, parseStreamResponse } from '@/utils/openAI' | ||
import { sendMessage } from '@/utils/openAI' | ||
import { verifySignature } from '@/utils/auth' | ||
import type { APIRoute } from 'astro' | ||
|
||
const apiKey = import.meta.env.GEMINI_API_KEY | ||
const baseUrl = 'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent' | ||
|
||
const sitePassword = import.meta.env.SITE_PASSWORD || '' | ||
const passList = sitePassword.split(',') || [] | ||
|
||
export const post: APIRoute = async(context) => { | ||
const body = await context.request.json() | ||
const { sign, time, messages, pass } = body | ||
|
||
if (!messages) { | ||
return new Response(JSON.stringify({ | ||
error: { | ||
message: 'No input text.', | ||
}, | ||
}), { status: 400 }) | ||
} | ||
|
||
if (sitePassword && !(sitePassword === pass || passList.includes(pass))) { | ||
return new Response(JSON.stringify({ | ||
error: { | ||
message: 'Invalid password.', | ||
}, | ||
}), { status: 401 }) | ||
} | ||
if (import.meta.env.PROD && !await verifySignature({ t: time, m: messages?.[messages.length - 1]?.parts[0]?.text || '' }, sign)) { | ||
|
||
if (import.meta.env.PROD && !await verifySignature({ t: time, m: messages[messages.length - 1].parts.map(part => part.text).join('') }, sign)) { | ||
return new Response(JSON.stringify({ | ||
error: { | ||
message: 'Invalid signature.', | ||
}, | ||
}), { status: 401 }) | ||
} | ||
|
||
console.log('Received messages:', messages); | ||
|
||
|
||
const initOptions = generatePayload(messages) | ||
|
||
const response = await fetch(`${baseUrl}?key=${apiKey}`, initOptions).catch((err: Error) => { | ||
console.error(err) | ||
try { | ||
const result = await sendMessage(messages) | ||
let text = '' | ||
for await (const chunk of result.stream) { | ||
const chunkText = chunk.text() | ||
text += chunkText | ||
} | ||
return new Response(JSON.stringify({ text }), { status: 200 }) | ||
} catch (error) { | ||
console.error(error) | ||
return new Response(JSON.stringify({ | ||
error: { | ||
code: err.name, | ||
message: err.message, | ||
code: error.name, | ||
message: error.message, | ||
}, | ||
}), { status: 500 }) | ||
}) as Response | ||
|
||
return parseStreamResponse(response) as Response | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,57 +1,22 @@ | ||
// openAI.ts | ||
import { createParser } from 'eventsource-parser' | ||
import type { ParsedEvent, ReconnectInterval } from 'eventsource-parser' | ||
import type { ChatMessage } from '@/types' | ||
import { GoogleGenerativeAI } from '@google/generative-ai' | ||
|
||
export const generatePayload = ( | ||
messages: ChatMessage[], | ||
): RequestInit => ({ | ||
headers: { | ||
'Content-Type': 'application/json', | ||
}, | ||
method: 'POST', | ||
body: JSON.stringify({ | ||
contents: messages.map(message => ({ | ||
role: message.role, | ||
parts: [{ text: message.content }] | ||
})), | ||
}), | ||
}) | ||
|
||
export const parseStreamResponse = (rawResponse: Response) => { | ||
const encoder = new TextEncoder() | ||
const decoder = new TextDecoder() | ||
if (!rawResponse.ok) { | ||
return new Response(rawResponse.body, { | ||
status: rawResponse.status, | ||
statusText: rawResponse.statusText, | ||
}) | ||
} | ||
const apiKey = process.env.GENERATIVE_LANGUAGE_API_KEY | ||
const genAI = new GoogleGenerativeAI(apiKey) | ||
|
||
const stream = new ReadableStream({ | ||
async start(controller) { | ||
const streamParser = (event: ParsedEvent | ReconnectInterval) => { | ||
if (event.type === 'event') { | ||
try { | ||
const json = JSON.parse(event.data) | ||
json.contents.forEach((content: { parts: { text: string }[] }) => { | ||
content.parts.forEach(part => { | ||
const text = part.text | ||
const queue = encoder.encode(text) | ||
controller.enqueue(queue) | ||
}) | ||
}) | ||
} catch (e) { | ||
controller.error(e) | ||
} | ||
} | ||
} | ||
export const sendMessage = async(messages: ChatMessage[]) => { | ||
const model = genAI.getGenerativeModel({ model: 'gemini-pro' }) | ||
|
||
const parser = createParser(streamParser) | ||
for await (const chunk of rawResponse.body as any) | ||
parser.feed(decoder.decode(chunk)) | ||
const chat = model.startChat({ | ||
history: messages.map(msg => ({ | ||
role: msg.role, | ||
parts: msg.parts.map(part => part.text), | ||
})), | ||
generationConfig: { | ||
maxOutputTokens: 4000, // or your desired token limit | ||
}, | ||
}) | ||
|
||
return new Response(stream) | ||
const lastMessage = messages[messages.length - 1] | ||
const result = await model.sendMessageStream(lastMessage.parts.map(part => part.text).join('')) | ||
return result | ||
} |