Skip to content

Commit

Permalink
Gemini
Browse files Browse the repository at this point in the history
Modify to support Gemini
  • Loading branch information
babaohuang committed Dec 14, 2023
1 parent 20f0810 commit 4b6c989
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 151 deletions.
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Your API Key for OpenAI
OPENAI_API_KEY=
GEMINI_API_KEY=
# Provide proxy for OpenAI API. e.g. http://127.0.0.1:7890
HTTPS_PROXY=
# Custom base url for OpenAI API. default: https://api.openai.com
Expand Down
215 changes: 109 additions & 106 deletions src/components/Generator.tsx
Original file line number Diff line number Diff line change
@@ -1,138 +1,141 @@
import { Index, Show, createEffect, createSignal, onCleanup, onMount } from 'solid-js'
import { useThrottleFn } from 'solidjs-use'
import { generateSignature } from '@/utils/auth'
import IconClear from './icons/Clear'
import MessageItem from './MessageItem'
import ErrorMessageItem from './ErrorMessageItem'
import type { ChatMessage, ErrorMessage } from '@/types'
import { Index, Show, createEffect, createSignal, onCleanup, onMount } from 'solid-js';

Check failure on line 1 in src/components/Generator.tsx

View workflow job for this annotation

GitHub Actions / lint

Extra semicolon
import { useThrottleFn } from 'solidjs-use';

Check failure on line 2 in src/components/Generator.tsx

View workflow job for this annotation

GitHub Actions / lint

Extra semicolon
import { generateSignature } from '@/utils/auth';

Check failure on line 3 in src/components/Generator.tsx

View workflow job for this annotation

GitHub Actions / lint

Extra semicolon
import IconClear from './icons/Clear';

Check failure on line 4 in src/components/Generator.tsx

View workflow job for this annotation

GitHub Actions / lint

Extra semicolon
import MessageItem from './MessageItem';

Check failure on line 5 in src/components/Generator.tsx

View workflow job for this annotation

GitHub Actions / lint

Extra semicolon
import ErrorMessageItem from './ErrorMessageItem';

Check failure on line 6 in src/components/Generator.tsx

View workflow job for this annotation

GitHub Actions / lint

Extra semicolon
import type { ChatMessage, ErrorMessage } from '@/types';

Check failure on line 7 in src/components/Generator.tsx

View workflow job for this annotation

GitHub Actions / lint

Extra semicolon

export default () => {
let inputRef: HTMLTextAreaElement
const [messageList, setMessageList] = createSignal<ChatMessage[]>([])
const [currentError, setCurrentError] = createSignal<ErrorMessage>()
const [currentAssistantMessage, setCurrentAssistantMessage] = createSignal('')
const [loading, setLoading] = createSignal(false)
const [controller, setController] = createSignal<AbortController>(null)
const [isStick, setStick] = createSignal(false)
const maxHistoryMessages = parseInt(import.meta.env.PUBLIC_MAX_HISTORY_MESSAGES || '9')
let inputRef: HTMLTextAreaElement;

Check failure on line 10 in src/components/Generator.tsx

View workflow job for this annotation

GitHub Actions / lint

Extra semicolon
const [messageList, setMessageList] = createSignal<ChatMessage[]>([]);

Check failure on line 11 in src/components/Generator.tsx

View workflow job for this annotation

GitHub Actions / lint

Extra semicolon
const [currentError, setCurrentError] = createSignal<ErrorMessage>();

Check failure on line 12 in src/components/Generator.tsx

View workflow job for this annotation

GitHub Actions / lint

Extra semicolon
const [currentAssistantMessage, setCurrentAssistantMessage] = createSignal('');
const [loading, setLoading] = createSignal(false);
const [controller, setController] = createSignal<AbortController>(null);
const [isStick, setStick] = createSignal(false);
const maxHistoryMessages = parseInt(import.meta.env.PUBLIC_MAX_HISTORY_MESSAGES || '9');

createEffect(() => (isStick() && smoothToBottom()))
createEffect(() => (isStick() && smoothToBottom()));

onMount(() => {
let lastPostion = window.scrollY
let lastPostion = window.scrollY;

window.addEventListener('scroll', () => {
const nowPostion = window.scrollY
nowPostion < lastPostion && setStick(false)
lastPostion = nowPostion
})
const nowPostion = window.scrollY;
nowPostion < lastPostion && setStick(false);
lastPostion = nowPostion;
});

try {
if (sessionStorage.getItem('messageList'))
setMessageList(JSON.parse(sessionStorage.getItem('messageList')))
setMessageList(JSON.parse(sessionStorage.getItem('messageList')));

if (localStorage.getItem('stickToBottom') === 'stick')
setStick(true)
setStick(true);
} catch (err) {
console.error(err)
console.error(err);
}

window.addEventListener('beforeunload', handleBeforeUnload)
window.addEventListener('beforeunload', handleBeforeUnload);
onCleanup(() => {
window.removeEventListener('beforeunload', handleBeforeUnload)
})
})
window.removeEventListener('beforeunload', handleBeforeUnload);
});
});

