From 4b6c989f2ef0060a4022f1c5a913e0b4c9b16c89 Mon Sep 17 00:00:00 2001 From: babaohuang Date: Thu, 14 Dec 2023 15:32:18 +0800 Subject: [PATCH] Gemini Modify to support Gemini --- .env.example | 2 +- src/components/Generator.tsx | 215 ++++++++++++++++++----------------- src/pages/api/generate.ts | 24 ++-- src/utils/openAI.ts | 44 +++---- 4 files changed, 134 insertions(+), 151 deletions(-) diff --git a/.env.example b/.env.example index b98ff44c..c469dc69 100644 --- a/.env.example +++ b/.env.example @@ -1,5 +1,5 @@ # Your API Key for OpenAI -OPENAI_API_KEY= +GEMINI_API_KEY= # Provide proxy for OpenAI API. e.g. http://127.0.0.1:7890 HTTPS_PROXY= # Custom base url for OpenAI API. default: https://api.openai.com diff --git a/src/components/Generator.tsx b/src/components/Generator.tsx index 2c1877b0..0c115ee0 100644 --- a/src/components/Generator.tsx +++ b/src/components/Generator.tsx @@ -1,138 +1,141 @@ -import { Index, Show, createEffect, createSignal, onCleanup, onMount } from 'solid-js' -import { useThrottleFn } from 'solidjs-use' -import { generateSignature } from '@/utils/auth' -import IconClear from './icons/Clear' -import MessageItem from './MessageItem' -import ErrorMessageItem from './ErrorMessageItem' -import type { ChatMessage, ErrorMessage } from '@/types' +import { Index, Show, createEffect, createSignal, onCleanup, onMount } from 'solid-js'; +import { useThrottleFn } from 'solidjs-use'; +import { generateSignature } from '@/utils/auth'; +import IconClear from './icons/Clear'; +import MessageItem from './MessageItem'; +import ErrorMessageItem from './ErrorMessageItem'; +import type { ChatMessage, ErrorMessage } from '@/types'; export default () => { - let inputRef: HTMLTextAreaElement - const [messageList, setMessageList] = createSignal([]) - const [currentError, setCurrentError] = createSignal() - const [currentAssistantMessage, setCurrentAssistantMessage] = createSignal('') - const [loading, setLoading] = createSignal(false) - const [controller, setController] = createSignal(null) - const [isStick, setStick] = createSignal(false) - const maxHistoryMessages = parseInt(import.meta.env.PUBLIC_MAX_HISTORY_MESSAGES || '9') + let inputRef: HTMLTextAreaElement; + const [messageList, setMessageList] = createSignal([]); + const [currentError, setCurrentError] = createSignal(); + const [currentAssistantMessage, setCurrentAssistantMessage] = createSignal(''); + const [loading, setLoading] = createSignal(false); + const [controller, setController] = createSignal(null); + const [isStick, setStick] = createSignal(false); + const maxHistoryMessages = parseInt(import.meta.env.PUBLIC_MAX_HISTORY_MESSAGES || '9'); - createEffect(() => (isStick() && smoothToBottom())) + createEffect(() => (isStick() && smoothToBottom())); onMount(() => { - let lastPostion = window.scrollY + let lastPostion = window.scrollY; window.addEventListener('scroll', () => { - const nowPostion = window.scrollY - nowPostion < lastPostion && setStick(false) - lastPostion = nowPostion - }) + const nowPostion = window.scrollY; + nowPostion < lastPostion && setStick(false); + lastPostion = nowPostion; + }); try { if (sessionStorage.getItem('messageList')) - setMessageList(JSON.parse(sessionStorage.getItem('messageList'))) + setMessageList(JSON.parse(sessionStorage.getItem('messageList'))); if (localStorage.getItem('stickToBottom') === 'stick') - setStick(true) + setStick(true); } catch (err) { - console.error(err) + console.error(err); } - window.addEventListener('beforeunload', handleBeforeUnload) + window.addEventListener('beforeunload', handleBeforeUnload); onCleanup(() => { - window.removeEventListener('beforeunload', handleBeforeUnload) - }) - }) + window.removeEventListener('beforeunload', handleBeforeUnload); + }); + }); const handleBeforeUnload = () => { - sessionStorage.setItem('messageList', JSON.stringify(messageList())) - isStick() ? localStorage.setItem('stickToBottom', 'stick') : localStorage.removeItem('stickToBottom') - } + sessionStorage.setItem('messageList', JSON.stringify(messageList())); + isStick() ? localStorage.setItem('stickToBottom', 'stick') : localStorage.removeItem('stickToBottom'); + }; - const handleButtonClick = async() => { - const inputValue = inputRef.value + const handleButtonClick = async () => { + const inputValue = inputRef.value; if (!inputValue) - return + return; - inputRef.value = '' + inputRef.value = ''; setMessageList([ ...messageList(), { role: 'user', content: inputValue, }, - ]) - requestWithLatestMessage() - instantToBottom() - } + ]); + requestWithLatestMessage(); + instantToBottom(); + }; const smoothToBottom = useThrottleFn(() => { - window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' }) - }, 300, false, true) + window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' }); + }, 300, false, true); const instantToBottom = () => { - window.scrollTo({ top: document.body.scrollHeight, behavior: 'instant' }) - } - - const requestWithLatestMessage = async() => { - setLoading(true) - setCurrentAssistantMessage('') - setCurrentError(null) - const storagePassword = localStorage.getItem('pass') + window.scrollTo({ top: document.body.scrollHeight, behavior: 'instant' }); + }; + + const requestWithLatestMessage = async () => { + setLoading(true); + setCurrentAssistantMessage(''); + setCurrentError(null); + const storagePassword = localStorage.getItem('pass'); try { - const controller = new AbortController() - setController(controller) - const requestMessageList = messageList().slice(-maxHistoryMessages) - const timestamp = Date.now() + const controller = new AbortController(); + setController(controller); + const requestMessageList = messageList().map(message => ({ + role: message.role === 'assistant' ? 'model' : 'user', + parts: [{ text: message.content }], + })).slice(-maxHistoryMessages); + const timestamp = Date.now(); const response = await fetch('/api/generate', { method: 'POST', body: JSON.stringify({ - messages: requestMessageList, + contents: requestMessageList, time: timestamp, pass: storagePassword, sign: await generateSignature({ t: timestamp, - m: requestMessageList?.[requestMessageList.length - 1]?.content || '', + m: requestMessageList?.[requestMessageList.length - 1]?.parts[0]?.text || '', }) }), signal: controller.signal, - }) + }); if (!response.ok) { - const error = await response.json() - console.error(error.error) - setCurrentError(error.error) - throw new Error('Request failed') + const error = await response.json(); + console.error(error.error); + setCurrentError(error.error); + throw new Error('Request failed'); } - const data = response.body + const data = response.body; if (!data) - throw new Error('No data') + throw new Error('No data'); - const reader = data.getReader() - const decoder = new TextDecoder('utf-8') - let done = false + const reader = data.getReader(); + const decoder = new TextDecoder('utf-8'); + let done = false; while (!done) { - const { value, done: readerDone } = await reader.read() + const { value, done: readerDone } = await reader.read(); if (value) { - const char = decoder.decode(value) + const char = decoder.decode(value); if (char === '\n' && currentAssistantMessage().endsWith('\n')) - continue + continue; if (char) - setCurrentAssistantMessage(currentAssistantMessage() + char) + setCurrentAssistantMessage(currentAssistantMessage() + char); - isStick() && instantToBottom() + isStick() && instantToBottom(); } - done = readerDone + done = readerDone; } } catch (e) { - console.error(e) - setLoading(false) - setController(null) - return + console.error(e); + setLoading(false); + setController(null); + return; } - archiveCurrentMessage() - isStick() && instantToBottom() - } + archiveCurrentMessage(); + isStick() && instantToBottom(); + }; const archiveCurrentMessage = () => { if (currentAssistantMessage()) { @@ -142,49 +145,49 @@ export default () => { role: 'assistant', content: currentAssistantMessage(), }, - ]) - setCurrentAssistantMessage('') - setLoading(false) - setController(null) + ]); + setCurrentAssistantMessage(''); + setLoading(false); + setController(null); // Disable auto-focus on touch devices if (!('ontouchstart' in document.documentElement || navigator.maxTouchPoints > 0)) - inputRef.focus() + inputRef.focus(); } - } + }; const clear = () => { - inputRef.value = '' - inputRef.style.height = 'auto' - setMessageList([]) - setCurrentAssistantMessage('') - setCurrentError(null) - } + inputRef.value = ''; + inputRef.style.height = 'auto'; + setMessageList([]); + setCurrentAssistantMessage(''); + setCurrentError(null); + }; const stopStreamFetch = () => { if (controller()) { - controller().abort() - archiveCurrentMessage() + controller().abort(); + archiveCurrentMessage(); } - } + }; const retryLastFetch = () => { if (messageList().length > 0) { - const lastMessage = messageList()[messageList().length - 1] + const lastMessage = messageList()[messageList().length - 1]; if (lastMessage.role === 'assistant') - setMessageList(messageList().slice(0, -1)) - requestWithLatestMessage() + setMessageList(messageList().slice(0, -1)); + requestWithLatestMessage(); } - } + }; const handleKeydown = (e: KeyboardEvent) => { if (e.isComposing || e.shiftKey) - return + return; if (e.key === 'Enter') { - e.preventDefault() - handleButtonClick() + e.preventDefault(); + handleButtonClick(); } - } + }; return (
@@ -204,7 +207,7 @@ export default () => { message={currentAssistantMessage} /> )} - { currentError() && } + {currentError() && } ( @@ -222,8 +225,8 @@ export default () => { autocomplete="off" autofocus onInput={() => { - inputRef.style.height = 'auto' - inputRef.style.height = `${inputRef.scrollHeight}px` + inputRef.style.height = 'auto'; + inputRef.style.height = `${inputRef.scrollHeight}px`; }} rows="1" class="gen-textarea" @@ -244,5 +247,5 @@ export default () => {
- ) -} + ); +}; diff --git a/src/pages/api/generate.ts b/src/pages/api/generate.ts index 1ccbf822..991bec05 100644 --- a/src/pages/api/generate.ts +++ b/src/pages/api/generate.ts @@ -1,19 +1,19 @@ // #vercel-disable-blocks -import { ProxyAgent, fetch } from 'undici' +import { fetch } from 'undici' // #vercel-end -import { generatePayload, parseOpenAIStream } from '@/utils/openAI' +import { generatePayload, parseStreamResponse } from '@/utils/openAI' import { verifySignature } from '@/utils/auth' import type { APIRoute } from 'astro' -const apiKey = import.meta.env.OPENAI_API_KEY -const httpsProxy = import.meta.env.HTTPS_PROXY -const baseUrl = ((import.meta.env.OPENAI_API_BASE_URL) || 'https://api.openai.com').trim().replace(/\/$/, '') +const apiKey = import.meta.env.GEMINI_API_KEY +const baseUrl = 'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent' + const sitePassword = import.meta.env.SITE_PASSWORD || '' const passList = sitePassword.split(',') || [] export const post: APIRoute = async(context) => { const body = await context.request.json() - const { sign, time, messages, pass, temperature } = body + const { sign, time, messages, pass } = body if (!messages) { return new Response(JSON.stringify({ error: { @@ -35,15 +35,9 @@ export const post: APIRoute = async(context) => { }, }), { status: 401 }) } - const initOptions = generatePayload(apiKey, messages, temperature) - // #vercel-disable-blocks - if (httpsProxy) - initOptions.dispatcher = new ProxyAgent(httpsProxy) - // #vercel-end + const initOptions = generatePayload(apiKey, messages) - // eslint-disable-next-line @typescript-eslint/ban-ts-comment - // @ts-expect-error - const response = await fetch(`${baseUrl}/v1/chat/completions`, initOptions).catch((err: Error) => { + const response = await fetch(`${baseUrl}?key=${apiKey}`, initOptions).catch((err: Error) => { console.error(err) return new Response(JSON.stringify({ error: { @@ -53,5 +47,5 @@ export const post: APIRoute = async(context) => { }), { status: 500 }) }) as Response - return parseOpenAIStream(response) as Response + return parseStreamResponse(response) as Response } diff --git a/src/utils/openAI.ts b/src/utils/openAI.ts index a51ae95b..cd51e194 100644 --- a/src/utils/openAI.ts +++ b/src/utils/openAI.ts @@ -1,28 +1,24 @@ +// openAI.ts import { createParser } from 'eventsource-parser' import type { ParsedEvent, ReconnectInterval } from 'eventsource-parser' import type { ChatMessage } from '@/types' -export const model = import.meta.env.OPENAI_API_MODEL || 'gpt-3.5-turbo' - export const generatePayload = ( - apiKey: string, messages: ChatMessage[], - temperature: number, -): RequestInit & { dispatcher?: any } => ({ +): RequestInit => ({ headers: { 'Content-Type': 'application/json', - 'Authorization': `Bearer ${apiKey}`, }, method: 'POST', body: JSON.stringify({ - model, - messages, - temperature, - stream: true, + contents: messages.map(message => ({ + role: message.role, + parts: [{ text: message.content }] + })), }), }) -export const parseOpenAIStream = (rawResponse: Response) => { +export const parseStreamResponse = (rawResponse: Response) => { const encoder = new TextEncoder() const decoder = new TextDecoder() if (!rawResponse.ok) { @@ -36,25 +32,15 @@ export const parseOpenAIStream = (rawResponse: Response) => { async start(controller) { const streamParser = (event: ParsedEvent | ReconnectInterval) => { if (event.type === 'event') { - const data = event.data - if (data === '[DONE]') { - controller.close() - return - } try { - // response = { - // id: 'chatcmpl-6pULPSegWhFgi0XQ1DtgA3zTa1WR6', - // object: 'chat.completion.chunk', - // created: 1677729391, - // model: 'gpt-3.5-turbo-0301', - // choices: [ - // { delta: { content: '你' }, index: 0, finish_reason: null } - // ], - // } - const json = JSON.parse(data) - const text = json.choices[0].delta?.content || '' - const queue = encoder.encode(text) - controller.enqueue(queue) + const json = JSON.parse(event.data) + json.contents.forEach((content: { parts: { text: string }[] }) => { + content.parts.forEach(part => { + const text = part.text + const queue = encoder.encode(text) + controller.enqueue(queue) + }) + }) } catch (e) { controller.error(e) }