Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streaming support #26

Merged
merged 6 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ AZURE_SEARCH_URL_COLUMN=
AZURE_OPENAI_RESOURCE=
AZURE_OPENAI_MODEL=
AZURE_OPENAI_KEY=
AZURE_OPENAI_MODEL_NAME=gpt-35-turbo
AZURE_OPENAI_TEMPERATURE=0
AZURE_OPENAI_TOP_P=1.0
AZURE_OPENAI_MAX_TOKENS=1000
AZURE_OPENAI_STOP_SEQUENCE=
AZURE_OPENAI_SYSTEM_MESSAGE=You are an AI assistant that helps people find information.
AZURE_OPENAI_PREVIEW_API_VERSION=2023-03-31-preview
AZURE_OPENAI_PREVIEW_API_VERSION=2023-06-01-preview
AZURE_OPENAI_STREAM=True
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Feel free to fork this repository and make your own modifications to the UX or b
|AZURE_SEARCH_INDEX|||
|AZURE_SEARCH_KEYv
|AZURE_SEARCH_USE_SEMANTIC_SEARCH|False||
|AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG|default||
|AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG|||
|AZURE_SEARCH_INDEX_IS_PRECHUNKED|False||
|AZURE_SEARCH_TOP_K|5||
|AZURE_SEARCH_ENABLE_IN_DOMAIN|False||
Expand All @@ -64,14 +64,16 @@ Feel free to fork this repository and make your own modifications to the UX or b
|AZURE_SEARCH_TITLE_COLUMN|||
|AZURE_SEARCH_URL_COLUMN|||
|AZURE_OPENAI_RESOURCE|||
|AZURE_OPENAI_MODEL|||
|AZURE_OPENAI_MODEL||The name of your model deployment|
|AZURE_OPENAI_MODEL_NAME|gpt-35-turbo|The name of the model|
|AZURE_OPENAI_KEY|||
|AZURE_OPENAI_TEMPERATURE|0||
|AZURE_OPENAI_TOP_P|1.0||
|AZURE_OPENAI_MAX_TOKENS|1000||
|AZURE_OPENAI_STOP_SEQUENCE|||
|AZURE_OPENAI_SYSTEM_MESSAGE|You are an AI assistant that helps people find information.||
|AZURE_OPENAI_API_VERSION|2023-03-31-preview||
|AZURE_OPENAI_PREVIEW_API_VERSION|2023-06-01-preview||
|AZURE_OPENAI_STREAM|True||


## Contributing
Expand Down
191 changes: 144 additions & 47 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import os
import logging
import requests
from flask import Flask, request, jsonify
import openai
from flask import Flask, Response, request, jsonify
from dotenv import load_dotenv

load_dotenv()
Expand Down Expand Up @@ -36,7 +38,20 @@ def static_file(path):
AZURE_OPENAI_STOP_SEQUENCE = os.environ.get("AZURE_OPENAI_STOP_SEQUENCE")
AZURE_OPENAI_SYSTEM_MESSAGE = os.environ.get("AZURE_OPENAI_SYSTEM_MESSAGE", "You are an AI assistant that helps people find information.")
AZURE_OPENAI_PREVIEW_API_VERSION = os.environ.get("AZURE_OPENAI_PREVIEW_API_VERSION", "2023-06-01-preview")
AZURE_OPENAI_STREAM = os.environ.get("AZURE_OPENAI_STREAM", "true")
AZURE_OPENAI_MODEL_NAME = os.environ.get("AZURE_OPENAI_MODEL_NAME", "gpt-35-turbo") # Name of the model, e.g. 'gpt-35-turbo' or 'gpt-4'

SHOULD_STREAM = True if AZURE_OPENAI_STREAM.lower() == "true" else False

def is_chat_model():
if 'gpt-4' in AZURE_OPENAI_MODEL_NAME.lower():
return True
return False

def should_use_data():
if AZURE_SEARCH_SERVICE and AZURE_SEARCH_INDEX and AZURE_SEARCH_KEY:
return True
return False

def prepare_body_headers_with_data(request):
request_messages = request.json["messages"]
Expand All @@ -47,7 +62,7 @@ def prepare_body_headers_with_data(request):
"max_tokens": AZURE_OPENAI_MAX_TOKENS,
"top_p": AZURE_OPENAI_TOP_P,
"stop": AZURE_OPENAI_STOP_SEQUENCE.split("|") if AZURE_OPENAI_STOP_SEQUENCE else [],
"stream": False,
"stream": SHOULD_STREAM,
"dataSources": [
{
"type": "AzureCognitiveSearch",
Expand All @@ -71,79 +86,161 @@ def prepare_body_headers_with_data(request):
]
}

