search related doc title before similarity search

search related doc title before similarity search
This commit is contained in:
wvivi2023 2023-11-06 08:57:58 +08:00
parent e6382cacb1
commit 526c4b52a8
14 changed files with 73 additions and 6 deletions

BIN
.DS_Store vendored Normal file

Binary file not shown.

View File

@ -33,7 +33,7 @@ PROMPT_TEMPLATES["llm_chat"] = {
PROMPT_TEMPLATES["knowledge_base_chat"] = { PROMPT_TEMPLATES["knowledge_base_chat"] = {
"default": "default":
""" """
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令> <指令>完全依据已知信息的内容,以一个电力专家的视角,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,不回答与问题无关的问题,答案请使用中文。 </指令>
<已知信息>{{ context }}</已知信息>、 <已知信息>{{ context }}</已知信息>、
<问题>{{ question }}</问题> <问题>{{ question }}</问题>
""", """,

BIN
server/.DS_Store vendored Normal file

Binary file not shown.

View File

@ -39,6 +39,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
history = [History.from_data(h) for h in history] history = [History.from_data(h) for h in history]
#weiwei
print(f"server/chat/knowledge_base_chat function, history:{history}")
async def knowledge_base_chat_iterator(query: str, async def knowledge_base_chat_iterator(query: str,
top_k: int, top_k: int,
@ -46,6 +48,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
model_name: str = LLM_MODEL, model_name: str = LLM_MODEL,
prompt_name: str = prompt_name, prompt_name: str = prompt_name,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
#weiwei
print(f"knowledge_base_chat_iterator,query:{query},model_name:{model_name},prompt_name:{prompt_name}")
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
model = get_ChatOpenAI( model = get_ChatOpenAI(
model_name=model_name, model_name=model_name,
@ -55,12 +60,21 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
) )
docs = search_docs(query, knowledge_base_name, top_k, score_threshold) docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
#weiwei
print(f"knowledge_base_chat_iterator,search docs context:{context}")
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False) input_msg = History(role="user", content=prompt_template).to_msg_template(False)
#weiwei
print(f"knowledge_base_chat_iterator,input_msg:{input_msg}")
chat_prompt = ChatPromptTemplate.from_messages( chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg]) [i.to_msg_template() for i in history] + [input_msg])
#weiwei
print(f"knowledge_base_chat_iterator,chat_prompt:{chat_prompt}")
chain = LLMChain(prompt=chat_prompt, llm=model) chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background. # Begin a task that runs in the background.
@ -69,6 +83,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
callback.done), callback.done),
) )
#weiwei
print(f"task call end")
source_documents = [] source_documents = []
for inum, doc in enumerate(docs): for inum, doc in enumerate(docs):
filename = os.path.split(doc.metadata["source"])[-1] filename = os.path.split(doc.metadata["source"])[-1]

BIN
server/knowledge_base/.DS_Store vendored Normal file

Binary file not shown.

View File

@ -1,3 +1,7 @@
# from .kb_api import list_kbs, create_kb, delete_kb # from .kb_api import list_kbs, create_kb, delete_kb
# from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store # from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
# from .utils import KnowledgeFile, KBServiceFactory # from .utils import KnowledgeFile, KBServiceFactory
from server.knowledge_base.kb_doc_api import *
from server.knowledge_base.kb_api import *
from server.knowledge_base.utils import *

View File

@ -29,8 +29,23 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
kb = KBServiceFactory.get_service_by_name(knowledge_base_name) kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None: if kb is None:
return [] return []
# query = "根据国网安徽信通公司安全准入实施要求," + query
pre_doc = kb.search_docs(query, 1)
print(f"len(pre_doc):{len(pre_doc)}")
if len(pre_doc) > 0:
print(f"search_docs, len(pre_doc):{len(pre_doc)}")
filpath = pre_doc[0][0].metadata['source']
file_name = os.path.basename(filpath)
file_name, file_extension = os.path.splitext(file_name)
query = "根据" +file_name + ""+ query
print(f"search_docs, query:{query}")
docs = kb.search_docs(query, top_k, score_threshold) docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs] data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
# i = 1
# for x in docs:
# print(f"相似文档 {i}: {x}")
# i = i+1
return data return data

