diff --git a/app/client/api.ts b/app/client/api.ts index 88c7ccff..2e945eea 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -1,10 +1,10 @@ import { ChatCompletionFinishReason, CompletionUsage } from "@mlc-ai/web-llm"; -import { CacheType, ModelType } from "../store"; +import { CacheType, Model } from "../store"; export const ROLES = ["system", "user", "assistant"] as const; export type MessageRole = (typeof ROLES)[number]; export const Models = ["gpt-3.5-turbo", "gpt-4"] as const; -export type ChatModel = ModelType; +export type ChatModel = Model; export interface MultimodalContent { type: "text" | "image_url"; @@ -40,7 +40,6 @@ export interface ChatOptions { usage?: CompletionUsage, ) => void; onError?: (err: Error) => void; - onController?: (controller: AbortController) => void; } export interface LLMUsage { @@ -60,7 +59,7 @@ export interface ModelRecord { buffer_size_required_bytes?: number; low_resource_required?: boolean; required_features?: string[]; - recommended_config: { + recommended_config?: { temperature?: number; top_p?: number; presence_penalty?: number; @@ -71,4 +70,5 @@ export interface ModelRecord { export abstract class LLMApi { abstract chat(options: ChatOptions): Promise; abstract abort(): Promise; + abstract models(): Promise; } diff --git a/app/client/mlcllm.ts b/app/client/mlcllm.ts new file mode 100644 index 00000000..af4bb027 --- /dev/null +++ b/app/client/mlcllm.ts @@ -0,0 +1,100 @@ +import log from "loglevel"; +import { ChatOptions, LLMApi } from "./api"; +import { ChatCompletionFinishReason } from "@mlc-ai/web-llm"; + +export class MlcLLMApi implements LLMApi { + private endpoint: string; + private abortController: AbortController | null = null; + + constructor(endpoint: string) { + this.endpoint = endpoint; + } + + async chat(options: ChatOptions) { + const { messages, config } = options; + + const payload = { + messages: messages, + ...config, + }; + + // Instantiate a new AbortController for this request + this.abortController = new AbortController(); + const { signal } = this.abortController; + + let reply: string = ""; + let stopReason: ChatCompletionFinishReason | undefined; + + try { + const response = await fetch(`${this.endpoint}/v1/chat/completions`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(payload), + signal, + }); + + if (config.stream) { + const reader = response.body!.getReader(); + while (true) { + const { value, done } = await reader.read(); + if (done) break; + // Extracting the data part from the server response + const chunk = new TextDecoder("utf-8").decode(value); + const result = chunk.match(/data: (.+)/); + if (result) { + const data = JSON.parse(result[1]); + if (data.choices && data.choices.length > 0) { + reply += data.choices[0].delta.content; // Append the content + options.onUpdate?.(reply, chunk); // Handle the chunk update + + if (data.choices[0].finish_reason) { + stopReason = data.choices[0].finish_reason; + } + } + } + + if (chunk === "[DONE]") { + // Ending the stream when "[DONE]" is found + break; + } + } + options.onFinish(reply, stopReason); + } else { + const data = await response.json(); + options.onFinish( + data.choices[0].message.content, + data.choices[0].finish_reason, + ); + } + } catch (error: any) { + if (error.name === "AbortError") { + log.info("MLC_LLM: chat aborted"); + } else { + log.error("MLC_LLM: Fetch error:", error); + options.onError?.(error); + } + } + } + + // Implements the abort method to cancel the request + async abort() { + this.abortController?.abort(); + } + + async models() { + try { + const response = await fetch(`${this.endpoint}/v1/models`, { + method: "GET", + }); + const jsonRes = await response.json(); + return jsonRes.data.map((model: { id: string }) => ({ + name: model.id, + display_name: model.id.split("/")[model.id.split("/").length - 1], + })); + } catch (error: any) { + log.error("MLC_LLM: Fetch error:", error); + } + } +} diff --git a/app/client/webllm.ts b/app/client/webllm.ts index e105bbdd..8473196e 100644 --- a/app/client/webllm.ts +++ b/app/client/webllm.ts @@ -17,6 +17,7 @@ import { import { ChatOptions, LLMApi, LLMConfig, RequestMessage } from "./api"; import { LogLevel } from "@mlc-ai/web-llm"; import { fixMessage } from "../utils"; +import { DEFAULT_MODELS } from "../constant"; const KEEP_ALIVE_INTERVAL = 5_000; @@ -215,6 +216,8 @@ export class WebLLMApi implements LLMApi { usage: chatCompletion.usage, }; } -} -export const WebLLMContext = createContext(undefined); + async models() { + return DEFAULT_MODELS; + } +} diff --git a/app/components/chat.tsx b/app/components/chat.tsx index e6ff6c33..c59491ae 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -45,7 +45,8 @@ import { createMessage, useAppConfig, DEFAULT_TOPIC, - ModelType, + Model, + ModelClient, } from "../store"; import { @@ -92,9 +93,9 @@ import { ChatCommandPrefix, useChatCommand, useCommand } from "../command"; import { prettyObject } from "../utils/format"; import { ExportMessageModal } from "./exporter"; import { MultimodalContent } from "../client/api"; -import { WebLLMContext } from "../client/webllm"; import { Template, useTemplateStore } from "../store/template"; import Image from "next/image"; +import { MLCLLMContext, WebLLMContext } from "../context"; export function ScrollDownToast(prop: { show: boolean; onclick: () => void }) { return ( @@ -125,7 +126,7 @@ export function SessionConfigModel(props: { onClose: () => void }) { }; return ( -
+
props.onClose()} @@ -556,7 +557,7 @@ export function ChatActions(props: { onClose={() => setShowModelSelector(false)} onSelection={(s) => { if (s.length === 0) return; - config.selectModel(s[0] as ModelType); + config.selectModel(s[0] as Model); showToast(s[0]); }} /> @@ -606,6 +607,12 @@ function _Chat() { const [uploading, setUploading] = useState(false); const [showEditPromptModal, setShowEditPromptModal] = useState(false); const webllm = useContext(WebLLMContext)!; + const mlcllm = useContext(MLCLLMContext)!; + + const llm = + config.modelClientType === ModelClient.MLCLLM_API ? mlcllm : webllm; + + const models = config.models; // prompt hints const promptStore = usePromptStore(); @@ -685,7 +692,7 @@ function _Chat() { if (isStreaming) return; - chatStore.onUserInput(userInput, webllm, attachImages); + chatStore.onUserInput(userInput, llm, attachImages); setAttachImages([]); localStorage.setItem(LAST_INPUT_KEY, userInput); setUserInput(""); @@ -713,7 +720,7 @@ function _Chat() { // stop response const onUserStop = () => { - webllm.abort(); + llm.abort(); chatStore.stopStreaming(); }; @@ -836,7 +843,7 @@ function _Chat() { // resend the message const textContent = getMessageTextContent(userMessage); const images = getMessageImages(userMessage); - chatStore.onUserInput(textContent, webllm, images); + chatStore.onUserInput(textContent, llm, images); inputRef.current?.focus(); }; @@ -867,7 +874,13 @@ function _Chat() { ] : [], ); - }, [config.sendPreviewBubble, context, session.messages, userInput]); + }, [ + config.sendPreviewBubble, + context, + session.messages, + session.messages.length, + userInput, + ]); const [msgRenderIndex, _setMsgRenderIndex] = useState( Math.max(0, renderMessages.length - CHAT_PAGE_SIZE), @@ -1183,10 +1196,9 @@ function _Chat() { )} {message.role === "assistant" && (
- {config.models.find((m) => m.name === message.model) - ? config.models.find( - (m) => m.name === message.model, - )!.display_name + {models.find((m) => m.name === message.model) + ? models.find((m) => m.name === message.model)! + .display_name : message.model}
)} diff --git a/app/components/emoji.tsx b/app/components/emoji.tsx index 0209acf2..a7d64f11 100644 --- a/app/components/emoji.tsx +++ b/app/components/emoji.tsx @@ -4,7 +4,7 @@ import EmojiPicker, { Theme as EmojiTheme, } from "emoji-picker-react"; -import { ModelType } from "../store"; +import { Model } from "../store"; import MlcIcon from "../icons/mlc.svg"; @@ -31,7 +31,7 @@ export function AvatarPicker(props: { ); } -export function Avatar(props: { model?: ModelType; avatar?: string }) { +export function Avatar(props: { model?: Model; avatar?: string }) { if (props.model) { return (
diff --git a/app/components/exporter.tsx b/app/components/exporter.tsx index 18ba266b..9e9a456c 100644 --- a/app/components/exporter.tsx +++ b/app/components/exporter.tsx @@ -1,5 +1,5 @@ /* eslint-disable @next/next/no-img-element */ -import { ChatMessage, ModelType, useAppConfig, useChatStore } from "../store"; +import { ChatMessage, Model, useAppConfig, useChatStore } from "../store"; import Locale from "../locales"; import styles from "./exporter.module.scss"; import { @@ -46,7 +46,7 @@ const Markdown = dynamic(async () => (await import("./markdown")).Markdown, { export function ExportMessageModal(props: { onClose: () => void }) { return ( -
+
{ return { webllm, isWebllmActive }; }; +const useMlcLLM = () => { + const config = useAppConfig(); + const [mlcllm, setMlcLlm] = useState(undefined); + + useEffect(() => { + setMlcLlm(new MlcLLMApi(config.modelConfig.mlc_endpoint)); + }, [config.modelConfig.mlc_endpoint, setMlcLlm]); + + return mlcllm; +}; + const useLoadUrlParam = () => { const config = useAppConfig(); @@ -306,14 +319,32 @@ const useLogLevel = (webllm?: WebLLMApi) => { }, [config.logLevel, webllm?.webllm?.engine]); }; +const useModels = (mlcllm: MlcLLMApi | undefined) => { + const config = useAppConfig(); + + useEffect(() => { + if (config.modelClientType == ModelClient.WEBLLM) { + config.setModels(DEFAULT_MODELS); + } else if (config.modelClientType == ModelClient.MLCLLM_API) { + if (mlcllm) { + mlcllm.models().then((models) => { + config.setModels(models); + }); + } + } + }, [config.modelClientType, mlcllm]); +}; + export function Home() { const hasHydrated = useHasHydrated(); const { webllm, isWebllmActive } = useWebLLM(); + const mlcllm = useMlcLLM(); useSwitchTheme(); useHtmlLang(); useLoadUrlParam(); useStopStreamingMessages(); + useModels(mlcllm); useLogLevel(webllm); if (!hasHydrated || !webllm || !isWebllmActive) { @@ -328,7 +359,9 @@ export function Home() { - + + + diff --git a/app/components/model-config.tsx b/app/components/model-config.tsx index 58137971..665e15cf 100644 --- a/app/components/model-config.tsx +++ b/app/components/model-config.tsx @@ -1,222 +1,275 @@ import { ModalConfigValidator, ModelConfig, - ModelType, useAppConfig, + ModelClient, } from "../store"; +import CancelIcon from "../icons/cancel.svg"; +import ConfirmIcon from "../icons/confirm.svg"; +import ConnectIcon from "../icons/connection.svg"; + import Locale from "../locales"; import { InputRange } from "./input-range"; -import { ListItem, Select } from "./ui-lib"; -import React from "react"; +import { List, ListItem, Modal, Select } from "./ui-lib"; +import React, { useState } from "react"; +import { IconButton } from "./button"; -export function ModelConfigList(props: { - modelConfig: ModelConfig; - selectModel: (model: ModelType) => void; - updateConfig: (updater: (config: ModelConfig) => void) => void; -}) { +export function ModelConfigList() { const config = useAppConfig(); const models = config.models; + const [showApiConnectModel, setShowApiConnectModel] = useState(false); + + const [endpointInput, setEndpointInput] = useState( + config.modelConfig.mlc_endpoint, + ); + + const updateModelConfig = (updater: (config: ModelConfig) => void) => { + const modelConfig = { ...config.modelConfig } as ModelConfig; + updater(modelConfig); + config.update((config) => (config.modelConfig = modelConfig)); + }; return ( <> - + - - { - props.updateConfig( - (config) => - (config.temperature = ModalConfigValidator.temperature( - e.currentTarget.valueAsNumber, - )), - ); - }} - > - - - { - props.updateConfig( - (config) => - (config.top_p = ModalConfigValidator.top_p( - e.currentTarget.valueAsNumber, - )), - ); - }} - > - - - - props.updateConfig( - (config) => - (config.max_tokens = ModalConfigValidator.max_tokens( - e.currentTarget.valueAsNumber, - )), - ) - } - > - - <> - - { - props.updateConfig( - (config) => - (config.presence_penalty = - ModalConfigValidator.presence_penalty( + {config.modelClientType === ModelClient.WEBLLM && ( + <> + + + + + { + updateModelConfig( + (config) => + (config.temperature = ModalConfigValidator.temperature( e.currentTarget.valueAsNumber, )), - ); - }} - > - - - - { - props.updateConfig( - (config) => - (config.frequency_penalty = - ModalConfigValidator.frequency_penalty( + ); + }} + > + + + { + updateModelConfig( + (config) => + (config.top_p = ModalConfigValidator.top_p( e.currentTarget.valueAsNumber, )), - ); - }} - > - + ); + }} + > + + + + updateModelConfig( + (config) => + (config.max_tokens = ModalConfigValidator.max_tokens( + e.currentTarget.valueAsNumber, + )), + ) + } + > + + + { + updateModelConfig( + (config) => + (config.presence_penalty = + ModalConfigValidator.presence_penalty( + e.currentTarget.valueAsNumber, + )), + ); + }} + > + - - - props.updateConfig( - (config) => - (config.enableInjectSystemPrompts = e.currentTarget.checked), - ) - } - > - + + { + updateModelConfig( + (config) => + (config.frequency_penalty = + ModalConfigValidator.frequency_penalty( + e.currentTarget.valueAsNumber, + )), + ); + }} + > + + + )} - - - props.updateConfig( - (config) => (config.template = e.currentTarget.value), - ) - } - > - - - - - props.updateConfig( - (config) => (config.historyMessageCount = e.target.valueAsNumber), - ) - } - > - + {config.modelClientType === ModelClient.MLCLLM_API && ( + <> + + } + text={Locale.Settings.MlcLlmApi.Connect.Title} + onClick={() => setShowApiConnectModel(true)} + type="primary" + /> + {" "} + + + + + )} - - - props.updateConfig( - (config) => - (config.compressMessageLengthThreshold = - e.currentTarget.valueAsNumber), - ) - } - > - - - - props.updateConfig( - (config) => (config.sendMemory = e.currentTarget.checked), - ) - } - > - + {showApiConnectModel && ( +
+ setShowApiConnectModel(false)} + actions={[ + { + setShowApiConnectModel(false); + }} + icon={} + bordered + shadow + tabIndex={0} + >, + { + if (!/^(http:\/\/|https:\/\/).*/i.test(endpointInput)) { + config.update( + (config) => + (config.modelConfig.mlc_endpoint = + "http://" + endpointInput), + ); + } else { + config.update( + (config) => + (config.modelConfig.mlc_endpoint = endpointInput), + ); + } + setShowApiConnectModel(false); + config.update((config) => { + config.modelClientType = ModelClient.MLCLLM_API; + }); + }} + icon={} + bordered + shadow + tabIndex={0} + >, + ]} + > + + + setEndpointInput(e.currentTarget.value)} + > + + + +
+ )} ); } diff --git a/app/components/settings.tsx b/app/components/settings.tsx index a3556bf2..db642511 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -35,14 +35,14 @@ import { InputRange } from "./input-range"; import { useNavigate } from "react-router-dom"; import { nanoid } from "nanoid"; import { LogLevel } from "@mlc-ai/web-llm"; -import { WebLLMContext } from "../client/webllm"; +import { WebLLMContext } from "../context"; function EditPromptModal(props: { id: string; onClose: () => void }) { const promptStore = usePromptStore(); const prompt = promptStore.get(props.id); return prompt ? ( -
+
void }) { }, [searchInput]); return ( -
+
props.onClose?.()} @@ -278,15 +278,87 @@ export function Settings() {
- { - const modelConfig = { ...config.modelConfig }; - updater(modelConfig); - config.update((config) => (config.modelConfig = modelConfig)); - }} - /> + + + + + + + config.update( + (config) => + (config.enableInjectSystemPrompts = + e.currentTarget.checked), + ) + } + > + + + + config.update( + (config) => (config.template = e.currentTarget.value), + ) + } + > + + + + config.update( + (config) => + (config.historyMessageCount = e.target.valueAsNumber), + ) + } + > + + + + config.update( + (config) => + (config.compressMessageLengthThreshold = + e.currentTarget.valueAsNumber), + ) + } + > + + + + config.update( + (config) => (config.sendMemory = e.currentTarget.checked), + ) + } + > + diff --git a/app/components/template.tsx b/app/components/template.tsx index bb4073fb..f9652ec7 100644 --- a/app/components/template.tsx +++ b/app/components/template.tsx @@ -22,7 +22,7 @@ import { import { ChatMessage, createMessage, - ModelType, + Model, useAppConfig, useChatStore, } from "../store"; @@ -67,7 +67,7 @@ function reorder(list: T[], startIndex: number, endIndex: number): T[] { return result; } -export function TemplateAvatar(props: { avatar: string; model?: ModelType }) { +export function TemplateAvatar(props: { avatar: string; model?: Model }) { return props.avatar !== DEFAULT_TEMPLATE_AVATAR ? ( ) : ( @@ -523,7 +523,7 @@ export function TemplatePage() {
{editingTemplate && ( -
+
(undefined); +export const MLCLLMContext = createContext(undefined); diff --git a/app/layout.tsx b/app/layout.tsx index 8bb45598..bacfc40a 100644 --- a/app/layout.tsx +++ b/app/layout.tsx @@ -28,7 +28,7 @@ const cspHeader = ` default-src 'self'; script-src 'self' 'unsafe-eval' 'unsafe-inline'; worker-src 'self'; - connect-src 'self' blob: data: https:; + connect-src 'self' blob: data: https: http:; style-src 'self' 'unsafe-inline'; img-src 'self' blob: data: https:; font-src 'self'; diff --git a/app/locales/cn.ts b/app/locales/cn.ts index 5a447008..05a220ca 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -207,94 +207,6 @@ const cn = { Check: "重新检查", NoAccess: "输入 API Key 或访问密码查看余额", }, - - Access: { - AccessCode: { - Title: "访问密码", - SubTitle: "管理员已开启加密访问", - Placeholder: "请输入访问密码", - }, - CustomEndpoint: { - Title: "自定义接口", - SubTitle: "是否使用自定义 Azure 或 OpenAI 服务", - }, - Provider: { - Title: "模型服务商", - SubTitle: "切换不同的服务商", - }, - OpenAI: { - ApiKey: { - Title: "API Key", - SubTitle: "使用自定义 OpenAI Key 绕过密码访问限制", - Placeholder: "OpenAI API Key", - }, - - Endpoint: { - Title: "接口地址", - SubTitle: "除默认地址外,必须包含 http(s)://", - }, - }, - Azure: { - ApiKey: { - Title: "接口密钥", - SubTitle: "使用自定义 Azure Key 绕过密码访问限制", - Placeholder: "Azure API Key", - }, - - Endpoint: { - Title: "接口地址", - SubTitle: "样例:", - }, - - ApiVerion: { - Title: "接口版本 (azure api version)", - SubTitle: "选择指定的部分版本", - }, - }, - Anthropic: { - ApiKey: { - Title: "接口密钥", - SubTitle: "使用自定义 Anthropic Key 绕过密码访问限制", - Placeholder: "Anthropic API Key", - }, - - Endpoint: { - Title: "接口地址", - SubTitle: "样例:", - }, - - ApiVerion: { - Title: "接口版本 (claude api version)", - SubTitle: "选择一个特定的 API 版本输入", - }, - }, - Google: { - ApiKey: { - Title: "API 密钥", - SubTitle: "从 Google AI 获取您的 API 密钥", - Placeholder: "输入您的 Google AI Studio API 密钥", - }, - - Endpoint: { - Title: "终端地址", - SubTitle: "示例:", - }, - - ApiVersion: { - Title: "API 版本(仅适用于 gemini-pro)", - SubTitle: "选择一个特定的 API 版本", - }, - }, - CacheType: { - Title: "缓存类型", - SubTitle: "使用IndexDB或Cache API作为模型缓存", - }, - CustomModel: { - Title: "自定义模型名", - SubTitle: "增加自定义模型可选项,使用英文逗号隔开", - }, - }, - Model: "模型 (model)", Temperature: { Title: "随机性 (temperature)", diff --git a/app/locales/en.ts b/app/locales/en.ts index a72d55a7..fb55b51e 100644 --- a/app/locales/en.ts +++ b/app/locales/en.ts @@ -215,44 +215,22 @@ const en = { Check: "Check", NoAccess: "Enter API Key to check balance", }, - Access: { - AccessCode: { - Title: "Access Code", - SubTitle: "Access control Enabled", - Placeholder: "Enter Code", - }, - CustomEndpoint: { - Title: "Custom Endpoint", - SubTitle: "Use custom Azure or OpenAI service", - }, - Provider: { - Title: "Model Provider", - SubTitle: "Select Azure or OpenAI", - }, - CustomModel: { - Title: "Custom Models", - SubTitle: "Custom model options, seperated by comma", - }, - Google: { - ApiKey: { - Title: "API Key", - SubTitle: "Obtain your API Key from Google AI", - Placeholder: "Enter your Google AI Studio API Key", - }, - - Endpoint: { - Title: "Endpoint Address", - SubTitle: "Example:", - }, + Model: "Model", + ModelClientType: { + Title: "Model Type", + WebLlm: "WebLLM Models", + MlcLlm: "MLC-LLM REST API Endpoint (Advanced)", + }, - ApiVersion: { - Title: "API Version (specific to gemini-pro)", - SubTitle: "Select a specific API version", - }, + MlcLlmApi: { + Title: "API Endpoint", + SubTitle: "MLC-LLM serve API endpoint", + Connect: { + Title: "Connect", + SubTitle: "Connect to the API endpoint", }, }, - Model: "Model", Temperature: { Title: "Temperature", SubTitle: "A larger value makes the more random output", @@ -397,6 +375,9 @@ const en = { Error: "The WebLLM worker has lost connection. Please close all tabs of WebLLM Chat and try opening WebLLM Chat again.", }, + MlcLLMConnect: { + Title: "Connect to MLC-LLM API Endpoint", + }, }; export default en; diff --git a/app/locales/pt.ts b/app/locales/pt.ts index 7c662100..96bc8608 100644 --- a/app/locales/pt.ts +++ b/app/locales/pt.ts @@ -203,26 +203,6 @@ const pt: PartialLocaleType = { Check: "Verificar", NoAccess: "Insira a Chave API para verificar o saldo", }, - Access: { - AccessCode: { - Title: "Código de Acesso", - SubTitle: "Controle de Acesso Habilitado", - Placeholder: "Insira o Código", - }, - CustomEndpoint: { - Title: "Endpoint Personalizado", - SubTitle: "Use serviço personalizado Azure ou OpenAI", - }, - Provider: { - Title: "Provedor do Modelo", - SubTitle: "Selecione Azure ou OpenAI", - }, - CustomModel: { - Title: "Modelos Personalizados", - SubTitle: "Opções de modelo personalizado, separados por vírgula", - }, - }, - Model: "Modelo", Temperature: { Title: "Temperatura", diff --git a/app/locales/sk.ts b/app/locales/sk.ts index cc1b72e4..637c4f78 100644 --- a/app/locales/sk.ts +++ b/app/locales/sk.ts @@ -205,43 +205,6 @@ const sk: PartialLocaleType = { Check: "Skontrolovať", NoAccess: "Zadajte API kľúč na skontrolovanie zostatku", }, - Access: { - AccessCode: { - Title: "Prístupový kód", - SubTitle: "Povolený prístupový kód", - Placeholder: "Zadajte kód", - }, - CustomEndpoint: { - Title: "Vlastný koncový bod", - SubTitle: "Použiť vlastnú službu Azure alebo OpenAI", - }, - Provider: { - Title: "Poskytovateľ modelu", - SubTitle: "Vyberte Azure alebo OpenAI", - }, - CustomModel: { - Title: "Vlastné modely", - SubTitle: "Možnosti vlastného modelu, oddelené čiarkou", - }, - Google: { - ApiKey: { - Title: "API kľúč", - SubTitle: - "Obísť obmedzenia prístupu heslom pomocou vlastného API kľúča Google AI Studio", - Placeholder: "API kľúč Google AI Studio", - }, - - Endpoint: { - Title: "Adresa koncového bodu", - SubTitle: "Príklad:", - }, - - ApiVersion: { - Title: "Verzia API (gemini-pro verzia API)", - SubTitle: "Vyberte špecifickú verziu časti", - }, - }, - }, Model: "Model", Temperature: { diff --git a/app/locales/tw.ts b/app/locales/tw.ts index bdde2f36..8f25d3b2 100644 --- a/app/locales/tw.ts +++ b/app/locales/tw.ts @@ -251,89 +251,6 @@ const tw = { NoAccess: "輸入 API Key 檢視餘額", }, - Access: { - AccessCode: { - Title: "存取密碼", - SubTitle: "管理員已開啟加密存取", - Placeholder: "請輸入存取密碼", - }, - CustomEndpoint: { - Title: "自定義介面 (Endpoint)", - SubTitle: "是否使用自定義 Azure 或 OpenAI 服務", - }, - Provider: { - Title: "模型服務商", - SubTitle: "切換不同的服務商", - }, - OpenAI: { - ApiKey: { - Title: "API Key", - SubTitle: "使用自定義 OpenAI Key 繞過密碼存取限制", - Placeholder: "OpenAI API Key", - }, - - Endpoint: { - Title: "介面(Endpoint) 地址", - SubTitle: "除預設地址外,必須包含 http(s)://", - }, - }, - Azure: { - ApiKey: { - Title: "介面金鑰", - SubTitle: "使用自定義 Azure Key 繞過密碼存取限制", - Placeholder: "Azure API Key", - }, - - Endpoint: { - Title: "介面(Endpoint) 地址", - SubTitle: "樣例:", - }, - - ApiVerion: { - Title: "介面版本 (azure api version)", - SubTitle: "選擇指定的部分版本", - }, - }, - Anthropic: { - ApiKey: { - Title: "API 金鑰", - SubTitle: "從 Anthropic AI 取得您的 API 金鑰", - Placeholder: "Anthropic API Key", - }, - - Endpoint: { - Title: "終端地址", - SubTitle: "範例:", - }, - - ApiVerion: { - Title: "API 版本 (claude api version)", - SubTitle: "選擇一個特定的 API 版本輸入", - }, - }, - Google: { - ApiKey: { - Title: "API 金鑰", - SubTitle: "從 Google AI 取得您的 API 金鑰", - Placeholder: "輸入您的 Google AI Studio API 金鑰", - }, - - Endpoint: { - Title: "終端地址", - SubTitle: "範例:", - }, - - ApiVersion: { - Title: "API 版本(僅適用於 gemini-pro)", - SubTitle: "選擇一個特定的 API 版本", - }, - }, - CustomModel: { - Title: "自定義模型名", - SubTitle: "增加自定義模型可選項,使用英文逗號隔開", - }, - }, - Model: "模型 (model)", Temperature: { Title: "隨機性 (temperature)", diff --git a/app/store/chat.ts b/app/store/chat.ts index 06c7f7b7..47bac112 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -3,7 +3,7 @@ import { trimTopic, getMessageTextContent } from "../utils"; import log from "loglevel"; import Locale, { getLang } from "../locales"; import { showToast } from "../components/ui-lib"; -import { ModelConfig, ModelType, useAppConfig } from "./config"; +import { ModelConfig, Model, useAppConfig, ConfigType } from "./config"; import { createEmptyTemplate, Template } from "./template"; import { DEFAULT_INPUT_TEMPLATE, @@ -11,7 +11,7 @@ import { DEFAULT_SYSTEM_TEMPLATE, StoreKey, } from "../constant"; -import { RequestMessage, MultimodalContent } from "../client/api"; +import { RequestMessage, MultimodalContent, LLMApi } from "../client/api"; import { estimateTokenLength } from "../utils/token"; import { nanoid } from "nanoid"; import { createPersistStore } from "../utils/store"; @@ -24,7 +24,7 @@ export type ChatMessage = RequestMessage & { isError?: boolean; id: string; stopReason?: ChatCompletionFinishReason; - model?: ModelType; + model?: Model; usage?: CompletionUsage; }; @@ -92,13 +92,15 @@ function countMessages(msgs: ChatMessage[]) { ); } -function fillTemplateWith(input: string, modelConfig: ModelConfig) { +function fillTemplateWith(input: string, modelConfig: ConfigType) { // Find the model in the DEFAULT_MODELS array that matches the modelConfig.model - const modelInfo = DEFAULT_MODELS.find((m) => m.name === modelConfig.model); + const modelInfo = DEFAULT_MODELS.find( + (m) => m.name === modelConfig.modelConfig.model, + ); const vars = { provider: modelInfo?.provider || "unknown", - model: modelConfig.model, + model: modelConfig.modelConfig.model, time: new Date().toString(), lang: getLang(), input: input, @@ -269,19 +271,19 @@ export const useChatStore = createPersistStore( })); }, - onNewMessage(message: ChatMessage, webllm: WebLLMApi) { + onNewMessage(message: ChatMessage, llm: LLMApi) { get().updateCurrentSession((session) => { session.messages = session.messages.concat(); session.lastUpdate = Date.now(); }); get().updateStat(message); - get().summarizeSession(webllm); + get().summarizeSession(llm); }, - onUserInput(content: string, webllm: WebLLMApi, attachImages?: string[]) { + onUserInput(content: string, llm: LLMApi, attachImages?: string[]) { const modelConfig = useAppConfig.getState().modelConfig; - const userContent = fillTemplateWith(content, modelConfig); + const userContent = fillTemplateWith(content, useAppConfig.getState()); log.debug("[User Input] after template: ", userContent); let mContent: string | MultimodalContent[] = userContent; @@ -335,7 +337,7 @@ export const useChatStore = createPersistStore( }); // make request - webllm.chat({ + llm.chat({ messages: sendMessages, config: { ...modelConfig, @@ -357,7 +359,7 @@ export const useChatStore = createPersistStore( botMessage.stopReason = stopReason; if (message) { botMessage.content = message; - get().onNewMessage(botMessage, webllm); + get().onNewMessage(botMessage, llm); } get().updateCurrentSession((session) => { session.isGenerating = false; @@ -396,7 +398,8 @@ export const useChatStore = createPersistStore( getMessagesWithMemory() { const session = get().currentSession(); - const modelConfig = useAppConfig.getState().modelConfig; + const config = useAppConfig.getState(); + const modelConfig = config.modelConfig; const clearContextIndex = session.clearContextIndex ?? 0; const messages = session.messages.slice(); const totalMessageCount = session.messages.length; @@ -405,7 +408,7 @@ export const useChatStore = createPersistStore( const contextPrompts = session.template.context.slice(); // system prompts, to get close to OpenAI Web ChatGPT - const shouldInjectSystemPrompts = modelConfig.enableInjectSystemPrompts; + const shouldInjectSystemPrompts = config.enableInjectSystemPrompts; var systemPrompts: ChatMessage[] = []; systemPrompts = shouldInjectSystemPrompts @@ -413,7 +416,7 @@ export const useChatStore = createPersistStore( createMessage({ role: "system", content: fillTemplateWith("", { - ...modelConfig, + ...config, template: DEFAULT_SYSTEM_TEMPLATE, }), }), @@ -428,7 +431,7 @@ export const useChatStore = createPersistStore( // long term memory const shouldSendLongTermMemory = - modelConfig.sendMemory && + config.sendMemory && session.memoryPrompt && session.memoryPrompt.length > 0 && session.lastSummarizeIndex > clearContextIndex; @@ -440,7 +443,7 @@ export const useChatStore = createPersistStore( // short term memory const shortTermMemoryStartIndex = Math.max( 0, - totalMessageCount - modelConfig.historyMessageCount, + totalMessageCount - config.historyMessageCount, ); // lets concat send messages, including 4 parts: @@ -498,7 +501,7 @@ export const useChatStore = createPersistStore( }); }, - summarizeSession(webllm: WebLLMApi) { + summarizeSession(llm: LLMApi) { const config = useAppConfig.getState(); const session = get().currentSession(); const modelConfig = useAppConfig.getState().modelConfig; @@ -519,7 +522,7 @@ export const useChatStore = createPersistStore( content: Locale.Store.Prompt.Topic, }), ); - webllm.chat({ + llm.chat({ messages: topicMessages, config: { model: modelConfig.model, @@ -548,7 +551,7 @@ export const useChatStore = createPersistStore( if (historyMsgLength > modelConfig?.max_tokens ?? 4000) { const n = toBeSummarizedMsgs.length; toBeSummarizedMsgs = toBeSummarizedMsgs.slice( - Math.max(0, n - modelConfig.historyMessageCount), + Math.max(0, n - config.historyMessageCount), ); } @@ -561,12 +564,12 @@ export const useChatStore = createPersistStore( "[Chat History] ", toBeSummarizedMsgs, historyMsgLength, - modelConfig.compressMessageLengthThreshold, + config.compressMessageLengthThreshold, ); if ( - historyMsgLength > modelConfig.compressMessageLengthThreshold && - modelConfig.sendMemory + historyMsgLength > config.compressMessageLengthThreshold && + config.sendMemory ) { /** Destruct max_tokens while summarizing * this param is just shit @@ -601,7 +604,7 @@ export const useChatStore = createPersistStore( } log.debug("summarizeSession", messages); - webllm.chat({ + llm.chat({ messages: toBeSummarizedMsgs, config: { ...modelcfg, @@ -632,17 +635,21 @@ export const useChatStore = createPersistStore( if (session.messages.length === 0) { return; } - const lastMessage = session.messages[session.messages.length - 1]; + const messages = [...session.messages]; + const lastMessage = messages[messages.length - 1]; if ( lastMessage.role === "assistant" && lastMessage.streaming && lastMessage.content.length === 0 ) { // This message generation is interrupted by refresh and is stuck - session.messages.splice(session.messages.length - 1, 1); + messages.splice(session.messages.length - 1, 1); } // Reset streaming status for all messages - session.messages.forEach((m) => (m.streaming = false)); + session.messages = messages.map((m) => ({ + ...m, + streaming: false, + })); }); set(() => ({ sessions })); }, diff --git a/app/store/config.ts b/app/store/config.ts index bde2e797..417ccf5a 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -8,7 +8,7 @@ import { } from "../constant"; import { createPersistStore } from "../utils/store"; -export type ModelType = (typeof DEFAULT_MODELS)[number]["name"]; +export type Model = (typeof DEFAULT_MODELS)[number]["name"]; export enum SubmitKey { Enter = "Enter", @@ -29,50 +29,101 @@ export enum CacheType { IndexDB = "index_db", } -export const DEFAULT_CONFIG = { +export enum ModelClient { + WEBLLM = "webllm", + MLCLLM_API = "mlc-llm-api", +} + +export type ModelConfig = { + model: Model; + + // Chat configs + temperature: number; + top_p: number; + max_tokens: number; + presence_penalty: number; + frequency_penalty: number; + + // MLC LLM configs + mlc_endpoint: string; +}; + +export type ConfigType = { + lastUpdate: number; // timestamp, to merge state + + submitKey: SubmitKey; + avatar: string; + fontSize: number; + theme: Theme; + tightBorder: boolean; + sendPreviewBubble: boolean; + enableAutoGenerateTitle: boolean; + sidebarWidth: number; + + disablePromptHint: boolean; + hideBuiltinTemplates: boolean; + + sendMemory: boolean; + historyMessageCount: number; + compressMessageLengthThreshold: number; + enableInjectSystemPrompts: boolean; + template: string; + + modelClientType: ModelClient; + models: ModelRecord[]; + + cacheType: CacheType; + logLevel: LogLevel; + modelConfig: ModelConfig; +}; + +const DEFAULT_MODEL_CONFIG: ModelConfig = { + model: DEFAULT_MODELS[0].name, + + // Chat configs + temperature: 1.0, + top_p: 1, + max_tokens: 4000, + presence_penalty: 0, + frequency_penalty: 0, + + // Use recommended config to overwrite above parameters + ...DEFAULT_MODELS[0].recommended_config, + + mlc_endpoint: "", +}; + +export const DEFAULT_CONFIG: ConfigType = { lastUpdate: Date.now(), // timestamp, to merge state submitKey: SubmitKey.Enter, avatar: "1f603", fontSize: 14, - theme: Theme.Auto as Theme, + theme: Theme.Auto, tightBorder: false, sendPreviewBubble: true, enableAutoGenerateTitle: true, sidebarWidth: DEFAULT_SIDEBAR_WIDTH, disablePromptHint: false, - hideBuiltinTemplates: false, // dont add builtin masks - cacheType: "cache" as CacheType, - logLevel: "INFO" as LogLevel, - models: DEFAULT_MODELS as any as ModelRecord[], - - modelConfig: { - model: DEFAULT_MODELS[0].name, - sendMemory: true, - historyMessageCount: 4, - compressMessageLengthThreshold: 1000, - enableInjectSystemPrompts: false, - template: DEFAULT_INPUT_TEMPLATE, - - // Chat configs - temperature: 1.0, - top_p: 1, - max_tokens: 4000, - presence_penalty: 0, - frequency_penalty: 0, - - // Use recommended config to overwrite above parameters - ...DEFAULT_MODELS[0].recommended_config, - }, + sendMemory: true, + historyMessageCount: 4, + compressMessageLengthThreshold: 1000, + enableInjectSystemPrompts: false, + template: DEFAULT_INPUT_TEMPLATE, + + modelClientType: ModelClient.WEBLLM, + models: DEFAULT_MODELS, + cacheType: CacheType.Cache, + logLevel: "INFO", + + modelConfig: DEFAULT_MODEL_CONFIG, }; export type ChatConfig = typeof DEFAULT_CONFIG; -export type ModelConfig = ChatConfig["modelConfig"]; - export function limitNumber( x: number, min: number, @@ -88,7 +139,7 @@ export function limitNumber( export const ModalConfigValidator = { model(x: string) { - return x as ModelType; + return x as Model; }, max_tokens(x: number) { return limitNumber(x, 0, 512000, 1024); @@ -114,32 +165,7 @@ export const useAppConfig = createPersistStore( set(() => ({ ...DEFAULT_CONFIG })); }, - mergeModels(newModels: ModelRecord[]) { - if (!newModels || newModels.length === 0) { - return; - } - - const oldModels = get().models; - const modelMap: Record = {}; - - for (const model of oldModels) { - modelMap[model.name] = model; - } - - for (const model of newModels) { - modelMap[model.name] = model; - } - - set(() => ({ - models: Object.values(modelMap), - })); - }, - - allModels() { - return get().models; - }, - - selectModel(model: ModelType) { + selectModel(model: Model) { const config = DEFAULT_MODELS.find((m) => m.name === model); set((state) => ({ @@ -152,6 +178,17 @@ export const useAppConfig = createPersistStore( })); }, + setModels(models: ModelRecord[]) { + set((state) => ({ + ...state, + models, + modelConfig: { + ...state.modelConfig, + model: models[0].name, + }, + })); + }, + updateModelConfig(config: Partial) { set((state) => ({ ...state, @@ -164,20 +201,48 @@ export const useAppConfig = createPersistStore( }), { name: StoreKey.Config, - version: 0.41, + version: 0.43, migrate: (persistedState, version) => { if (version < 0.41) { return { ...DEFAULT_CONFIG, ...(persistedState as any), models: DEFAULT_MODELS as any as ModelRecord[], + + modelConfig: { + model: DEFAULT_MODELS[0].name, + + // Chat configs + temperature: 1.0, + top_p: 1, + max_tokens: 4000, + presence_penalty: 0, + frequency_penalty: 0, + + // Use recommended config to overwrite above parameters + ...DEFAULT_MODELS[0].recommended_config, + }, + }; + } + if (version < 0.42) { + return { + ...DEFAULT_CONFIG, + ...(persistedState as any), + models: DEFAULT_MODELS as any as ModelRecord[], + + sendMemory: (persistedState as any).modelConfig?.sendMemory || true, + historyMessageCount: + (persistedState as any).modelConfig?.historyMessageCount || 4, + compressMessageLengthThreshold: + (persistedState as any).modelConfig + ?.compressMessageLengthThreshold || 1000, + enableInjectSystemPrompts: + (persistedState as any).modelConfig?.enableInjectSystemPrompts || + false, + template: DEFAULT_INPUT_TEMPLATE, + modelConfig: { model: DEFAULT_MODELS[0].name, - sendMemory: true, - historyMessageCount: 4, - compressMessageLengthThreshold: 1000, - enableInjectSystemPrompts: false, - template: DEFAULT_INPUT_TEMPLATE, // Chat configs temperature: 1.0, diff --git a/app/styles/globals.scss b/app/styles/globals.scss index 619ceac4..75b31da4 100644 --- a/app/styles/globals.scss +++ b/app/styles/globals.scss @@ -250,7 +250,7 @@ div.math { overflow-x: auto; } -.modal-template { +.screen-model-container { z-index: 9999; position: fixed; top: 0; diff --git a/next.config.mjs b/next.config.mjs index f243d3b0..669d6fc4 100644 --- a/next.config.mjs +++ b/next.config.mjs @@ -10,7 +10,7 @@ const cspHeader = ` default-src 'self'; script-src 'self' 'unsafe-eval' 'unsafe-inline'; worker-src 'self'; - connect-src 'self' blob: data: https:; + connect-src 'self' blob: data: https: http:; style-src 'self' 'unsafe-inline'; img-src 'self' blob: data: https:; font-src 'self';