chatgpt_url = f"https://{AZURE_OPENAI_RESOURCE}.openai.azure.com/openai/deployments/{AZURE_OPENAI_MODEL}"
if is_chat_model():
chatgpt_url += "/chat/completions?api-version=2023-03-15-preview"
else:
chatgpt_url += "/completions?api-version=2023-03-15-preview"

headers = {
'Content-Type': 'application/json',
'api-key': AZURE_OPENAI_KEY,
'chatgpt_url': f"https://{AZURE_OPENAI_RESOURCE}.openai.azure.com/openai/deployments/{AZURE_OPENAI_MODEL}/completions?api-version=2023-03-31-preview",
'chatgpt_url': chatgpt_url,
'chatgpt_key': AZURE_OPENAI_KEY,
"x-ms-useragent": "GitHubSampleWebApp/PublicAPI/1.0.0"
}

return body, headers

def prepare_body_headers_without_data(request):

def stream_with_data(body, headers, endpoint):
s = requests.Session()
response = {
"id": "",
"model": "",
"created": 0,
"object": "",
"choices": [{
"messages": []
}]
}
try:
with s.post(endpoint, json=body, headers=headers, stream=True) as r:
for line in r.iter_lines(chunk_size=10):
if line:
lineJson = json.loads(line.lstrip(b'data:').decode('utf-8'))
if 'error' in lineJson:
yield json.dumps(lineJson).replace("\n", "\\n") + "\n"
response["id"] = lineJson["id"]
response["model"] = lineJson["model"]
response["created"] = lineJson["created"]
response["object"] = lineJson["object"]

role = lineJson["choices"][0]["messages"][0]["delta"].get("role")
if role == "tool":
response["choices"][0]["messages"].append(lineJson["choices"][0]["messages"][0]["delta"])
elif role == "assistant":
response["choices"][0]["messages"].append({
"role": "assistant",
"content": ""
})
else:
deltaText = lineJson["choices"][0]["messages"][0]["delta"]["content"]
if deltaText != "[DONE]":
response["choices"][0]["messages"][1]["content"] += deltaText

yield json.dumps(response).replace("\n", "\\n") + "\n"
except Exception as e:
yield json.dumps({"error": str(e)}).replace("\n", "\\n") + "\n"


def conversation_with_data(request):
body, headers = prepare_body_headers_with_data(request)
endpoint = f"https://{AZURE_OPENAI_RESOURCE}.openai.azure.com/openai/deployments/{AZURE_OPENAI_MODEL}/extensions/chat/completions?api-version={AZURE_OPENAI_PREVIEW_API_VERSION}"

if not SHOULD_STREAM:
r = requests.post(endpoint, headers=headers, json=body)
status_code = r.status_code
r = r.json()

return Response(json.dumps(r).replace("\n", "\\n"), status=status_code)
else:
if request.method == "POST":
return Response(stream_with_data(body, headers, endpoint), mimetype='text/event-stream')
else:
return Response(None, mimetype='text/event-stream')

def stream_without_data(response):
responseText = ""
for line in response:
deltaText = line["choices"][0]["delta"].get('content')
if deltaText and deltaText != "[DONE]":
responseText += deltaText

response_obj = {
"id": line["id"],
"model": line["model"],
"created": line["created"],
"object": line["object"],
"choices": [{
"messages": [{
"role": "assistant",
"content": responseText
}]
}]
}
yield json.dumps(response_obj).replace("\n", "\\n") + "\n"


def conversation_without_data(request):
openai.api_type = "azure"
openai.api_base = f"https://{AZURE_OPENAI_RESOURCE}.openai.azure.com/"
openai.api_version = "2023-03-15-preview"
openai.api_key = AZURE_OPENAI_KEY

