diff --git a/app/components/chat/BaseChat.tsx b/app/components/chat/BaseChat.tsx index 4fb5989e2..eccd80ffc 100644 --- a/app/components/chat/BaseChat.tsx +++ b/app/components/chat/BaseChat.tsx @@ -51,6 +51,7 @@ interface BaseChatProps { setUploadedFiles?: (files: File[]) => void; imageDataList?: string[]; setImageDataList?: (dataList: string[]) => void; + onRewind?: (messageId: string, contents: string) => void; } export const BaseChat = React.forwardRef( @@ -77,6 +78,7 @@ export const BaseChat = React.forwardRef( imageDataList = [], setImageDataList, messages, + onRewind, }, ref, ) => { @@ -241,6 +243,7 @@ export const BaseChat = React.forwardRef( className="flex flex-col w-full flex-1 max-w-chat pb-6 mx-auto z-1" messages={messages} isStreaming={isStreaming} + onRewind={onRewind} /> ) : null; }} diff --git a/app/components/chat/Chat.client.tsx b/app/components/chat/Chat.client.tsx index 4861094f2..9eb9ffd91 100644 --- a/app/components/chat/Chat.client.tsx +++ b/app/components/chat/Chat.client.tsx @@ -49,7 +49,7 @@ export function resetChatFileWritten() { } async function flushSimulationData() { - console.log("FlushSimulationData"); + //console.log("FlushSimulationData"); const iframe = getCurrentIFrame(); if (!iframe) { @@ -60,7 +60,7 @@ async function flushSimulationData() { return; } - console.log("HaveSimulationData", simulationData.length); + //console.log("HaveSimulationData", simulationData.length); // Add the simulation data to the chat. await simulationAddData(simulationData); @@ -150,24 +150,6 @@ interface ChatProps { let gNumAborts = 0; -interface InjectedMessage { - message: Message; - previousId: string; -} - -function handleInjectMessages(baseMessages: Message[], injectedMessages: InjectedMessage[]) { - const messages = []; - for (const message of baseMessages) { - messages.push(message); - for (const injectedMessage of injectedMessages) { - if (injectedMessage.previousId === message.id) { - messages.push(injectedMessage.message); - } - } - } - return messages; -} - function filterFiles(files: FileMap): FileMap { const rv: FileMap = {}; for (const [path, file] of Object.entries(files)) { @@ -187,7 +169,6 @@ export const ChatImpl = memo( const [uploadedFiles, setUploadedFiles] = useState([]); // Move here const [imageDataList, setImageDataList] = useState([]); // Move here const [searchParams, setSearchParams] = useSearchParams(); - const [injectedMessages, setInjectedMessages] = useState([]); const [simulationLoading, setSimulationLoading] = useState(false); const files = useStore(workbenchStore.files); const { promptId } = useSettings(); @@ -196,7 +177,7 @@ export const ChatImpl = memo( const [animationScope, animate] = useAnimate(); - const { messages: baseMessages, isLoading, input, handleInputChange, setInput, stop, append } = useChat({ + const { messages, isLoading, input, handleInputChange, setInput, stop, append, setMessages } = useChat({ api: '/api/chat', body: { files: filterFiles(files), @@ -213,10 +194,6 @@ export const ChatImpl = memo( initialInput: Cookies.get(PROMPT_COOKIE_KEY) || '', }); - const messages = useMemo(() => { - return handleInjectMessages(baseMessages, injectedMessages); - }, [baseMessages, injectedMessages]); - useEffect(() => { const prompt = searchParams.get('prompt'); @@ -384,7 +361,7 @@ export const ChatImpl = memo( } console.log("RecordingMessage", recordingMessage); - setInjectedMessages([...injectedMessages, { message: recordingMessage, previousId: messages[messages.length - 1].id }]); + setMessages([...messages, recordingMessage]); if (recordingId) { const info = await enhancedPromptPromise; @@ -396,7 +373,7 @@ export const ChatImpl = memo( simulationEnhancedPrompt = info.enhancedPrompt; console.log("EnhancedPromptMessage", info.enhancedPromptMessage); - setInjectedMessages([...injectedMessages, { message: info.enhancedPromptMessage, previousId: messages[messages.length - 1].id }]); + setMessages([...messages, info.enhancedPromptMessage]); } } finally { gLockSimulationData = false; @@ -451,6 +428,17 @@ export const ChatImpl = memo( saveProjectContents(lastMessage.id, { content: contentBase64 }); }; + const onRewind = async (messageId: string, contents: string) => { + console.log("Rewinding", messageId, contents); + + await workbenchStore.restoreProjectContentsBase64(messageId, contents); + + const messageIndex = messages.findIndex((message) => message.id === messageId); + if (messageIndex >= 0) { + setMessages(messages.slice(0, messageIndex + 1)); + } + }; + /** * Handles the change event for the textarea and updates the input state. * @param event - The change event from the textarea. @@ -517,6 +505,7 @@ export const ChatImpl = memo( setUploadedFiles={setUploadedFiles} imageDataList={imageDataList} setImageDataList={setImageDataList} + onRewind={onRewind} /> ); }, diff --git a/app/components/chat/LoadProblemButton.tsx b/app/components/chat/LoadProblemButton.tsx index 3f90e61be..ee50dbe74 100644 --- a/app/components/chat/LoadProblemButton.tsx +++ b/app/components/chat/LoadProblemButton.tsx @@ -1,12 +1,11 @@ import React, { useState } from 'react'; import type { Message } from 'ai'; import { toast } from 'react-toastify'; -import { createChatFromFolder, type FileArtifact } from '~/utils/folderImport'; +import { createChatFromFolder } from '~/utils/folderImport'; import { logStore } from '~/lib/stores/logs'; // Assuming logStore is imported from this location -import { assert, sendCommandDedicatedClient } from '~/lib/replay/ReplayProtocolClient'; +import { assert } from '~/lib/replay/ReplayProtocolClient'; import type { BoltProblem } from '~/lib/replay/Problems'; -import { getProblem } from '~/lib/replay/Problems'; -import JSZip from 'jszip'; +import { getProblem, extractFileArtifactsFromRepositoryContents } from '~/lib/replay/Problems'; interface LoadProblemButtonProps { className?: string; @@ -40,17 +39,7 @@ export async function loadProblem(problemId: string, importChat: (description: s const { repositoryContents, title: problemTitle } = problem; - const zip = new JSZip(); - await zip.loadAsync(repositoryContents, { base64: true }); - - const fileArtifacts: FileArtifact[] = []; - for (const [key, object] of Object.entries(zip.files)) { - if (object.dir) continue; - fileArtifacts.push({ - content: await object.async('text'), - path: key, - }); - } + const fileArtifacts = await extractFileArtifactsFromRepositoryContents(repositoryContents); try { const messages = await createChatFromFolder(fileArtifacts, [], "problem"); diff --git a/app/components/chat/Messages.client.tsx b/app/components/chat/Messages.client.tsx index b8dcfb554..b98fbf4ac 100644 --- a/app/components/chat/Messages.client.tsx +++ b/app/components/chat/Messages.client.tsx @@ -15,6 +15,7 @@ interface MessagesProps { className?: string; isStreaming?: boolean; messages?: Message[]; + onRewind?: (messageId: string, contents: string) => void; } interface ProjectContents { @@ -27,11 +28,8 @@ export function saveProjectContents(messageId: string, contents: ProjectContents gProjectContentsByMessageId.set(messageId, contents); } -// The rewind button is not fully implemented yet. -const EnableRewindButton = false; - export const Messages = React.forwardRef((props: MessagesProps, ref) => { - const { id, isStreaming = false, messages = [] } = props; + const { id, isStreaming = false, messages = [], onRewind } = props; const getLastMessageProjectContents = (index: number) => { // The message index is for the model response, and the project @@ -42,7 +40,11 @@ export const Messages = React.forwardRef((props: return undefined; } const previousMessage = messages[index - 2]; - return gProjectContentsByMessageId.get(previousMessage.id); + const contents = gProjectContentsByMessageId.get(previousMessage.id); + if (!contents) { + return undefined; + } + return { messageId: previousMessage.id, contents }; }; return ( @@ -82,13 +84,14 @@ export const Messages = React.forwardRef((props: )} - {!isUserMessage && messageId && getLastMessageProjectContents(index) && EnableRewindButton && ( + {!isUserMessage && messageId && onRewind && getLastMessageProjectContents(index) && (
- +