const handleBeforeUnload = () => {
sessionStorage.setItem('messageList', JSON.stringify(messageList()))
isStick() ? localStorage.setItem('stickToBottom', 'stick') : localStorage.removeItem('stickToBottom')
}
sessionStorage.setItem('messageList', JSON.stringify(messageList()));
isStick() ? localStorage.setItem('stickToBottom', 'stick') : localStorage.removeItem('stickToBottom');
};

const handleButtonClick = async() => {
const inputValue = inputRef.value
const handleButtonClick = async () => {
const inputValue = inputRef.value;
if (!inputValue)
return
return;

inputRef.value = ''
inputRef.value = '';
setMessageList([
...messageList(),
{
role: 'user',
content: inputValue,
},
])
requestWithLatestMessage()
instantToBottom()
}
]);
requestWithLatestMessage();
instantToBottom();
};

const smoothToBottom = useThrottleFn(() => {
window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' })
}, 300, false, true)
window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' });
}, 300, false, true);

const instantToBottom = () => {
window.scrollTo({ top: document.body.scrollHeight, behavior: 'instant' })
}

const requestWithLatestMessage = async() => {
setLoading(true)
setCurrentAssistantMessage('')
setCurrentError(null)
const storagePassword = localStorage.getItem('pass')
window.scrollTo({ top: document.body.scrollHeight, behavior: 'instant' });
};

const requestWithLatestMessage = async () => {
setLoading(true);
setCurrentAssistantMessage('');
setCurrentError(null);
const storagePassword = localStorage.getItem('pass');
try {
const controller = new AbortController()
setController(controller)
const requestMessageList = messageList().slice(-maxHistoryMessages)
const timestamp = Date.now()
const controller = new AbortController();
setController(controller);
const requestMessageList = messageList().map(message => ({
role: message.role === 'assistant' ? 'model' : 'user',
parts: [{ text: message.content }],
})).slice(-maxHistoryMessages);
const timestamp = Date.now();
const response = await fetch('/api/generate', {
method: 'POST',
body: JSON.stringify({
messages: requestMessageList,
contents: requestMessageList,
time: timestamp,
pass: storagePassword,
sign: await generateSignature({
t: timestamp,
m: requestMessageList?.[requestMessageList.length - 1]?.content || '',
m: requestMessageList?.[requestMessageList.length - 1]?.parts[0]?.text || '',
})
}),
signal: controller.signal,
})
});
if (!response.ok) {
const error = await response.json()
console.error(error.error)
setCurrentError(error.error)
throw new Error('Request failed')
const error = await response.json();
console.error(error.error);
setCurrentError(error.error);
throw new Error('Request failed');
}
const data = response.body
const data = response.body;
if (!data)
throw new Error('No data')
throw new Error('No data');

const reader = data.getReader()
const decoder = new TextDecoder('utf-8')
let done = false
const reader = data.getReader();
const decoder = new TextDecoder('utf-8');
let done = false;

while (!done) {
const { value, done: readerDone } = await reader.read()
const { value, done: readerDone } = await reader.read();
if (value) {
const char = decoder.decode(value)
const char = decoder.decode(value);
if (char === '\n' && currentAssistantMessage().endsWith('\n'))
continue
continue;

if (char)
setCurrentAssistantMessage(currentAssistantMessage() + char)
setCurrentAssistantMessage(currentAssistantMessage() + char);

isStick() && instantToBottom()
isStick() && instantToBottom();
}
done = readerDone
done = readerDone;
}
} catch (e) {
console.error(e)
setLoading(false)
setController(null)
return
console.error(e);
setLoading(false);
setController(null);
return;
}
archiveCurrentMessage()
isStick() && instantToBottom()
}
archiveCurrentMessage();
isStick() && instantToBottom();
};

