Skip to content

Commit

Permalink
Support rewinding to earlier messages (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
bhackett1024 authored Feb 13, 2025
1 parent af7b8c9 commit dd18d53
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 57 deletions.
3 changes: 3 additions & 0 deletions app/components/chat/BaseChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ interface BaseChatProps {
setUploadedFiles?: (files: File[]) => void;
imageDataList?: string[];
setImageDataList?: (dataList: string[]) => void;
onRewind?: (messageId: string, contents: string) => void;
}

export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
Expand All @@ -77,6 +78,7 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
imageDataList = [],
setImageDataList,
messages,
onRewind,
},
ref,
) => {
Expand Down Expand Up @@ -241,6 +243,7 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
className="flex flex-col w-full flex-1 max-w-chat pb-6 mx-auto z-1"
messages={messages}
isStreaming={isStreaming}
onRewind={onRewind}
/>
) : null;
}}
Expand Down
45 changes: 17 additions & 28 deletions app/components/chat/Chat.client.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ export function resetChatFileWritten() {
}

async function flushSimulationData() {
console.log("FlushSimulationData");
//console.log("FlushSimulationData");

const iframe = getCurrentIFrame();
if (!iframe) {
Expand All @@ -60,7 +60,7 @@ async function flushSimulationData() {
return;
}

console.log("HaveSimulationData", simulationData.length);
//console.log("HaveSimulationData", simulationData.length);

// Add the simulation data to the chat.
await simulationAddData(simulationData);
Expand Down Expand Up @@ -150,24 +150,6 @@ interface ChatProps {

let gNumAborts = 0;

interface InjectedMessage {
message: Message;
previousId: string;
}

function handleInjectMessages(baseMessages: Message[], injectedMessages: InjectedMessage[]) {
const messages = [];
for (const message of baseMessages) {
messages.push(message);
for (const injectedMessage of injectedMessages) {
if (injectedMessage.previousId === message.id) {
messages.push(injectedMessage.message);
}
}
}
return messages;
}

function filterFiles(files: FileMap): FileMap {
const rv: FileMap = {};
for (const [path, file] of Object.entries(files)) {
Expand All @@ -187,7 +169,6 @@ export const ChatImpl = memo(
const [uploadedFiles, setUploadedFiles] = useState<File[]>([]); // Move here
const [imageDataList, setImageDataList] = useState<string[]>([]); // Move here
const [searchParams, setSearchParams] = useSearchParams();
const [injectedMessages, setInjectedMessages] = useState<InjectedMessage[]>([]);
const [simulationLoading, setSimulationLoading] = useState(false);
const files = useStore(workbenchStore.files);
const { promptId } = useSettings();
Expand All @@ -196,7 +177,7 @@ export const ChatImpl = memo(

const [animationScope, animate] = useAnimate();

const { messages: baseMessages, isLoading, input, handleInputChange, setInput, stop, append } = useChat({
const { messages, isLoading, input, handleInputChange, setInput, stop, append, setMessages } = useChat({
api: '/api/chat',
body: {
files: filterFiles(files),
Expand All @@ -213,10 +194,6 @@ export const ChatImpl = memo(
initialInput: Cookies.get(PROMPT_COOKIE_KEY) || '',
});

const messages = useMemo(() => {
return handleInjectMessages(baseMessages, injectedMessages);
}, [baseMessages, injectedMessages]);

useEffect(() => {
const prompt = searchParams.get('prompt');

Expand Down Expand Up @@ -384,7 +361,7 @@ export const ChatImpl = memo(
}

console.log("RecordingMessage", recordingMessage);
setInjectedMessages([...injectedMessages, { message: recordingMessage, previousId: messages[messages.length - 1].id }]);
setMessages([...messages, recordingMessage]);

if (recordingId) {
const info = await enhancedPromptPromise;
Expand All @@ -396,7 +373,7 @@ export const ChatImpl = memo(
simulationEnhancedPrompt = info.enhancedPrompt;

console.log("EnhancedPromptMessage", info.enhancedPromptMessage);
setInjectedMessages([...injectedMessages, { message: info.enhancedPromptMessage, previousId: messages[messages.length - 1].id }]);
setMessages([...messages, info.enhancedPromptMessage]);
}
} finally {
gLockSimulationData = false;
Expand Down Expand Up @@ -451,6 +428,17 @@ export const ChatImpl = memo(
saveProjectContents(lastMessage.id, { content: contentBase64 });
};

const onRewind = async (messageId: string, contents: string) => {
console.log("Rewinding", messageId, contents);

await workbenchStore.restoreProjectContentsBase64(messageId, contents);

const messageIndex = messages.findIndex((message) => message.id === messageId);
if (messageIndex >= 0) {
setMessages(messages.slice(0, messageIndex + 1));
}
};

/**
* Handles the change event for the textarea and updates the input state.
* @param event - The change event from the textarea.
Expand Down Expand Up @@ -517,6 +505,7 @@ export const ChatImpl = memo(
setUploadedFiles={setUploadedFiles}
imageDataList={imageDataList}
setImageDataList={setImageDataList}
onRewind={onRewind}
/>
);
},
Expand Down
19 changes: 4 additions & 15 deletions app/components/chat/LoadProblemButton.tsx
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import React, { useState } from 'react';
import type { Message } from 'ai';
import { toast } from 'react-toastify';
import { createChatFromFolder, type FileArtifact } from '~/utils/folderImport';
import { createChatFromFolder } from '~/utils/folderImport';
import { logStore } from '~/lib/stores/logs'; // Assuming logStore is imported from this location
import { assert, sendCommandDedicatedClient } from '~/lib/replay/ReplayProtocolClient';
import { assert } from '~/lib/replay/ReplayProtocolClient';
import type { BoltProblem } from '~/lib/replay/Problems';
import { getProblem } from '~/lib/replay/Problems';
import JSZip from 'jszip';
import { getProblem, extractFileArtifactsFromRepositoryContents } from '~/lib/replay/Problems';

interface LoadProblemButtonProps {
className?: string;
Expand Down Expand Up @@ -40,17 +39,7 @@ export async function loadProblem(problemId: string, importChat: (description: s

const { repositoryContents, title: problemTitle } = problem;

const zip = new JSZip();
await zip.loadAsync(repositoryContents, { base64: true });

const fileArtifacts: FileArtifact[] = [];
for (const [key, object] of Object.entries(zip.files)) {
if (object.dir) continue;
fileArtifacts.push({
content: await object.async('text'),
path: key,
});
}
const fileArtifacts = await extractFileArtifactsFromRepositoryContents(repositoryContents);

try {
const messages = await createChatFromFolder(fileArtifacts, [], "problem");
Expand Down
21 changes: 12 additions & 9 deletions app/components/chat/Messages.client.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ interface MessagesProps {
className?: string;
isStreaming?: boolean;
messages?: Message[];
onRewind?: (messageId: string, contents: string) => void;
}

interface ProjectContents {
Expand All @@ -27,11 +28,8 @@ export function saveProjectContents(messageId: string, contents: ProjectContents
gProjectContentsByMessageId.set(messageId, contents);
}

// The rewind button is not fully implemented yet.
const EnableRewindButton = false;

export const Messages = React.forwardRef<HTMLDivElement, MessagesProps>((props: MessagesProps, ref) => {
const { id, isStreaming = false, messages = [] } = props;
const { id, isStreaming = false, messages = [], onRewind } = props;

const getLastMessageProjectContents = (index: number) => {
// The message index is for the model response, and the project
Expand All @@ -42,7 +40,11 @@ export const Messages = React.forwardRef<HTMLDivElement, MessagesProps>((props:
return undefined;
}
const previousMessage = messages[index - 2];
return gProjectContentsByMessageId.get(previousMessage.id);
const contents = gProjectContentsByMessageId.get(previousMessage.id);
if (!contents) {
return undefined;
}
return { messageId: previousMessage.id, contents };
};

return (
Expand Down Expand Up @@ -82,13 +84,14 @@ export const Messages = React.forwardRef<HTMLDivElement, MessagesProps>((props:
<AssistantMessage content={content} annotations={message.annotations} />
)}
</div>
{!isUserMessage && messageId && getLastMessageProjectContents(index) && EnableRewindButton && (
{!isUserMessage && messageId && onRewind && getLastMessageProjectContents(index) && (
<div className="flex gap-2 flex-col lg:flex-row">
<WithTooltip tooltip="Rewind to this message">
<WithTooltip tooltip="Undo changes in this message">
<button
onClick={() => {
const contents = getLastMessageProjectContents(index);
assert(contents);
const info = getLastMessageProjectContents(index);
assert(info);
onRewind(info.messageId, info.contents.content);
}}
key="i-ph:arrow-u-up-left"
className={classNames(
Expand Down
17 changes: 17 additions & 0 deletions app/lib/replay/Problems.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import { toast } from "react-toastify";
import { sendCommandDedicatedClient } from "./ReplayProtocolClient";
import type { ProtocolMessage } from "./SimulationPrompt";
import Cookies from 'js-cookie';
import JSZip from 'jszip';
import type { FileArtifact } from "~/utils/folderImport";

export interface BoltProblemComment {
username?: string;
Expand Down Expand Up @@ -149,3 +151,18 @@ export function getProblemsUsername(): string | undefined {
export function setProblemsUsername(username: string) {
Cookies.set(nutProblemsUsernameCookieName, username);
}

export async function extractFileArtifactsFromRepositoryContents(repositoryContents: string): Promise<FileArtifact[]> {
const zip = new JSZip();
await zip.loadAsync(repositoryContents, { base64: true });

const fileArtifacts: FileArtifact[] = [];
for (const [key, object] of Object.entries(zip.files)) {
if (object.dir) continue;
fileArtifacts.push({
content: await object.async('text'),
path: key,
});
}
return fileArtifacts;
}
2 changes: 1 addition & 1 deletion app/lib/replay/Recording.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ function addRecordingMessageHandler(messageHandlerId: string) {
}

async function getSimulationData(): Promise<SimulationData> {
console.log("GetSimulationData", simulationData.length, numSimulationPacketsSent);
//console.log("GetSimulationData", simulationData.length, numSimulationPacketsSent);
const data = simulationData.slice(numSimulationPacketsSent);
numSimulationPacketsSent = simulationData.length;
return data;
Expand Down
2 changes: 1 addition & 1 deletion app/lib/replay/ReplayProtocolClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export function stringToBase64(inputString: string) {
}

function logDebug(msg: string, tags: Record<string, any> = {}) {
console.log(msg, JSON.stringify(tags));
//console.log(msg, JSON.stringify(tags));
}

class ProtocolError extends Error {
Expand Down
65 changes: 62 additions & 3 deletions app/lib/stores/workbench.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import Cookies from 'js-cookie';
import { createSampler } from '~/utils/sampler';
import { uint8ArrayToBase64 } from '../replay/ReplayProtocolClient';
import type { ActionAlert } from '~/types/actions';
import { extractFileArtifactsFromRepositoryContents } from '../replay/Problems';

export interface ArtifactState {
id: string;
Expand Down Expand Up @@ -326,10 +327,12 @@ export class WorkbenchStore {
unreachable('Artifact not found');
}

const action = artifact.runner.actions.get()[data.actionId];
if (data.actionId != 'restore-contents-action-id') {
const action = artifact.runner.actions.get()[data.actionId];

if (!action || action.executed) {
return;
if (!action || action.executed) {
return;
}
}

if (data.action.type === 'file') {
Expand Down Expand Up @@ -421,6 +424,62 @@ export class WorkbenchStore {
return { contentBase64, uniqueProjectName };
}

async restoreProjectContentsBase64(messageId: string, contentBase64: string) {
const fileArtifacts = await extractFileArtifactsFromRepositoryContents(contentBase64);

const modifiedFilePaths = new Set<string>();

// Check if any files we know about have different contents in the artifacts.
const files = this.files.get();
const fileRelativePaths = new Set<string>();
for (const [filePath, dirent] of Object.entries(files)) {
if (dirent?.type === 'file' && !dirent.isBinary) {
const relativePath = extractRelativePath(filePath);
fileRelativePaths.add(relativePath);

const content = dirent.content;

const artifact = fileArtifacts.find((artifact) => artifact.path === relativePath);
const artifactContent = artifact?.content ?? "";

if (content != artifactContent) {
modifiedFilePaths.add(relativePath);
}
}
}

// Also create any new files in the artifacts.
for (const artifact of fileArtifacts) {
if (!fileRelativePaths.has(artifact.path)) {
modifiedFilePaths.add(artifact.path);
}
}

const actionArtifactId = `restore-contents-artifact-id-${messageId}`;

for (const filePath of modifiedFilePaths) {
console.log("RestoreModifiedFile", filePath);

const artifact = fileArtifacts.find((artifact) => artifact.path === filePath);
const artifactContent = artifact?.content ?? "";

const actionId = `restore-contents-action-${messageId}-${filePath}`;
const data: ActionCallbackData = {
actionId,
messageId,
artifactId: actionArtifactId,
action: {
type: 'file',
filePath: filePath,
content: artifactContent,
},
};

this.addAction(data);
this.runAction(data);
}
}

async syncFiles(targetHandle: FileSystemDirectoryHandle) {
const files = this.files.get();
const syncedFiles = [];
Expand Down

0 comments on commit dd18d53

Please sign in to comment.