request_messages = request.json["messages"]
body_messages = [
messages = [
{
"role": "system",
"content": AZURE_OPENAI_SYSTEM_MESSAGE
}
]

for message in request_messages:
body_messages.append({
messages.append({
"role": message["role"] ,
"content": message["content"]
})

body = {
"messages": body_messages,
"temperature": float(AZURE_OPENAI_TEMPERATURE),
"top_p": float(AZURE_OPENAI_TOP_P),
"max_tokens": int(AZURE_OPENAI_MAX_TOKENS),
"stream": False
}

headers = {
'Content-Type': 'application/json',
'api-key': AZURE_OPENAI_KEY
}

if AZURE_OPENAI_STOP_SEQUENCE:
sequences = AZURE_OPENAI_STOP_SEQUENCE.split("|")
body["stop"] = sequences

return body, headers
response = openai.ChatCompletion.create(
engine=AZURE_OPENAI_MODEL,
messages = messages,
temperature=float(AZURE_OPENAI_TEMPERATURE),
max_tokens=int(AZURE_OPENAI_MAX_TOKENS),
top_p=float(AZURE_OPENAI_TOP_P),
stop=AZURE_OPENAI_STOP_SEQUENCE.split("|") if AZURE_OPENAI_STOP_SEQUENCE else None,
stream=SHOULD_STREAM
)

if not SHOULD_STREAM:
response_obj = {
"id": response,
"model": response.model,
"created": response.created,
"object": response.object,
"choices": [{
"messages": [{
"role": "assistant",
"content": response.choices[0].message.content
}]
}]
}

def should_use_data():
if AZURE_SEARCH_SERVICE and AZURE_SEARCH_INDEX and AZURE_SEARCH_KEY:
return True
return False
return jsonify(response_obj), 200
else:
if request.method == "POST":
return Response(stream_without_data(response), mimetype='text/event-stream')
else:
return Response(None, mimetype='text/event-stream')

@app.route("/conversation", methods=["POST"])
@app.route("/conversation", methods=["GET", "POST"])
def conversation():
try:
base_url = f"https://{AZURE_OPENAI_RESOURCE}.openai.azure.com"
use_data = should_use_data()
if use_data:
body, headers = prepare_body_headers_with_data(request)
endpoint = f"{base_url}/openai/deployments/{AZURE_OPENAI_MODEL}/extensions/chat/completions?api-version={AZURE_OPENAI_PREVIEW_API_VERSION}"
return conversation_with_data(request)
else:
body, headers = prepare_body_headers_without_data(request)
endpoint = f"{base_url}/openai/deployments/{AZURE_OPENAI_MODEL}/chat/completions?api-version=2023-03-15-preview"

r = requests.post(endpoint, headers=headers, json=body)
status_code = r.status_code
r = r.json()

if not use_data and status_code == 200:
# convert to the same format as the data version
r["choices"][0]["messages"] = [{
"content": r["choices"][0]["message"]["content"],
"role": "assistant"
}]

return jsonify(r), status_code
return conversation_without_data(request)
except Exception as e:
logging.exception("Exception in /conversation")
return jsonify({"error": str(e)}), 500
Expand Down
13 changes: 2 additions & 11 deletions frontend/src/api/api.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { ChatResponse, ConversationRequest } from "./models";

export async function conversationApi(options: ConversationRequest, abortSignal: AbortSignal): Promise<ChatResponse> {
export async function conversationApi(options: ConversationRequest, abortSignal: AbortSignal): Promise<Response> {
const response = await fetch("/conversation", {
method: "POST",
headers: {
Expand All @@ -12,14 +12,5 @@ export async function conversationApi(options: ConversationRequest, abortSignal:
signal: abortSignal
});

const parsedResponse: ChatResponse = await response.json();

if (response.status > 299 || !response.ok) {
console.log("Error response from /conversation", parsedResponse)
const message = "An error occurred. Please try again. If the problem persists, please contact the site administrator.";
alert(message);
throw Error(message);
}

return parsedResponse;
return response;
}
3 changes: 1 addition & 2 deletions frontend/src/api/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export type ToolMessageContent = {
export type ChatMessage = {
role: string;
content: string;
end_turn: boolean | null;
end_turn?: boolean;
};

export enum ChatCompletionType {
Expand All @@ -32,7 +32,6 @@ export enum ChatCompletionType {
}

export type ChatResponseChoice = {
index: number;
messages: ChatMessage[];
}

Expand Down
1 change: 0 additions & 1 deletion frontend/src/pages/chat/Chat.module.css
Original file line number Diff line number Diff line change
Expand Up @@ -247,5 +247,4 @@
order: 1;
align-self: stretch;
flex-grow: 0;
white-space: pre-wrap;
}
Loading