search related doc title before similarity search
search related doc title before similarity search
This commit is contained in:
parent
e6382cacb1
commit
526c4b52a8
|
|
@ -33,7 +33,7 @@ PROMPT_TEMPLATES["llm_chat"] = {
|
||||||
PROMPT_TEMPLATES["knowledge_base_chat"] = {
|
PROMPT_TEMPLATES["knowledge_base_chat"] = {
|
||||||
"default":
|
"default":
|
||||||
"""
|
"""
|
||||||
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
|
<指令>完全依据已知信息的内容,以一个电力专家的视角,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,不回答与问题无关的问题,答案请使用中文。 </指令>
|
||||||
<已知信息>{{ context }}</已知信息>、
|
<已知信息>{{ context }}</已知信息>、
|
||||||
<问题>{{ question }}</问题>
|
<问题>{{ question }}</问题>
|
||||||
""",
|
""",
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -39,6 +39,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
history = [History.from_data(h) for h in history]
|
history = [History.from_data(h) for h in history]
|
||||||
|
#weiwei
|
||||||
|
print(f"server/chat/knowledge_base_chat function, history:{history}")
|
||||||
|
|
||||||
async def knowledge_base_chat_iterator(query: str,
|
async def knowledge_base_chat_iterator(query: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
|
@ -46,6 +48,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||||
model_name: str = LLM_MODEL,
|
model_name: str = LLM_MODEL,
|
||||||
prompt_name: str = prompt_name,
|
prompt_name: str = prompt_name,
|
||||||
) -> AsyncIterable[str]:
|
) -> AsyncIterable[str]:
|
||||||
|
#weiwei
|
||||||
|
print(f"knowledge_base_chat_iterator,query:{query},model_name:{model_name},prompt_name:{prompt_name}")
|
||||||
|
|
||||||
callback = AsyncIteratorCallbackHandler()
|
callback = AsyncIteratorCallbackHandler()
|
||||||
model = get_ChatOpenAI(
|
model = get_ChatOpenAI(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
|
|
@ -55,12 +60,21 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||||
)
|
)
|
||||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
|
#weiwei
|
||||||
|
print(f"knowledge_base_chat_iterator,search docs context:{context}")
|
||||||
|
|
||||||
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||||
|
|
||||||
|
#weiwei
|
||||||
|
print(f"knowledge_base_chat_iterator,input_msg:{input_msg}")
|
||||||
|
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
[i.to_msg_template() for i in history] + [input_msg])
|
[i.to_msg_template() for i in history] + [input_msg])
|
||||||
|
|
||||||
|
#weiwei
|
||||||
|
print(f"knowledge_base_chat_iterator,chat_prompt:{chat_prompt}")
|
||||||
|
|
||||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||||
|
|
||||||
# Begin a task that runs in the background.
|
# Begin a task that runs in the background.
|
||||||
|
|
@ -69,6 +83,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||||
callback.done),
|
callback.done),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
#weiwei
|
||||||
|
print(f"task call end")
|
||||||
|
|
||||||
source_documents = []
|
source_documents = []
|
||||||
for inum, doc in enumerate(docs):
|
for inum, doc in enumerate(docs):
|
||||||
filename = os.path.split(doc.metadata["source"])[-1]
|
filename = os.path.split(doc.metadata["source"])[-1]
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -1,3 +1,7 @@
|
||||||
# from .kb_api import list_kbs, create_kb, delete_kb
|
# from .kb_api import list_kbs, create_kb, delete_kb
|
||||||
# from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
|
# from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
|
||||||
# from .utils import KnowledgeFile, KBServiceFactory
|
# from .utils import KnowledgeFile, KBServiceFactory
|
||||||
|
|
||||||
|
from server.knowledge_base.kb_doc_api import *
|
||||||
|
from server.knowledge_base.kb_api import *
|
||||||
|
from server.knowledge_base.utils import *
|
||||||
|
|
@ -29,8 +29,23 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
|
||||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
if kb is None:
|
if kb is None:
|
||||||
return []
|
return []
|
||||||
|
# query = "根据国网安徽信通公司安全准入实施要求," + query
|
||||||
|
pre_doc = kb.search_docs(query, 1)
|
||||||
|
print(f"len(pre_doc):{len(pre_doc)}")
|
||||||
|
if len(pre_doc) > 0:
|
||||||
|
print(f"search_docs, len(pre_doc):{len(pre_doc)}")
|
||||||
|
filpath = pre_doc[0][0].metadata['source']
|
||||||
|
file_name = os.path.basename(filpath)
|
||||||
|
file_name, file_extension = os.path.splitext(file_name)
|
||||||
|
query = "根据" +file_name + ","+ query
|
||||||
|
|
||||||
|
print(f"search_docs, query:{query}")
|
||||||
docs = kb.search_docs(query, top_k, score_threshold)
|
docs = kb.search_docs(query, top_k, score_threshold)
|
||||||
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
|
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
|
||||||
|
# i = 1
|
||||||
|
# for x in docs:
|
||||||
|
# print(f"相似文档 {i}: {x}")
|
||||||
|
# i = i+1
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -60,8 +60,10 @@ class FaissKBService(KBService):
|
||||||
score_threshold: float = SCORE_THRESHOLD,
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
embeddings: Embeddings = None,
|
embeddings: Embeddings = None,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
|
print(f"do_search,top_k:{top_k},score_threshold:{score_threshold}")
|
||||||
with self.load_vector_store().acquire() as vs:
|
with self.load_vector_store().acquire() as vs:
|
||||||
docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
|
docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
|
||||||
|
print(f"do_search,docs:{docs}")
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def do_add_doc(self,
|
def do_add_doc(self,
|
||||||
|
|
|
||||||
Binary file not shown.
Binary file not shown.
|
|
@ -15,6 +15,7 @@ from configs import (
|
||||||
import importlib
|
import importlib
|
||||||
from text_splitter import zh_title_enhance as func_zh_title_enhance
|
from text_splitter import zh_title_enhance as func_zh_title_enhance
|
||||||
import langchain.document_loaders
|
import langchain.document_loaders
|
||||||
|
from langchain.document_loaders.word_document import Docx2txtLoader
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.text_splitter import TextSplitter
|
from langchain.text_splitter import TextSplitter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -76,8 +77,9 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
||||||
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
|
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
|
||||||
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
|
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
|
||||||
'.rtf', '.txt', '.xml',
|
'.rtf', '.txt', '.xml',
|
||||||
'.docx', '.epub', '.odt',
|
'.epub', '.odt',
|
||||||
'.ppt', '.pptx', '.tsv'],
|
'.ppt', '.pptx', '.tsv'],
|
||||||
|
"Docx2txtLoader":['.docx'],
|
||||||
}
|
}
|
||||||
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||||
|
|
||||||
|
|
@ -281,6 +283,7 @@ class KnowledgeFile:
|
||||||
self.splited_docs = None
|
self.splited_docs = None
|
||||||
self.document_loader_name = get_LoaderClass(self.ext)
|
self.document_loader_name = get_LoaderClass(self.ext)
|
||||||
self.text_splitter_name = TEXT_SPLITTER_NAME
|
self.text_splitter_name = TEXT_SPLITTER_NAME
|
||||||
|
print(f"KnowledgeFile: filepath:{self.filepath}")
|
||||||
|
|
||||||
def file2docs(self, refresh: bool=False):
|
def file2docs(self, refresh: bool=False):
|
||||||
if self.docs is None or refresh:
|
if self.docs is None or refresh:
|
||||||
|
|
@ -312,8 +315,13 @@ class KnowledgeFile:
|
||||||
doc.metadata["source"] = os.path.basename(self.filepath)
|
doc.metadata["source"] = os.path.basename(self.filepath)
|
||||||
else:
|
else:
|
||||||
docs = text_splitter.split_documents(docs)
|
docs = text_splitter.split_documents(docs)
|
||||||
|
|
||||||
print(f"文档切分示例:{docs[0]}")
|
#print(f"文档切分示例:{docs[0]}")
|
||||||
|
i = 0
|
||||||
|
for doc in docs:
|
||||||
|
print(f"**********切分段{i}:{doc}")
|
||||||
|
i = i+1
|
||||||
|
|
||||||
if zh_title_enhance:
|
if zh_title_enhance:
|
||||||
docs = func_zh_title_enhance(docs)
|
docs = func_zh_title_enhance(docs)
|
||||||
self.splited_docs = docs
|
self.splited_docs = docs
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
|
||||||
|
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||||
|
from server.knowledge_base import KnowledgeFile
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
#kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples")
|
||||||
|
# kb_file = KnowledgeFile(filename="国网安徽信通公司安全准入实施要求_修订.docx", knowledge_base_name="test")
|
||||||
|
# docs = kb_file.file2docs()
|
||||||
|
# pprint(docs[-1])
|
||||||
|
# docs = kb_file.file2text()
|
||||||
|
# pprint(docs[-1])
|
||||||
|
|
||||||
|
faissService = FaissKBService("test")
|
||||||
|
faissService.add_doc(KnowledgeFile("国网安徽信通公司安全准入实施要求_修订.docx", "test"))
|
||||||
|
# faissService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||||
|
# faissService.do_drop_kb()
|
||||||
|
print(faissService.search_docs("准入手续的内容是什么?"))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -98,7 +98,7 @@ if __name__ == "__main__":
|
||||||
]
|
]
|
||||||
# text = """"""
|
# text = """"""
|
||||||
for inum, text in enumerate(ls):
|
for inum, text in enumerate(ls):
|
||||||
print(inum)
|
print(f"**************分段:{inum}")
|
||||||
chunks = text_splitter.split_text(text)
|
chunks = text_splitter.split_text(text)
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
print(chunk)
|
print(f"**************:chunk:{chunk}")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue