Skip to content

Commit

Permalink
use google sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
babaohuang committed Dec 14, 2023
1 parent 6855ce7 commit cc8ca9e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 70 deletions.
38 changes: 18 additions & 20 deletions src/pages/api/generate.ts
Original file line number Diff line number Diff line change
@@ -1,55 +1,53 @@
// #vercel-disable-blocks
import { fetch } from 'undici'
// #vercel-end
import { generatePayload, parseStreamResponse } from '@/utils/openAI'
import { sendMessage } from '@/utils/openAI'
import { verifySignature } from '@/utils/auth'
import type { APIRoute } from 'astro'

const apiKey = import.meta.env.GEMINI_API_KEY
const baseUrl = 'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent'

const sitePassword = import.meta.env.SITE_PASSWORD || ''
const passList = sitePassword.split(',') || []

export const post: APIRoute = async(context) => {
const body = await context.request.json()
const { sign, time, messages, pass } = body

if (!messages) {
return new Response(JSON.stringify({
error: {
message: 'No input text.',
},
}), { status: 400 })
}

if (sitePassword && !(sitePassword === pass || passList.includes(pass))) {
return new Response(JSON.stringify({
error: {
message: 'Invalid password.',
},
}), { status: 401 })
}
if (import.meta.env.PROD && !await verifySignature({ t: time, m: messages?.[messages.length - 1]?.parts[0]?.text || '' }, sign)) {

if (import.meta.env.PROD && !await verifySignature({ t: time, m: messages[messages.length - 1].parts.map(part => part.text).join('') }, sign)) {
return new Response(JSON.stringify({
error: {
message: 'Invalid signature.',
},
}), { status: 401 })
}

console.log('Received messages:', messages);


const initOptions = generatePayload(messages)

const response = await fetch(`${baseUrl}?key=${apiKey}`, initOptions).catch((err: Error) => {
console.error(err)
try {
const result = await sendMessage(messages)
let text = ''
for await (const chunk of result.stream) {
const chunkText = chunk.text()
text += chunkText
}
return new Response(JSON.stringify({ text }), { status: 200 })
} catch (error) {
console.error(error)
return new Response(JSON.stringify({
error: {
code: err.name,
message: err.message,
code: error.name,
message: error.message,
},
}), { status: 500 })
}) as Response

return parseStreamResponse(response) as Response
}
}
65 changes: 15 additions & 50 deletions src/utils/openAI.ts
Original file line number Diff line number Diff line change
@@ -1,57 +1,22 @@
// openAI.ts
import { createParser } from 'eventsource-parser'
import type { ParsedEvent, ReconnectInterval } from 'eventsource-parser'
import type { ChatMessage } from '@/types'
import { GoogleGenerativeAI } from '@google/generative-ai'

export const generatePayload = (
messages: ChatMessage[],
): RequestInit => ({
headers: {
'Content-Type': 'application/json',
},
method: 'POST',
body: JSON.stringify({
contents: messages.map(message => ({
role: message.role,
parts: [{ text: message.content }]
})),
}),
})

export const parseStreamResponse = (rawResponse: Response) => {
const encoder = new TextEncoder()
const decoder = new TextDecoder()
if (!rawResponse.ok) {
return new Response(rawResponse.body, {
status: rawResponse.status,
statusText: rawResponse.statusText,
})
}
const apiKey = process.env.GENERATIVE_LANGUAGE_API_KEY
const genAI = new GoogleGenerativeAI(apiKey)

const stream = new ReadableStream({
async start(controller) {
const streamParser = (event: ParsedEvent | ReconnectInterval) => {
if (event.type === 'event') {
try {
const json = JSON.parse(event.data)
json.contents.forEach((content: { parts: { text: string }[] }) => {
content.parts.forEach(part => {
const text = part.text
const queue = encoder.encode(text)
controller.enqueue(queue)
})
})
} catch (e) {
controller.error(e)
}
}
}
export const sendMessage = async(messages: ChatMessage[]) => {
const model = genAI.getGenerativeModel({ model: 'gemini-pro' })

const parser = createParser(streamParser)
for await (const chunk of rawResponse.body as any)
parser.feed(decoder.decode(chunk))
const chat = model.startChat({
history: messages.map(msg => ({
role: msg.role,
parts: msg.parts.map(part => part.text),
})),
generationConfig: {
maxOutputTokens: 4000, // or your desired token limit
},
})

return new Response(stream)
const lastMessage = messages[messages.length - 1]
const result = await model.sendMessageStream(lastMessage.parts.map(part => part.text).join(''))
return result
}

0 comments on commit cc8ca9e

Please sign in to comment.