-
Notifications
You must be signed in to change notification settings - Fork 211
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Build Processor Pipeline for db build and use agent to generate
- Loading branch information
Showing
38 changed files
with
4,934 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
from typing import Any, List, Dict | ||
import re | ||
|
||
from .rag_agent import BaseAgent | ||
from lagent.rag.schema import Node, CommunityReport, Community, Chunk | ||
from lagent.rag.nlp import SentenceTransformerEmbedder, SimpleTokenizer | ||
from lagent.llms import DeepseekAPI, BaseAPILLM | ||
from lagent.rag.prompts import KNOWLEDGE_PROMPT | ||
from lagent.rag.processors import (DocParser, ChunkSplitter, EntityExtractor, DescriptionSummarizer, | ||
CommunitiesDetector, CommunityReportsExtractor, BuildDatabase, SaveGraph) | ||
|
||
|
||
class GraphRagAgent(BaseAgent): | ||
def __init__(self, | ||
llm: BaseAPILLM = dict(type=DeepseekAPI), | ||
embedder=dict(type=SentenceTransformerEmbedder), | ||
tokenizer=dict(type=SimpleTokenizer), | ||
processors_config: | ||
List = [dict(type=DocParser), dict(type=ChunkSplitter), dict(type=EntityExtractor), | ||
dict(type=DescriptionSummarizer), dict(type=CommunitiesDetector), | ||
dict(type=CommunityReportsExtractor), dict(type=BuildDatabase), dict(type=SaveGraph)], | ||
**kwargs): | ||
super().__init__(llm=llm, embedder=embedder, tokenizer=tokenizer, processors_config=processors_config, **kwargs) | ||
|
||
def forward(self, query: str, **kwargs) -> Any: | ||
""" | ||
Processes the input query and returns a response generated by the language model. | ||
This method performs the following steps: | ||
1. Initializes external memory. | ||
2. Performs similarity searches to identify top relevant entities. | ||
3. Builds community and chunk contexts based on the retrieved entities. | ||
4. Prepares the final prompt and interacts with the language model to generate a response. | ||
Args: | ||
query (str): The input query string to be processed. | ||
**kwargs (Any): Additional keyword arguments, such as prompts, token limits, and weighting factors. | ||
Returns: | ||
Any: The response generated by the language model based on the processed query and contexts. | ||
""" | ||
|
||
memory = self.external_memory | ||
prompt = kwargs.get('prompts', KNOWLEDGE_PROMPT) | ||
tokenizer = self.tokenizer | ||
|
||
max_ref_token = kwargs.get('max_ref_token', 4096) | ||
w_community = kwargs.get('w_community', 0.4) | ||
w_text = kwargs.get('w_text', 0.6) | ||
max_ref_token = max_ref_token - tokenizer.get_token_num(prompt) | ||
|
||
dict_chunks = memory.layers['chunk_layer'].get_nodes() | ||
chunks: List[Chunk] = [] | ||
for chunk in dict_chunks: | ||
chunks.append(Chunk.dict_to_chunk(chunk)) | ||
|
||
entities: List[Node] = [] | ||
|
||
assert 'summarized_entity_layer' in memory.layers | ||
dict_entities = memory.layers['summarized_entity_layer'].get_nodes() | ||
for entity in dict_entities: | ||
entities.append(Node.dict_to_node(entity)) | ||
|
||
top_k = kwargs.get('top_k', 8) | ||
entities_db = memory.layers_db['summarized_entity_layer'] | ||
|
||
selected_entities = [] | ||
selected_entities = entities_db.similarity_search_with_score(query, k=top_k) | ||
selected_entities = [(en[0].metadata['id'], en[1]) for en in selected_entities] | ||
|
||
# build community_reports context | ||
dict_commu_rep = memory.layers['community_report_layer'].get_nodes() | ||
community_reports: List[CommunityReport] = [] | ||
for commu in dict_commu_rep: | ||
community_reports.append(CommunityReport( | ||
community_id=commu['community_id'], | ||
level=commu['level'], | ||
report=commu['report'], | ||
structured_report=commu['structured_report'] | ||
)) | ||
|
||
dict_commu = memory.layers['community_layer'].get_nodes() | ||
communities: List[Community] = [] | ||
for commu in dict_commu: | ||
communities.append(Community( | ||
community_id=commu['id'], | ||
level=commu['level'], | ||
nodes_id=commu['nodes_id'] | ||
)) | ||
|
||
community_context = self.build_community_contexts(entities_with_score=selected_entities, | ||
community_reports=community_reports, | ||
max_tokens=int(max_ref_token * w_community), | ||
communities=communities) | ||
|
||
chunk_content = self.build_chunk_contexts(entities_with_score=selected_entities, | ||
entities=entities, | ||
chunks=chunks, | ||
max_tokens=int(max_ref_token * w_text)) | ||
|
||
final_search_contents = f'{community_context}\n{chunk_content}' | ||
|
||
result = self.prepare_prompt(query=query, knowledge=final_search_contents, prompt=prompt) | ||
|
||
messages = [{"role": "user", "content": result}] | ||
|
||
response = self.llm.chat(messages) | ||
|
||
return response | ||
|
||
def build_community_contexts(self, entities_with_score: List[tuple[str, float]], | ||
community_reports: List[CommunityReport], max_tokens: int, | ||
communities: List[Community]) -> str: | ||
""" | ||
Constructs the community context based on selected entities, community reports, and communities. | ||
This context is built by aggregating relevant community reports and ensuring | ||
that the total number of tokens does not exceed the specified maximum. | ||
Args: | ||
entities_with_score (List[tuple[str, float]]): A list of tuples containing entity IDs and their | ||
corresponding scores. | ||
community_reports (List[CommunityReport]): A list of CommunityReport instances. | ||
max_tokens (int): The maximum number of tokens allowed for the community context. | ||
communities (List[Community]): A list of Community instances. | ||
Returns: | ||
str: The constructed community context as a string. | ||
""" | ||
|
||
selected_entities = {} | ||
for entity in entities_with_score: | ||
selected_entities[entity[0]] = entity | ||
selected_commu = {} | ||
for community in communities: | ||
for node_id in community.nodes_id: | ||
if node_id in selected_entities: | ||
selected_commu[community.community_id] = selected_commu.get(community.community_id, 0) + \ | ||
selected_entities[node_id][1] | ||
selected_commu = dict(sorted(selected_commu.items(), key=lambda item: item[1], reverse=True)) | ||
|
||
id_map_reports = {} | ||
for community_report in community_reports: | ||
id_map_reports[community_report.community_id] = community_report | ||
|
||
selected_reports = [id_map_reports[community_id] for community_id in selected_commu.keys()] | ||
|
||
# Trim reports to fit within the token limit | ||
result = [f'------reports------' + '\n'] | ||
tokenizer = self.tokenizer | ||
for selected_report in selected_reports: | ||
content = selected_report.report | ||
token_num = tokenizer.get_token_num(content) | ||
if token_num <= max_tokens: | ||
result.append(content) | ||
max_tokens = max_tokens - token_num | ||
else: | ||
tmp = '' | ||
sentences = re.split(r'\. |。', content) | ||
j = 0 | ||
while j < len(sentences): | ||
sentence = sentences[j].strip() | ||
num = tokenizer.get_token_num(sentence) | ||
if num == 0: | ||
continue | ||
if num <= max_tokens: | ||
tmp = f'{tmp}.{sentence}' | ||
max_tokens = max_tokens - num | ||
j += 1 | ||
else: | ||
break | ||
result.append(tmp) | ||
break | ||
|
||
return '\n'.join(result) | ||
|
||
def build_chunk_contexts(self, entities_with_score: List[tuple[str, float]], entities: List[Node], | ||
chunks: List[Chunk], max_tokens: int): | ||
""" | ||
Constructs the chunk context based on selected entities and chunks. | ||
This context aggregates relevant chunks associated with the selected entities | ||
while ensuring that the total number of tokens does not exceed the specified maximum. | ||
Args: | ||
entities_with_score (List[tuple[str, float]]): A list of tuples containing entity IDs and their corresponding scores. | ||
entities (List[Node]): A list of Node instances representing entities. | ||
chunks (List[Chunk]): A list of Chunk instances to be considered. | ||
max_tokens (int): The maximum number of tokens allowed for the chunk context. | ||
Returns: | ||
str: The constructed chunk context as a string. | ||
""" | ||
id_map_entities = {} | ||
for entity in entities: | ||
id_map_entities[entity.id] = entity | ||
|
||
id_map_chunks = {} | ||
for chunk in chunks: | ||
id_map_chunks[chunk.id] = chunk | ||
|
||
selected_chunks: Dict[str, Dict] = {} | ||
for entity_with_score in entities_with_score: | ||
entity = id_map_entities[entity_with_score[0]] | ||
for chunk_id in entity.source_id: | ||
if chunk_id not in selected_chunks: | ||
selected_chunks[chunk_id] = {} | ||
selected_chunks[chunk_id]['score'] = selected_chunks[chunk_id].get('score', 0) + entity_with_score[1] | ||
|
||
selected_chunks = dict( | ||
sorted(selected_chunks.items(), key=lambda item: item[1]["score"], reverse=True) | ||
) | ||
|
||
result = [f'------texts------' + '\n'] | ||
tokenizer = self.tokenizer | ||
for chunk_id, selected_chunk in selected_chunks.items(): | ||
content = id_map_chunks[chunk_id].content | ||
token_num = id_map_chunks[chunk_id].token_num | ||
if token_num <= max_tokens: | ||
result.append(content) | ||
max_tokens -= token_num | ||
else: | ||
tmp = '' | ||
sentences = re.split(r'\. |。', content) | ||
j = 0 | ||
while j < len(sentences): | ||
sentence = sentences[j].strip() | ||
num = tokenizer.get_token_num(sentence) | ||
if num == 0: | ||
continue | ||
if num <= max_tokens: | ||
tmp = f'{tmp}.{sentence}' | ||
max_tokens = max_tokens - num | ||
j += 1 | ||
else: | ||
break | ||
result.append(tmp) | ||
|
||
return '\n'.join(result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import re | ||
from typing import Any, List | ||
|
||
from .rag_agent import BaseAgent | ||
from lagent.rag.prompts import KNOWLEDGE_PROMPT | ||
from lagent.rag.schema import Document, Chunk, DocumentDB | ||
from lagent.llms import DeepseekAPI, BaseAPILLM | ||
from lagent.rag.processors import ChunkSplitter, DocParser, BuildDatabase, SaveGraph | ||
from lagent.rag.nlp import SentenceTransformerEmbedder, SimpleTokenizer | ||
from lagent.rag.settings import DEFAULT_LLM_MAX_TOKEN | ||
|
||
|
||
class NaiveRAGAgent(BaseAgent): | ||
def __init__(self, | ||
llm: BaseAPILLM = dict(type=DeepseekAPI), | ||
embedder=dict(type=SentenceTransformerEmbedder), | ||
tokenizer=dict(type=SimpleTokenizer), | ||
processors_config: | ||
List = [dict(type=DocParser), dict(type=ChunkSplitter), | ||
dict(type=BuildDatabase), dict(type=SaveGraph)], | ||
**kwargs): | ||
super().__init__(llm=llm, embedder=embedder, tokenizer=tokenizer, processors_config=processors_config, **kwargs) | ||
|
||
def forward(self, query: str, **kwargs) -> Any: | ||
memory = self.external_memory | ||
tokenizer = self.tokenizer | ||
|
||
prompt = kwargs.get('prompts', KNOWLEDGE_PROMPT) | ||
|
||
max_ref_token = kwargs.get('max_ref_token', DEFAULT_LLM_MAX_TOKEN) | ||
max_ref_token = max_ref_token - tokenizer.get_token_num(prompt) | ||
|
||
dict_chunks = memory.layers['chunk_layer'].get_nodes() | ||
chunks: List[Chunk] = [] | ||
for chunk in dict_chunks: | ||
chunks.append(Chunk.dict_to_chunk(chunk)) | ||
|
||
top_k = kwargs.get('top_k', 3) | ||
chunks_db = memory.layers_db['chunk_layer'] | ||
results = chunks_db.similarity_search_with_score(query, k=top_k) | ||
search_contents = [result[0].content for result in results] | ||
|
||
# TODO:find better ways to trim the context | ||
text = '\n'.join(search_contents) | ||
final_search_contents = self.trim_context(text, max_ref_token) | ||
|
||
result = self.prepare_prompt(query=query, knowledge=final_search_contents, prompt=prompt) | ||
|
||
messages = [{"role": "user", "content": result}] | ||
response = self.llm.chat(messages) | ||
|
||
return response | ||
|
||
def trim_context(self, text, max_ref_token): | ||
tokenizer = self.tokenizer | ||
token_num = tokenizer.get_token_num(text) | ||
if token_num <= max_ref_token: | ||
return text | ||
paras = text.split('\n') | ||
available_num = max_ref_token | ||
result = [] | ||
for para in paras: | ||
para = para.strip() | ||
token_num = tokenizer.get_token_num(para) | ||
if token_num <= available_num: | ||
result.append(para) | ||
available_num -= token_num | ||
else: | ||
tmp = '' | ||
sentences = re.split(r'\. |。', para) | ||
j = 0 | ||
while j < len(sentences): | ||
sentence = sentences[j].strip() | ||
num = tokenizer.get_token_num(sentence) | ||
if num == 0: | ||
continue | ||
if num <= available_num: | ||
tmp = f'{tmp}.{sentence}' | ||
max_tokens = available_num - num | ||
j += 1 | ||
else: | ||
break | ||
result.append(tmp) | ||
return '\n'.join(result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from typing import Any, List, Dict | ||
import yaml | ||
import inspect | ||
|
||
from .agent import Agent | ||
from ..schema import AgentMessage | ||
from lagent.rag.pipeline import BaseProcessor | ||
from lagent.utils.util import create_object | ||
from lagent.rag.pipeline import Pipeline | ||
from lagent.llms import BaseAPILLM, DeepseekAPI | ||
from lagent.rag.nlp import SentenceTransformerEmbedder, SimpleTokenizer | ||
from lagent.rag.utils import replace_variables_in_prompt | ||
|
||
|
||
class BaseAgent(Agent): | ||
def __init__(self, | ||
processors_config: List, | ||
llm: BaseAPILLM=dict(type=DeepseekAPI), | ||
embedder=dict(type=SentenceTransformerEmbedder), | ||
tokenizer=dict(type=SimpleTokenizer), | ||
**kwargs): | ||
super().__init__(memory={}, **kwargs) | ||
self.external_memory = None | ||
self.llm = create_object(llm) | ||
self.embedder = create_object(embedder) | ||
self.tokenizer = create_object(tokenizer) | ||
|
||
self.processors = self.init_processors(processors_config) | ||
|
||
def init_external_memory(self, data): | ||
|
||
processors = self.processors | ||
pipeline = Pipeline() | ||
for processor in processors: | ||
pipeline.add_processor(processor) | ||
|
||
self.external_memory = pipeline.run(data) | ||
|
||
return self.external_memory | ||
|
||
def forward(self, **kwargs) -> Any: | ||
raise NotImplemented | ||
|
||
def init_processors(self, processors_config: List): | ||
processors = [] | ||
for processor_config in processors_config: | ||
if isinstance(processor_config, dict): | ||
processor = create_object(processor_config) | ||
elif isinstance(processor_config, object): | ||
processor = processor_config | ||
else: | ||
raise ValueError | ||
processors.append(processor) | ||
|
||
return processors | ||
|
||
def prepare_prompt(self, knowledge: str, query: str, prompt: str): | ||
|
||
prompt_variables = { | ||
'External_Knowledge': knowledge, | ||
'Query': query | ||
} | ||
|
||
prompt = replace_variables_in_prompt(prompt=prompt, prompt_variables=prompt_variables) | ||
|
||
return prompt |
Oops, something went wrong.