Skip to content

Commit

Permalink
feat: generate text using vercel ai sdk
Browse files Browse the repository at this point in the history
laginha committed Sep 27, 2024
1 parent 7b8accb commit b30e39b
Showing 7 changed files with 1,743 additions and 200 deletions.
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
VITE_OPENAI_API_BASE=""
VITE_OPENAI_GET_KEY=""
VITE_OPENAI_API_BASE=""
VITE_GEMINI_API_BASE=""
VITE_ANTHROPIC_API_MODELS='claude-3-5-sonnet-20240620,claude-3-haiku-20240307'
VITE_GOOGLE_API_MODELS='gemini-1.5-flash,gemini-1.5-pro'
1,736 changes: 1,646 additions & 90 deletions package-lock.json

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -9,14 +9,17 @@
"preview": "vite preview"
},
"dependencies": {
"@ai-sdk/anthropic": "^0.0.50",
"@ai-sdk/google": "^0.0.50",
"@ai-sdk/openai": "^0.0.62",
"@chakra-ui/icons": "^2.0.17",
"@chakra-ui/react": "^2.5.1",
"@emotion/react": "^11.10.6",
"@emotion/styled": "^11.10.6",
"ai": "^3.4.4",
"framer-motion": "^9.0.4",
"highlightjs-solidity": "^2.0.6",
"mixpanel-browser": "^2.46.0",
"openai": "^3.3.0",
"openai-streams": "^6.2.0",
"re-resizable": "^6.9.9",
"react": "^18.2.0",
151 changes: 45 additions & 106 deletions src/components/App.tsx
Original file line number Diff line number Diff line change
@@ -67,7 +67,7 @@ import { NavigationBar } from "./utils/NavigationBar";
import { CheckCircleIcon } from "@chakra-ui/icons";
import { Box, useDisclosure, Spinner, useToast } from "@chakra-ui/react";
import mixpanel from "mixpanel-browser";
import { CreateCompletionResponseChoicesInner } from "openai";
// import { CreateCompletionResponseChoicesInner } from "openai";
import { OpenAI } from "openai-streams";
import { Resizable } from "re-resizable";
import { useEffect, useState, useCallback, useRef } from "react";
@@ -89,6 +89,7 @@ import ReactFlow, {
} from "reactflow";
import "reactflow/dist/style.css";
import { yieldStream } from "yield-stream";
import generateText from "../utils/generateText";

function App() {
const toast = useToast();
@@ -365,116 +366,48 @@ function App() {
if (firstCompletionId === undefined) throw new Error("No first completion id!");

(async () => {
const stream = await OpenAI(
"chat",
{
model,
n: responses,
temperature: temp,
messages: messagesFromLineage(parentNodeLineage, settings),
},
{ apiKey: apiKey!, mode: "raw", apiBase: apiBase }
const provider = Object.keys(availableModels || {}).find((key) =>
// @ts-ignore
availableModels[key].includes(model)
);
if (!provider) return;

const DECODER = new TextDecoder();

const abortController = new AbortController();

for await (const chunk of yieldStream(stream, abortController)) {
if (abortController.signal.aborted) break;

try {
const decoded = JSON.parse(DECODER.decode(chunk));

if (decoded.choices === undefined)
throw new Error(
"No choices in response. Decoded response: " + JSON.stringify(decoded)
);
const key = allApiKeys[provider as ApiKeyProvider];
if (!key) return;

const choice: CreateChatCompletionStreamResponseChoicesInner =
decoded.choices[0];
const messages = messagesFromLineage(parentNodeLineage, settings);

if (choice.index === undefined)
throw new Error(
"No index in choice. Decoded choice: " + JSON.stringify(choice)
);

const correspondingNodeId =
// If we re-used a node we have to pull it from children array.
overrideExistingIfPossible && choice.index < currentNodeChildren.length
? currentNodeChildren[choice.index].id
: newNodes[newNodes.length - responses + choice.index].id;

// The ChatGPT API will start by returning a
// choice with only a role delta and no content.
if (choice.delta?.content) {
setNodes((newerNodes) => {
try {
return appendTextToFluxNodeAsGPT(newerNodes, {
id: correspondingNodeId,
text: choice.delta?.content ?? UNDEFINED_RESPONSE_STRING,
streamId, // This will cause a throw if the streamId has changed.
});
} catch (e: any) {
// If the stream id does not match,
// it is stale and we should abort.
abortController.abort(e.message);

return newerNodes;
}
});
}

// We cannot return within the loop, and we do
// not want to execute the code below, so we break.
if (abortController.signal.aborted) break;

// If the choice has a finish reason, then it's the final
// choice and we can mark it as no longer animated right now.
if (choice.finish_reason !== null) {
// Reset the stream id.
setNodes((nodes) =>
setFluxNodeStreamId(nodes, { id: correspondingNodeId, streamId: undefined })
);

setEdges((edges) =>
modifyFluxEdge(edges, {
source: parentNode.id,
target: correspondingNodeId,
animated: false,
})
);
}
} catch (err) {
console.error(err);
}
}
for (let i = 0; i < responses; i++) {
const { text } = await generateText({
apiKey: key,
provider: provider as ApiKeyProvider,
model,
messages,
temperature: temp,
});

// If the stream wasn't aborted or was aborted due to a cancelation.
if (
!abortController.signal.aborted ||
abortController.signal.reason === STREAM_CANCELED_ERROR_MESSAGE
) {
// Mark all the edges as no longer animated.
for (let i = 0; i < responses; i++) {
const correspondingNodeId =
overrideExistingIfPossible && i < currentNodeChildren.length
? currentNodeChildren[i].id
: newNodes[newNodes.length - responses + i].id;

// Reset the stream id.
setNodes((nodes) =>
setFluxNodeStreamId(nodes, { id: correspondingNodeId, streamId: undefined })
);
const correspondingNodeId =
overrideExistingIfPossible && i < currentNodeChildren.length
? currentNodeChildren[i].id
: newNodes[newNodes.length - responses + i].id;

setEdges((edges) =>
modifyFluxEdge(edges, {
source: parentNode.id,
target: correspondingNodeId,
animated: false,
})
);
}
setNodes((newerNodes) =>
appendTextToFluxNodeAsGPT(newerNodes, {
id: correspondingNodeId,
text: text ?? UNDEFINED_RESPONSE_STRING,
streamId,
})
);
setNodes((nodes) =>
setFluxNodeStreamId(nodes, { id: correspondingNodeId, streamId: undefined })
);
setEdges((edges) =>
modifyFluxEdge(edges, {
source: parentNode.id,
target: correspondingNodeId,
animated: false,
})
);
}
})().catch((err) =>
toast({
@@ -574,7 +507,7 @@ function App() {
"No choices in response. Decoded response: " + JSON.stringify(decoded)
);

const choice: CreateCompletionResponseChoicesInner = decoded.choices[0];
const choice = decoded.choices[0];

setNodes((newerNodes) => {
try {
@@ -870,6 +803,12 @@ function App() {
const [anthropicKey] = useLocalStorage<string>(FLUX_ANTHROPIC_API_KEY);
const [googleKey] = useLocalStorage<string>(FLUX_GOOGLE_API_KEY);

const allApiKeys: Record<ApiKeyProvider, string | null> = {
openai: apiKey,
anthropic: anthropicKey,
google: googleKey,
};

const apiBase = import.meta.env.VITE_OPENAI_API_BASE;

type AvailableModels = Partial<Record<ApiKeyProvider, string[]>> | null;
38 changes: 38 additions & 0 deletions src/utils/generateText.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import { generateText, CoreMessage } from "ai";
import { createOpenAI, OpenAIProviderSettings } from "@ai-sdk/openai";
import { createAnthropic } from "@ai-sdk/anthropic";
import {
createGoogleGenerativeAI,
GoogleGenerativeAIProviderSettings,
} from "@ai-sdk/google";
import { ApiKeyProvider } from "./apikey";

interface GenerateAI {
apiKey: string;
provider: ApiKeyProvider;
model: string;
temperature?: number;
system?: string;
prompt?: string;
messages?: Array<CoreMessage>;
}

const openAiBaseUrl = import.meta.env.VITE_OPENAI_API_BASE;
const geminiBaseUrl = import.meta.env.VITE_GEMINI_API_BASE;

const SDK: Record<ApiKeyProvider, any> = {
openai: ({ apiKey }: OpenAIProviderSettings) =>
createOpenAI({ apiKey, baseURL: openAiBaseUrl }),
anthropic: createAnthropic,
google: ({ apiKey }: GoogleGenerativeAIProviderSettings) =>
createGoogleGenerativeAI({ apiKey, baseURL: geminiBaseUrl }),
};

async function generate({ apiKey, provider, model, ...rest }: GenerateAI) {
return await generateText({
model: SDK[provider]({ apiKey })(model),
...rest,
});
}

export default generate;
7 changes: 6 additions & 1 deletion src/utils/prompt.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import { FluxNodeData, FluxNodeType, Settings } from "./types";
import { ChatCompletionRequestMessage } from "openai";
// import { ChatCompletionRequestMessage } from "openai";
import { MAX_AUTOLABEL_CHARS } from "./constants";
import { Node } from "reactflow";

interface ChatCompletionRequestMessage {
role: "system" | "user" | "assistant";
content: string;
}

export function messagesFromLineage(
lineage: Node<FluxNodeData>[],
settings: Settings
4 changes: 2 additions & 2 deletions src/utils/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Node, Edge } from "reactflow";

import { ChatCompletionResponseMessage } from "openai";
// import { ChatCompletionResponseMessage } from "openai";

export type FluxNodeData = {
label: string;
@@ -32,7 +32,7 @@ export enum ReactFlowNodeTypes {
// The stream response is weird and has a delta instead of message field.
export interface CreateChatCompletionStreamResponseChoicesInner {
index?: number;
delta?: ChatCompletionResponseMessage;
// delta?: ChatCompletionResponseMessage;
finish_reason?: string;
}

0 comments on commit b30e39b

Please sign in to comment.