Binary file not shown.

View File

@ -60,8 +60,10 @@ class FaissKBService(KBService):
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
embeddings: Embeddings = None, embeddings: Embeddings = None,
) -> List[Document]: ) -> List[Document]:
print(f"do_search,top_k:{top_k},score_threshold:{score_threshold}")
with self.load_vector_store().acquire() as vs: with self.load_vector_store().acquire() as vs:
docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold) docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
print(f"do_search,docs:{docs}")
return docs return docs
def do_add_doc(self, def do_add_doc(self,

Binary file not shown.

Binary file not shown.

View File

@ -15,6 +15,7 @@ from configs import (
import importlib import importlib
from text_splitter import zh_title_enhance as func_zh_title_enhance from text_splitter import zh_title_enhance as func_zh_title_enhance
import langchain.document_loaders import langchain.document_loaders
from langchain.document_loaders.word_document import Docx2txtLoader
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter from langchain.text_splitter import TextSplitter
from pathlib import Path from pathlib import Path
@ -76,8 +77,9 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'], "RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
"UnstructuredFileLoader": ['.eml', '.msg', '.rst', "UnstructuredFileLoader": ['.eml', '.msg', '.rst',
'.rtf', '.txt', '.xml', '.rtf', '.txt', '.xml',
'.docx', '.epub', '.odt', '.epub', '.odt',
'.ppt', '.pptx', '.tsv'], '.ppt', '.pptx', '.tsv'],
"Docx2txtLoader":['.docx'],
} }
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
@ -281,6 +283,7 @@ class KnowledgeFile:
self.splited_docs = None self.splited_docs = None
self.document_loader_name = get_LoaderClass(self.ext) self.document_loader_name = get_LoaderClass(self.ext)
self.text_splitter_name = TEXT_SPLITTER_NAME self.text_splitter_name = TEXT_SPLITTER_NAME
print(f"KnowledgeFile: filepath:{self.filepath}")
def file2docs(self, refresh: bool=False): def file2docs(self, refresh: bool=False):
if self.docs is None or refresh: if self.docs is None or refresh:
@ -312,8 +315,13 @@ class KnowledgeFile:
doc.metadata["source"] = os.path.basename(self.filepath) doc.metadata["source"] = os.path.basename(self.filepath)
else: else:
docs = text_splitter.split_documents(docs) docs = text_splitter.split_documents(docs)
print(f"文档切分示例:{docs[0]}") #print(f"文档切分示例:{docs[0]}")
i = 0
for doc in docs:
print(f"**********切分段{i}{doc}")
i = i+1
if zh_title_enhance: if zh_title_enhance:
docs = func_zh_title_enhance(docs) docs = func_zh_title_enhance(docs)
self.splited_docs = docs self.splited_docs = docs

21
test.py Normal file
View File

@ -0,0 +1,21 @@
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
from server.knowledge_base import KnowledgeFile
if __name__ == '__main__':
from pprint import pprint
#kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples")
# kb_file = KnowledgeFile(filename="国网安徽信通公司安全准入实施要求_修订.docx", knowledge_base_name="test")
# docs = kb_file.file2docs()
# pprint(docs[-1])
# docs = kb_file.file2text()
# pprint(docs[-1])
faissService = FaissKBService("test")
faissService.add_doc(KnowledgeFile("国网安徽信通公司安全准入实施要求_修订.docx", "test"))
# faissService.delete_doc(KnowledgeFile("README.md", "test"))
# faissService.do_drop_kb()
print(faissService.search_docs("准入手续的内容是什么?"))

View File

@ -98,7 +98,7 @@ if __name__ == "__main__":
] ]
# text = """""" # text = """"""
for inum, text in enumerate(ls): for inum, text in enumerate(ls):
print(inum) print(f"**************分段:{inum}")
chunks = text_splitter.split_text(text) chunks = text_splitter.split_text(text)
for chunk in chunks: for chunk in chunks:
print(chunk) print(f"**************:chunk:{chunk}")