Skip to content

Commit

Permalink
feat: Provide options for cache types
Browse files Browse the repository at this point in the history
  • Loading branch information
Neet-Nestor committed May 17, 2024
1 parent 3657e0a commit 01b6716
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 40 deletions.
4 changes: 3 additions & 1 deletion app/client/api.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { getClientConfig } from "../config/client";
import { ACCESS_CODE_PREFIX } from "../constant";
import { ModelType, useChatStore } from "../store";
import { CacheType, ModelType, useChatStore } from "../store";
export const ROLES = ["system", "user", "assistant"] as const;
export type MessageRole = (typeof ROLES)[number];

Expand All @@ -22,6 +22,7 @@ export interface RequestMessage {

export interface LLMConfig {
model: string;
cache: CacheType;
temperature?: number;
top_p?: number;
stream?: boolean;
Expand Down Expand Up @@ -60,4 +61,5 @@ export abstract class LLMApi {
abstract chat(options: ChatOptions): Promise<void>;
abstract usage(): Promise<LLMUsage>;
abstract models(): Promise<LLMModel[]>;
abstract clear(): void;
}
7 changes: 6 additions & 1 deletion app/client/webllm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@ import {

import { ChatOptions, LLMApi, LLMConfig } from "./api";
import { ChatCompletionMessageParam } from "@mlc-ai/web-llm";
import { useAppConfig } from "../store";

export class WebLLMApi implements LLMApi {
private currentModel?: string;
private engine?: EngineInterface;

clear() {
this.engine = undefined;
}

async initModel(
config: LLMConfig,
onUpdate?: (message: string, chunk: string) => void,
Expand All @@ -31,7 +36,7 @@ export class WebLLMApi implements LLMApi {
},
appConfig: {
...prebuiltAppConfig,
useIndexedDBCache: true,
useIndexedDBCache: config.cache === "index_db",
},
initProgressCallback: (report: InitProgressReport) => {
onUpdate?.(report.text, report.text);
Expand Down
28 changes: 3 additions & 25 deletions app/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -667,14 +667,13 @@ function _Chat() {
const config = useAppConfig();
const fontSize = config.fontSize;

const currentModel = chatStore.currentSession().mask.modelConfig.model;
const isGenerating = session.isGenerating;

const [showExport, setShowExport] = useState(false);

const inputRef = useRef<HTMLTextAreaElement>(null);
const [userInput, setUserInput] = useState("");
const [isLoading, setIsLoading] = useState(false);
const [isGenerating, setIsGenerating] = useState(false);
const { submitKey, shouldSubmit } = useSubmitHandler();
const scrollRef = useRef<HTMLDivElement>(null);
const isScrolledToBottom = scrollRef?.current
Expand Down Expand Up @@ -771,18 +770,8 @@ function _Chat() {

if (isGenerating) return;
setIsLoading(true);
setIsGenerating(true);
chatStore
.onUserInput(
userInput,
attachImages,
() => {
setIsGenerating(true);
},
() => {
setIsGenerating(false);
},
)
.onUserInput(userInput, attachImages)
.then(() => setIsLoading(false));
setAttachImages([]);
localStorage.setItem(LAST_INPUT_KEY, userInput);
Expand Down Expand Up @@ -934,18 +923,7 @@ function _Chat() {
setIsLoading(true);
const textContent = getMessageTextContent(userMessage);
const images = getMessageImages(userMessage);
chatStore
.onUserInput(
textContent,
images,
() => {
setIsGenerating(true);
},
() => {
setIsGenerating(false);
},
)
.then(() => setIsLoading(false));
chatStore.onUserInput(textContent, images).then(() => setIsLoading(false));
inputRef.current?.focus();
};

Expand Down
32 changes: 31 additions & 1 deletion app/components/settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ import {
import { ModelConfigList } from "./model-config";

import { IconButton } from "./button";
import { SubmitKey, useChatStore, Theme, useAppConfig } from "../store";
import {
SubmitKey,
useChatStore,
Theme,
useAppConfig,
CacheType,
} from "../store";

import Locale, {
AllLangs,
Expand All @@ -37,6 +43,7 @@ import { InputRange } from "./input-range";
import { useNavigate } from "react-router-dom";
import { Avatar, AvatarPicker } from "./emoji";
import { nanoid } from "nanoid";
import { webllm } from "../client/webllm";

function EditPromptModal(props: { id: string; onClose: () => void }) {
const promptStore = usePromptStore();
Expand Down Expand Up @@ -456,6 +463,29 @@ export function Settings() {
</List>

<List id={SlotID.CustomModel}>
<ListItem
title={Locale.Settings.Access.CacheType.Title}
subTitle={Locale.Settings.Access.CacheType.SubTitle}
>
<Select
value="cache"
onChange={(e) => {
webllm.clear();
updateConfig(
(config) =>
(config.cacheType = e.currentTarget
.value as any as CacheType),
);
}}
>
<option value="cache" key="cache">
Cache
</option>
<option value="index_db" key="index_db">
Index DB
</option>
</Select>
</ListItem>
<ListItem
title={Locale.Settings.Access.CustomModel.Title}
subTitle={Locale.Settings.Access.CustomModel.SubTitle}
Expand Down
4 changes: 4 additions & 0 deletions app/locales/cn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ const cn = {
SubTitle: "选择一个特定的 API 版本",
},
},
CacheType: {
Title: "缓存类型",
SubTitle: "使用IndexDB或Cache API作为模型缓存",
},
CustomModel: {
Title: "自定义模型名",
SubTitle: "增加自定义模型可选项,使用英文逗号隔开",
Expand Down
4 changes: 4 additions & 0 deletions app/locales/en.ts
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ const en: LocaleType = {
SubTitle: "Select and input a specific API version",
},
},
CacheType: {
Title: "Cache Type",
SubTitle: "Use IndexDB or Cache API to store model weights",
},
CustomModel: {
Title: "Custom Models",
SubTitle: "Custom model options, seperated by comma",
Expand Down
26 changes: 15 additions & 11 deletions app/store/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ export interface ChatSession {
lastUpdate: number;
lastSummarizeIndex: number;
clearContextIndex?: number;
isGenerating: boolean;

mask: Mask;
}
Expand All @@ -76,6 +77,7 @@ function createEmptySession(): ChatSession {
},
lastUpdate: Date.now(),
lastSummarizeIndex: 0,
isGenerating: false,

mask: createEmptyMask(),
};
Expand Down Expand Up @@ -282,12 +284,7 @@ export const useChatStore = createPersistStore(
get().summarizeSession();
},

async onUserInput(
content: string,
attachImages?: string[],
onGenerateStart?: () => void,
onGenerateFinish?: () => void,
) {
async onUserInput(content: string, attachImages?: string[]) {
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;

Expand Down Expand Up @@ -340,14 +337,17 @@ export const useChatStore = createPersistStore(
savedUserMessage,
botMessage,
]);
session.isGenerating = true;
});

onGenerateStart?.();

// make request
webllm.chat({
messages: sendMessages,
config: { ...modelConfig, stream: true },
config: {
...modelConfig,
cache: useAppConfig.getState().cacheType,
stream: true,
},
onUpdate(message) {
botMessage.streaming = true;
if (message) {
Expand All @@ -363,8 +363,10 @@ export const useChatStore = createPersistStore(
botMessage.content = message;
get().onNewMessage(botMessage);
}
get().updateCurrentSession((session) => {
session.isGenerating = false;
});
ChatControllerPool.remove(session.id, botMessage.id);
onGenerateFinish?.();
},
onError(error) {
const isAborted = error.message.includes("aborted");
Expand All @@ -379,13 +381,13 @@ export const useChatStore = createPersistStore(
botMessage.isError = !isAborted;
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
session.isGenerating = false;
});
ChatControllerPool.remove(
session.id,
botMessage.id ?? messageIndex,
);

onGenerateFinish?.();
console.error("[Chat] failed ", error);
},
onController(controller) {
Expand Down Expand Up @@ -543,6 +545,7 @@ export const useChatStore = createPersistStore(
messages: topicMessages,
config: {
model: session.mask.modelConfig.model,
cache: useAppConfig.getState().cacheType,
stream: false,
},
onFinish(message) {
Expand Down Expand Up @@ -603,6 +606,7 @@ export const useChatStore = createPersistStore(
...modelcfg,
stream: true,
model: session.mask.modelConfig.model,
cache: useAppConfig.getState().cacheType,
},
onUpdate(message) {
session.memoryPrompt = message;
Expand Down
8 changes: 7 additions & 1 deletion app/store/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ export enum Theme {
Light = "light",
}

export enum CacheType {
Cache = "cache",
IndexDB = "index_db",
}

export const DEFAULT_CONFIG = {
lastUpdate: Date.now(), // timestamp, to merge state

Expand All @@ -41,11 +46,12 @@ export const DEFAULT_CONFIG = {

hideBuiltinMasks: false, // dont add builtin masks

cacheType: "cache" as CacheType,
customModels: "",
models: DEFAULT_MODELS as any as LLMModel[],

modelConfig: {
model: "gpt-3.5-turbo" as ModelType,
model: "Llama-3-8B-Instruct-q4f32_1-1k" as ModelType,
temperature: 0.5,
top_p: 1,
max_tokens: 4000,
Expand Down

0 comments on commit 01b6716

Please sign in to comment.