-
Notifications
You must be signed in to change notification settings - Fork 6
/
my_chainlit_example2.py
148 lines (119 loc) · 4.44 KB
/
my_chainlit_example2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from io import BytesIO
import chainlit as cl
from dotenv import load_dotenv
from langchain.chains import RetrievalQAWithSourcesChain
from langchain_community.chat_models import ChatOpenAI
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
import PyPDF2
# 加载环境变量
load_dotenv()
# 设置文件切片方式
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
system_template = """Use the following pieces of context to answer the users question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
ALWAYS return a "SOURCES" part in your answer.
The "SOURCES" part should be a reference to the source of the document from which you got your answer.
Example of your response should be:
```
The answer is foo
SOURCES: xyz
```
Begin!
----------------
{summaries}"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
prompt = ChatPromptTemplate.from_messages(messages)
chain_type_kwargs = {"prompt": prompt}
@cl.on_chat_start
async def on_chat_start():
await cl.Message(content="Welcome to LangChain World!").send()
files = None
# 等待上传 PDF 文件
while files is None:
files = await cl.AskFileMessage(
content="Please upload a PDF file to begin!",
accept=["application/pdf"],
max_size_mb=20,
timeout=180,
).send()
file = files[0]
msg = cl.Message(content=f"Processing `{file.name}`...")
await msg.send()
# 读取 PDF 文件
pdf_stream = BytesIO(file.content)
pdf = PyPDF2.PdfReader(pdf_stream)
pdf_text = ""
for page in pdf.pages:
pdf_text += page.extract_text()
# PDF 内容切片
texts = text_splitter.split_text(pdf_text)
# 为每片内容设定源信息
metadatas = [{"source": f"{i}-pl"} for i in range(len(texts))]
# 创建 Chroma 向量存储
embeddings = OpenAIEmbeddings()
docsearch = await cl.make_async(Chroma.from_texts)(
texts, embeddings, metadatas=metadatas
)
# 创建 Chain,是一个特殊的带 Sources 的 Chain
chain = RetrievalQAWithSourcesChain.from_chain_type(
ChatOpenAI(temperature=0, streaming=True),
chain_type="stuff",
chain_type_kwargs=chain_type_kwargs,
retriever=docsearch.as_retriever(),
)
# 在用户 session 中保留上下文信息
cl.user_session.set("metadatas", metadatas)
cl.user_session.set("texts", texts)
# 文件上传完毕后提示用户
msg.content = f"Processing `{file.name}` done. You can now ask questions!"
await msg.update()
cl.user_session.set("chain", chain)
@cl.on_message
async def main(message:str):
chain = cl.user_session.get("chain") # type: RetrievalQAWithSourcesChain
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached = True
# 获取 chain 的结果
res = await chain.acall(message.content, callbacks=[cb])
answer = res["answer"]
sources = res["sources"].strip()
source_elements = []
# 获取用户 session 信息
metadatas = cl.user_session.get("metadatas")
all_sources = [m["source"] for m in metadatas]
texts = cl.user_session.get("texts")
if sources:
found_sources = []
# 将来源添加到消息中
for source in sources.split(","):
source_name = source.strip().replace(".", "")
# 获取源的索引
try:
index = all_sources.index(source_name)
except ValueError:
continue
text = texts[index]
found_sources.append(source_name)
# 创建消息中引用的文本元素
source_elements.append(cl.Text(content=text, name=source_name))
if found_sources:
answer += f"\nSources: {', '.join(found_sources)}"
else:
answer += "\nNo sources found"
if cb.has_streamed_final_answer:
cb.final_stream.elements = source_elements
await cb.final_stream.update()
else:
await cl.Message(content=answer, elements=source_elements).send()