search related doc title before similarity search
search related doc title before similarity search
This commit is contained in:
parent
e6382cacb1
commit
526c4b52a8
|
|
@ -33,7 +33,7 @@ PROMPT_TEMPLATES["llm_chat"] = {
|
|||
PROMPT_TEMPLATES["knowledge_base_chat"] = {
|
||||
"default":
|
||||
"""
|
||||
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
|
||||
<指令>完全依据已知信息的内容,以一个电力专家的视角,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,不回答与问题无关的问题,答案请使用中文。 </指令>
|
||||
<已知信息>{{ context }}</已知信息>、
|
||||
<问题>{{ question }}</问题>
|
||||
""",
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -39,6 +39,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
history = [History.from_data(h) for h in history]
|
||||
#weiwei
|
||||
print(f"server/chat/knowledge_base_chat function, history:{history}")
|
||||
|
||||
async def knowledge_base_chat_iterator(query: str,
|
||||
top_k: int,
|
||||
|
|
@ -46,6 +48,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
model_name: str = LLM_MODEL,
|
||||
prompt_name: str = prompt_name,
|
||||
) -> AsyncIterable[str]:
|
||||
#weiwei
|
||||
print(f"knowledge_base_chat_iterator,query:{query},model_name:{model_name},prompt_name:{prompt_name}")
|
||||
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
|
|
@ -55,12 +60,21 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
)
|
||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
#weiwei
|
||||
print(f"knowledge_base_chat_iterator,search docs context:{context}")
|
||||
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
|
||||
#weiwei
|
||||
print(f"knowledge_base_chat_iterator,input_msg:{input_msg}")
|
||||
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
|
||||
#weiwei
|
||||
print(f"knowledge_base_chat_iterator,chat_prompt:{chat_prompt}")
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
|
|
@ -69,6 +83,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
callback.done),
|
||||
)
|
||||
|
||||
#weiwei
|
||||
print(f"task call end")
|
||||
|
||||
source_documents = []
|
||||
for inum, doc in enumerate(docs):
|
||||
filename = os.path.split(doc.metadata["source"])[-1]
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -1,3 +1,7 @@
|
|||
# from .kb_api import list_kbs, create_kb, delete_kb
|
||||
# from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
|
||||
# from .utils import KnowledgeFile, KBServiceFactory
|
||||
|
||||
from server.knowledge_base.kb_doc_api import *
|
||||
from server.knowledge_base.kb_api import *
|
||||
from server.knowledge_base.utils import *
|
||||
|
|
@ -29,8 +29,23 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
|
|||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return []
|
||||
# query = "根据国网安徽信通公司安全准入实施要求," + query
|
||||
pre_doc = kb.search_docs(query, 1)
|
||||
print(f"len(pre_doc):{len(pre_doc)}")
|
||||
if len(pre_doc) > 0:
|
||||
print(f"search_docs, len(pre_doc):{len(pre_doc)}")
|
||||
filpath = pre_doc[0][0].metadata['source']
|
||||
file_name = os.path.basename(filpath)
|
||||
file_name, file_extension = os.path.splitext(file_name)
|
||||
query = "根据" +file_name + ","+ query
|
||||
|
||||
print(f"search_docs, query:{query}")
|
||||
docs = kb.search_docs(query, top_k, score_threshold)
|
||||
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
|
||||
# i = 1
|
||||
# for x in docs:
|
||||
# print(f"相似文档 {i}: {x}")
|
||||
# i = i+1
|
||||
|
||||
return data
|
||||
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -60,8 +60,10 @@ class FaissKBService(KBService):
|
|||
score_threshold: float = SCORE_THRESHOLD,
|
||||
embeddings: Embeddings = None,
|
||||
) -> List[Document]:
|
||||
print(f"do_search,top_k:{top_k},score_threshold:{score_threshold}")
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
|
||||
print(f"do_search,docs:{docs}")
|
||||
return docs
|
||||
|
||||
def do_add_doc(self,
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
|
|
@ -15,6 +15,7 @@ from configs import (
|
|||
import importlib
|
||||
from text_splitter import zh_title_enhance as func_zh_title_enhance
|
||||
import langchain.document_loaders
|
||||
from langchain.document_loaders.word_document import Docx2txtLoader
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
from pathlib import Path
|
||||
|
|
@ -76,8 +77,9 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
|||
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
|
||||
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
|
||||
'.rtf', '.txt', '.xml',
|
||||
'.docx', '.epub', '.odt',
|
||||
'.epub', '.odt',
|
||||
'.ppt', '.pptx', '.tsv'],
|
||||
"Docx2txtLoader":['.docx'],
|
||||
}
|
||||
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||
|
||||
|
|
@ -281,6 +283,7 @@ class KnowledgeFile:
|
|||
self.splited_docs = None
|
||||
self.document_loader_name = get_LoaderClass(self.ext)
|
||||
self.text_splitter_name = TEXT_SPLITTER_NAME
|
||||
print(f"KnowledgeFile: filepath:{self.filepath}")
|
||||
|
||||
def file2docs(self, refresh: bool=False):
|
||||
if self.docs is None or refresh:
|
||||
|
|
@ -312,8 +315,13 @@ class KnowledgeFile:
|
|||
doc.metadata["source"] = os.path.basename(self.filepath)
|
||||
else:
|
||||
docs = text_splitter.split_documents(docs)
|
||||
|
||||
print(f"文档切分示例:{docs[0]}")
|
||||
|
||||
#print(f"文档切分示例:{docs[0]}")
|
||||
i = 0
|
||||
for doc in docs:
|
||||
print(f"**********切分段{i}:{doc}")
|
||||
i = i+1
|
||||
|
||||
if zh_title_enhance:
|
||||
docs = func_zh_title_enhance(docs)
|
||||
self.splited_docs = docs
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
|
||||
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||
from server.knowledge_base import KnowledgeFile
|
||||
|
||||
if __name__ == '__main__':
|
||||
from pprint import pprint
|
||||
|
||||
#kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples")
|
||||
# kb_file = KnowledgeFile(filename="国网安徽信通公司安全准入实施要求_修订.docx", knowledge_base_name="test")
|
||||
# docs = kb_file.file2docs()
|
||||
# pprint(docs[-1])
|
||||
# docs = kb_file.file2text()
|
||||
# pprint(docs[-1])
|
||||
|
||||
faissService = FaissKBService("test")
|
||||
faissService.add_doc(KnowledgeFile("国网安徽信通公司安全准入实施要求_修订.docx", "test"))
|
||||
# faissService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
# faissService.do_drop_kb()
|
||||
print(faissService.search_docs("准入手续的内容是什么?"))
|
||||
|
||||
|
||||
|
|
@ -98,7 +98,7 @@ if __name__ == "__main__":
|
|||
]
|
||||
# text = """"""
|
||||
for inum, text in enumerate(ls):
|
||||
print(inum)
|
||||
print(f"**************分段:{inum}")
|
||||
chunks = text_splitter.split_text(text)
|
||||
for chunk in chunks:
|
||||
print(chunk)
|
||||
print(f"**************:chunk:{chunk}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue