Skip to content

Commit

Permalink
feat: connect to mlc-llm REST API endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Neet-Nestor committed Jun 23, 2024
1 parent 2ac0357 commit 2fb025c
Show file tree
Hide file tree
Showing 22 changed files with 694 additions and 589 deletions.
8 changes: 4 additions & 4 deletions app/client/api.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { ChatCompletionFinishReason, CompletionUsage } from "@mlc-ai/web-llm";
import { CacheType, ModelType } from "../store";
import { CacheType, Model } from "../store";
export const ROLES = ["system", "user", "assistant"] as const;
export type MessageRole = (typeof ROLES)[number];

export const Models = ["gpt-3.5-turbo", "gpt-4"] as const;
export type ChatModel = ModelType;
export type ChatModel = Model;

export interface MultimodalContent {
type: "text" | "image_url";
Expand Down Expand Up @@ -40,7 +40,6 @@ export interface ChatOptions {
usage?: CompletionUsage,
) => void;
onError?: (err: Error) => void;
onController?: (controller: AbortController) => void;
}

export interface LLMUsage {
Expand All @@ -60,7 +59,7 @@ export interface ModelRecord {
buffer_size_required_bytes?: number;
low_resource_required?: boolean;
required_features?: string[];
recommended_config: {
recommended_config?: {
temperature?: number;
top_p?: number;
presence_penalty?: number;
Expand All @@ -71,4 +70,5 @@ export interface ModelRecord {
export abstract class LLMApi {
abstract chat(options: ChatOptions): Promise<void>;
abstract abort(): Promise<void>;
abstract models(): Promise<ModelRecord[] | Model[]>;
}
100 changes: 100 additions & 0 deletions app/client/mlcllm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import log from "loglevel";
import { ChatOptions, LLMApi } from "./api";
import { ChatCompletionFinishReason } from "@mlc-ai/web-llm";

export class MlcLLMApi implements LLMApi {
private endpoint: string;
private abortController: AbortController | null = null;

constructor(endpoint: string) {
this.endpoint = endpoint;
}

async chat(options: ChatOptions) {
const { messages, config } = options;

const payload = {
messages: messages,
...config,
};

// Instantiate a new AbortController for this request
this.abortController = new AbortController();
const { signal } = this.abortController;

let reply: string = "";
let stopReason: ChatCompletionFinishReason | undefined;

try {
const response = await fetch(`${this.endpoint}/v1/chat/completions`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(payload),
signal,
});

if (config.stream) {
const reader = response.body!.getReader();
while (true) {
const { value, done } = await reader.read();
if (done) break;
// Extracting the data part from the server response
const chunk = new TextDecoder("utf-8").decode(value);
const result = chunk.match(/data: (.+)/);
if (result) {
const data = JSON.parse(result[1]);
if (data.choices && data.choices.length > 0) {
reply += data.choices[0].delta.content; // Append the content
options.onUpdate?.(reply, chunk); // Handle the chunk update

if (data.choices[0].finish_reason) {
stopReason = data.choices[0].finish_reason;
}
}
}

if (chunk === "[DONE]") {
// Ending the stream when "[DONE]" is found
break;
}
}
options.onFinish(reply, stopReason);
} else {
const data = await response.json();
options.onFinish(
data.choices[0].message.content,
data.choices[0].finish_reason,
);
}
} catch (error: any) {
if (error.name === "AbortError") {
log.info("MLC_LLM: chat aborted");
} else {
log.error("MLC_LLM: Fetch error:", error);
options.onError?.(error);
}
}
}

// Implements the abort method to cancel the request
async abort() {
this.abortController?.abort();
}

async models() {
try {
const response = await fetch(`${this.endpoint}/v1/models`, {
method: "GET",
});
const jsonRes = await response.json();
return jsonRes.data.map((model: { id: string }) => ({
name: model.id,
display_name: model.id.split("/")[model.id.split("/").length - 1],
}));
} catch (error: any) {
log.error("MLC_LLM: Fetch error:", error);
}
}
}
7 changes: 5 additions & 2 deletions app/client/webllm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {
import { ChatOptions, LLMApi, LLMConfig, RequestMessage } from "./api";
import { LogLevel } from "@mlc-ai/web-llm";
import { fixMessage } from "../utils";
import { DEFAULT_MODELS } from "../constant";

const KEEP_ALIVE_INTERVAL = 5_000;

Expand Down Expand Up @@ -215,6 +216,8 @@ export class WebLLMApi implements LLMApi {
usage: chatCompletion.usage,
};
}
}

export const WebLLMContext = createContext<WebLLMApi | undefined>(undefined);
async models() {
return DEFAULT_MODELS;
}
}
36 changes: 24 additions & 12 deletions app/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ import {
createMessage,
useAppConfig,
DEFAULT_TOPIC,
ModelType,
Model,
ModelClient,
} from "../store";

import {
Expand Down Expand Up @@ -92,9 +93,9 @@ import { ChatCommandPrefix, useChatCommand, useCommand } from "../command";
import { prettyObject } from "../utils/format";
import { ExportMessageModal } from "./exporter";
import { MultimodalContent } from "../client/api";
import { WebLLMContext } from "../client/webllm";
import { Template, useTemplateStore } from "../store/template";
import Image from "next/image";
import { MLCLLMContext, WebLLMContext } from "../context";

export function ScrollDownToast(prop: { show: boolean; onclick: () => void }) {
return (
Expand Down Expand Up @@ -125,7 +126,7 @@ export function SessionConfigModel(props: { onClose: () => void }) {
};

return (
<div className="modal-template">
<div className="screen-model-container">
<Modal
title={Locale.Context.Edit}
onClose={() => props.onClose()}
Expand Down Expand Up @@ -556,7 +557,7 @@ export function ChatActions(props: {
onClose={() => setShowModelSelector(false)}
onSelection={(s) => {
if (s.length === 0) return;
config.selectModel(s[0] as ModelType);
config.selectModel(s[0] as Model);
showToast(s[0]);
}}
/>
Expand Down Expand Up @@ -606,6 +607,12 @@ function _Chat() {
const [uploading, setUploading] = useState(false);
const [showEditPromptModal, setShowEditPromptModal] = useState(false);
const webllm = useContext(WebLLMContext)!;
const mlcllm = useContext(MLCLLMContext)!;

const llm =
config.modelClientType === ModelClient.MLCLLM_API ? mlcllm : webllm;

const models = config.models;

// prompt hints
const promptStore = usePromptStore();
Expand Down Expand Up @@ -685,7 +692,7 @@ function _Chat() {

if (isStreaming) return;

chatStore.onUserInput(userInput, webllm, attachImages);
chatStore.onUserInput(userInput, llm, attachImages);
setAttachImages([]);
localStorage.setItem(LAST_INPUT_KEY, userInput);
setUserInput("");
Expand Down Expand Up @@ -713,7 +720,7 @@ function _Chat() {

// stop response
const onUserStop = () => {
webllm.abort();
llm.abort();
chatStore.stopStreaming();
};

Expand Down Expand Up @@ -836,7 +843,7 @@ function _Chat() {
// resend the message
const textContent = getMessageTextContent(userMessage);
const images = getMessageImages(userMessage);
chatStore.onUserInput(textContent, webllm, images);
chatStore.onUserInput(textContent, llm, images);
inputRef.current?.focus();
};

Expand Down Expand Up @@ -867,7 +874,13 @@ function _Chat() {
]
: [],
);
}, [config.sendPreviewBubble, context, session.messages, userInput]);
}, [
config.sendPreviewBubble,
context,
session.messages,
session.messages.length,
userInput,
]);

const [msgRenderIndex, _setMsgRenderIndex] = useState(
Math.max(0, renderMessages.length - CHAT_PAGE_SIZE),
Expand Down Expand Up @@ -1183,10 +1196,9 @@ function _Chat() {
)}
{message.role === "assistant" && (
<div className={styles["chat-message-role-name"]}>
{config.models.find((m) => m.name === message.model)
? config.models.find(
(m) => m.name === message.model,
)!.display_name
{models.find((m) => m.name === message.model)
? models.find((m) => m.name === message.model)!
.display_name
: message.model}
</div>
)}
Expand Down
4 changes: 2 additions & 2 deletions app/components/emoji.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import EmojiPicker, {
Theme as EmojiTheme,
} from "emoji-picker-react";

import { ModelType } from "../store";
import { Model } from "../store";

import MlcIcon from "../icons/mlc.svg";

Expand All @@ -31,7 +31,7 @@ export function AvatarPicker(props: {
);
}

export function Avatar(props: { model?: ModelType; avatar?: string }) {
export function Avatar(props: { model?: Model; avatar?: string }) {
if (props.model) {
return (
<div className="bot-avatar mlc-icon no-dark">
Expand Down
4 changes: 2 additions & 2 deletions app/components/exporter.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable @next/next/no-img-element */
import { ChatMessage, ModelType, useAppConfig, useChatStore } from "../store";
import { ChatMessage, Model, useAppConfig, useChatStore } from "../store";
import Locale from "../locales";
import styles from "./exporter.module.scss";
import {
Expand Down Expand Up @@ -46,7 +46,7 @@ const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {

export function ExportMessageModal(props: { onClose: () => void }) {
return (
<div className="modal-template">
<div className="screen-model-container">
<Modal
title={Locale.Export.Title}
onClose={props.onClose}
Expand Down
43 changes: 38 additions & 5 deletions app/components/home.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import styles from "./home.module.scss";

import log from "loglevel";
import dynamic from "next/dynamic";
import { useState, useEffect, useRef } from "react";
import { useState, useEffect, useRef, useMemo, useCallback } from "react";
import {
HashRouter as Router,
Routes,
Expand All @@ -20,13 +20,15 @@ import LoadingIcon from "../icons/three-dots.svg";

import Locale from "../locales";
import { getCSSVar, useMobileScreen } from "../utils";
import { Path, SlotID } from "../constant";
import { DEFAULT_MODELS, Path, SlotID } from "../constant";
import { ErrorBoundary } from "./error";
import { getISOLang, getLang } from "../locales";
import { SideBar } from "./sidebar";
import { useAppConfig } from "../store/config";
import { WebLLMApi, WebLLMContext } from "../client/webllm";
import { useChatStore } from "../store";
import { WebLLMApi } from "../client/webllm";
import { ModelClient, useChatStore } from "../store";
import { MLCLLMContext, WebLLMContext } from "../context";
import { MlcLLMApi } from "../client/mlcllm";

export function Loading(props: { noLogo?: boolean }) {
return (
Expand Down Expand Up @@ -251,6 +253,17 @@ const useWebLLM = () => {
return { webllm, isWebllmActive };
};

const useMlcLLM = () => {
const config = useAppConfig();
const [mlcllm, setMlcLlm] = useState<MlcLLMApi | undefined>(undefined);

useEffect(() => {
setMlcLlm(new MlcLLMApi(config.modelConfig.mlc_endpoint));
}, [config.modelConfig.mlc_endpoint, setMlcLlm]);

return mlcllm;
};

const useLoadUrlParam = () => {
const config = useAppConfig();

Expand Down Expand Up @@ -306,14 +319,32 @@ const useLogLevel = (webllm?: WebLLMApi) => {
}, [config.logLevel, webllm?.webllm?.engine]);
};

const useModels = (mlcllm: MlcLLMApi | undefined) => {
const config = useAppConfig();

useEffect(() => {
if (config.modelClientType == ModelClient.WEBLLM) {
config.setModels(DEFAULT_MODELS);
} else if (config.modelClientType == ModelClient.MLCLLM_API) {
if (mlcllm) {
mlcllm.models().then((models) => {
config.setModels(models);
});
}
}
}, [config.modelClientType, mlcllm]);
};

export function Home() {
const hasHydrated = useHasHydrated();
const { webllm, isWebllmActive } = useWebLLM();
const mlcllm = useMlcLLM();

useSwitchTheme();
useHtmlLang();
useLoadUrlParam();
useStopStreamingMessages();
useModels(mlcllm);
useLogLevel(webllm);

if (!hasHydrated || !webllm || !isWebllmActive) {
Expand All @@ -328,7 +359,9 @@ export function Home() {
<ErrorBoundary>
<Router>
<WebLLMContext.Provider value={webllm}>
<Screen />
<MLCLLMContext.Provider value={mlcllm}>
<Screen />
</MLCLLMContext.Provider>
</WebLLMContext.Provider>
</Router>
</ErrorBoundary>
Expand Down
Loading

0 comments on commit 2fb025c

Please sign in to comment.