const archiveCurrentMessage = () => {
if (currentAssistantMessage()) {
Expand All @@ -142,49 +145,49 @@ export default () => {
role: 'assistant',
content: currentAssistantMessage(),
},
])
setCurrentAssistantMessage('')
setLoading(false)
setController(null)
]);
setCurrentAssistantMessage('');
setLoading(false);
setController(null);
// Disable auto-focus on touch devices
if (!('ontouchstart' in document.documentElement || navigator.maxTouchPoints > 0))
inputRef.focus()
inputRef.focus();
}
}
};

const clear = () => {
inputRef.value = ''
inputRef.style.height = 'auto'
setMessageList([])
setCurrentAssistantMessage('')
setCurrentError(null)
}
inputRef.value = '';
inputRef.style.height = 'auto';
setMessageList([]);
setCurrentAssistantMessage('');
setCurrentError(null);
};

const stopStreamFetch = () => {
if (controller()) {
controller().abort()
archiveCurrentMessage()
controller().abort();
archiveCurrentMessage();
}
}
};

const retryLastFetch = () => {
if (messageList().length > 0) {
const lastMessage = messageList()[messageList().length - 1]
const lastMessage = messageList()[messageList().length - 1];
if (lastMessage.role === 'assistant')
setMessageList(messageList().slice(0, -1))
requestWithLatestMessage()
setMessageList(messageList().slice(0, -1));
requestWithLatestMessage();
}
}
};

const handleKeydown = (e: KeyboardEvent) => {
if (e.isComposing || e.shiftKey)
return
return;

if (e.key === 'Enter') {
e.preventDefault()
handleButtonClick()
e.preventDefault();
handleButtonClick();
}
}
};

return (
<div my-6>
Expand All @@ -204,7 +207,7 @@ export default () => {
message={currentAssistantMessage}
/>
)}
{ currentError() && <ErrorMessageItem data={currentError()} onRetry={retryLastFetch} /> }
{currentError() && <ErrorMessageItem data={currentError()} onRetry={retryLastFetch} />}
<Show
when={!loading()}
fallback={() => (
Expand All @@ -222,8 +225,8 @@ export default () => {
autocomplete="off"
autofocus
onInput={() => {
inputRef.style.height = 'auto'
inputRef.style.height = `${inputRef.scrollHeight}px`
inputRef.style.height = 'auto';
inputRef.style.height = `${inputRef.scrollHeight}px`;
}}
rows="1"
class="gen-textarea"
Expand All @@ -244,5 +247,5 @@ export default () => {
</div>
</div>
</div>
)
}
);
};
24 changes: 9 additions & 15 deletions src/pages/api/generate.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
// #vercel-disable-blocks
import { ProxyAgent, fetch } from 'undici'
import { fetch } from 'undici'
// #vercel-end
import { generatePayload, parseOpenAIStream } from '@/utils/openAI'
import { generatePayload, parseStreamResponse } from '@/utils/openAI'
import { verifySignature } from '@/utils/auth'
import type { APIRoute } from 'astro'

const apiKey = import.meta.env.OPENAI_API_KEY
const httpsProxy = import.meta.env.HTTPS_PROXY
const baseUrl = ((import.meta.env.OPENAI_API_BASE_URL) || 'https://api.openai.com').trim().replace(/\/$/, '')
const apiKey = import.meta.env.GEMINI_API_KEY
const baseUrl = 'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent'

const sitePassword = import.meta.env.SITE_PASSWORD || ''
const passList = sitePassword.split(',') || []

export const post: APIRoute = async(context) => {
const body = await context.request.json()
const { sign, time, messages, pass, temperature } = body
const { sign, time, messages, pass } = body
if (!messages) {
return new Response(JSON.stringify({
error: {
Expand All @@ -35,15 +35,9 @@ export const post: APIRoute = async(context) => {
},
}), { status: 401 })
}
const initOptions = generatePayload(apiKey, messages, temperature)
// #vercel-disable-blocks
if (httpsProxy)
initOptions.dispatcher = new ProxyAgent(httpsProxy)
// #vercel-end
const initOptions = generatePayload(apiKey, messages)

// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-expect-error
const response = await fetch(`${baseUrl}/v1/chat/completions`, initOptions).catch((err: Error) => {
const response = await fetch(`${baseUrl}?key=${apiKey}`, initOptions).catch((err: Error) => {
console.error(err)
return new Response(JSON.stringify({
error: {
Expand All @@ -53,5 +47,5 @@ export const post: APIRoute = async(context) => {
}), { status: 500 })
}) as Response

return parseOpenAIStream(response) as Response
return parseStreamResponse(response) as Response
}
Loading

0 comments on commit 4b6c989

Please sign in to comment.