diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..65b9efa Binary files /dev/null and b/.DS_Store differ diff --git a/configs/prompt_config.py.example b/configs/prompt_config.py.example index a52b1f7..a84ca71 100644 --- a/configs/prompt_config.py.example +++ b/configs/prompt_config.py.example @@ -33,7 +33,7 @@ PROMPT_TEMPLATES["llm_chat"] = { PROMPT_TEMPLATES["knowledge_base_chat"] = { "default": """ - <指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 + <指令>完全依据已知信息的内容,以一个电力专家的视角,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,不回答与问题无关的问题,答案请使用中文。 <已知信息>{{ context }}、 <问题>{{ question }} """, diff --git a/server/.DS_Store b/server/.DS_Store new file mode 100644 index 0000000..38c5157 Binary files /dev/null and b/server/.DS_Store differ diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 19ca871..9d2cec7 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -39,6 +39,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") history = [History.from_data(h) for h in history] + #weiwei + print(f"server/chat/knowledge_base_chat function, history:{history}") async def knowledge_base_chat_iterator(query: str, top_k: int, @@ -46,6 +48,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", model_name: str = LLM_MODEL, prompt_name: str = prompt_name, ) -> AsyncIterable[str]: + #weiwei + print(f"knowledge_base_chat_iterator,query:{query},model_name:{model_name},prompt_name:{prompt_name}") + callback = AsyncIteratorCallbackHandler() model = get_ChatOpenAI( model_name=model_name, @@ -55,12 +60,21 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", ) docs = search_docs(query, knowledge_base_name, top_k, score_threshold) context = "\n".join([doc.page_content for doc in docs]) + #weiwei + print(f"knowledge_base_chat_iterator,search docs context:{context}") prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) input_msg = History(role="user", content=prompt_template).to_msg_template(False) + + #weiwei + print(f"knowledge_base_chat_iterator,input_msg:{input_msg}") + chat_prompt = ChatPromptTemplate.from_messages( [i.to_msg_template() for i in history] + [input_msg]) + #weiwei + print(f"knowledge_base_chat_iterator,chat_prompt:{chat_prompt}") + chain = LLMChain(prompt=chat_prompt, llm=model) # Begin a task that runs in the background. @@ -69,6 +83,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", callback.done), ) + #weiwei + print(f"task call end") + source_documents = [] for inum, doc in enumerate(docs): filename = os.path.split(doc.metadata["source"])[-1] diff --git a/server/knowledge_base/.DS_Store b/server/knowledge_base/.DS_Store new file mode 100644 index 0000000..4030beb Binary files /dev/null and b/server/knowledge_base/.DS_Store differ diff --git a/server/knowledge_base/__init__.py b/server/knowledge_base/__init__.py index 19de504..727debd 100644 --- a/server/knowledge_base/__init__.py +++ b/server/knowledge_base/__init__.py @@ -1,3 +1,7 @@ # from .kb_api import list_kbs, create_kb, delete_kb # from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store # from .utils import KnowledgeFile, KBServiceFactory + +from server.knowledge_base.kb_doc_api import * +from server.knowledge_base.kb_api import * +from server.knowledge_base.utils import * \ No newline at end of file diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 3d01f9e..831b139 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -29,8 +29,23 @@ def search_docs(query: str = Body(..., description="用户输入", examples=[" kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: return [] + # query = "根据国网安徽信通公司安全准入实施要求," + query + pre_doc = kb.search_docs(query, 1) + print(f"len(pre_doc):{len(pre_doc)}") + if len(pre_doc) > 0: + print(f"search_docs, len(pre_doc):{len(pre_doc)}") + filpath = pre_doc[0][0].metadata['source'] + file_name = os.path.basename(filpath) + file_name, file_extension = os.path.splitext(file_name) + query = "根据" +file_name + ","+ query + + print(f"search_docs, query:{query}") docs = kb.search_docs(query, top_k, score_threshold) data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs] + # i = 1 + # for x in docs: + # print(f"相似文档 {i}: {x}") + # i = i+1 return data diff --git a/server/knowledge_base/kb_service/.DS_Store b/server/knowledge_base/kb_service/.DS_Store new file mode 100644 index 0000000..f5068e3 Binary files /dev/null and b/server/knowledge_base/kb_service/.DS_Store differ diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index a72fcf7..28de35c 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -60,8 +60,10 @@ class FaissKBService(KBService): score_threshold: float = SCORE_THRESHOLD, embeddings: Embeddings = None, ) -> List[Document]: + print(f"do_search,top_k:{top_k},score_threshold:{score_threshold}") with self.load_vector_store().acquire() as vs: docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold) + print(f"do_search,docs:{docs}") return docs def do_add_doc(self, diff --git a/server/knowledge_base/kb_service/knowledge_base/.DS_Store b/server/knowledge_base/kb_service/knowledge_base/.DS_Store new file mode 100644 index 0000000..32904a3 Binary files /dev/null and b/server/knowledge_base/kb_service/knowledge_base/.DS_Store differ diff --git a/server/knowledge_base/kb_service/knowledge_base/test/.DS_Store b/server/knowledge_base/kb_service/knowledge_base/test/.DS_Store new file mode 100644 index 0000000..e0859a5 Binary files /dev/null and b/server/knowledge_base/kb_service/knowledge_base/test/.DS_Store differ diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index c73d021..045ff9c 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -15,6 +15,7 @@ from configs import ( import importlib from text_splitter import zh_title_enhance as func_zh_title_enhance import langchain.document_loaders +from langchain.document_loaders.word_document import Docx2txtLoader from langchain.docstore.document import Document from langchain.text_splitter import TextSplitter from pathlib import Path @@ -76,8 +77,9 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], "RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'], "UnstructuredFileLoader": ['.eml', '.msg', '.rst', '.rtf', '.txt', '.xml', - '.docx', '.epub', '.odt', + '.epub', '.odt', '.ppt', '.pptx', '.tsv'], + "Docx2txtLoader":['.docx'], } SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] @@ -281,6 +283,7 @@ class KnowledgeFile: self.splited_docs = None self.document_loader_name = get_LoaderClass(self.ext) self.text_splitter_name = TEXT_SPLITTER_NAME + print(f"KnowledgeFile: filepath:{self.filepath}") def file2docs(self, refresh: bool=False): if self.docs is None or refresh: @@ -312,8 +315,13 @@ class KnowledgeFile: doc.metadata["source"] = os.path.basename(self.filepath) else: docs = text_splitter.split_documents(docs) - - print(f"文档切分示例:{docs[0]}") + + #print(f"文档切分示例:{docs[0]}") + i = 0 + for doc in docs: + print(f"**********切分段{i}:{doc}") + i = i+1 + if zh_title_enhance: docs = func_zh_title_enhance(docs) self.splited_docs = docs diff --git a/test.py b/test.py new file mode 100644 index 0000000..1299786 --- /dev/null +++ b/test.py @@ -0,0 +1,21 @@ + +from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService +from server.knowledge_base import KnowledgeFile + +if __name__ == '__main__': + from pprint import pprint + + #kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples") + # kb_file = KnowledgeFile(filename="国网安徽信通公司安全准入实施要求_修订.docx", knowledge_base_name="test") + # docs = kb_file.file2docs() + # pprint(docs[-1]) + # docs = kb_file.file2text() + # pprint(docs[-1]) + + faissService = FaissKBService("test") + faissService.add_doc(KnowledgeFile("国网安徽信通公司安全准入实施要求_修订.docx", "test")) + # faissService.delete_doc(KnowledgeFile("README.md", "test")) + # faissService.do_drop_kb() + print(faissService.search_docs("准入手续的内容是什么?")) + + diff --git a/text_splitter/chinese_recursive_text_splitter.py b/text_splitter/chinese_recursive_text_splitter.py index 70b4b29..d5ee666 100644 --- a/text_splitter/chinese_recursive_text_splitter.py +++ b/text_splitter/chinese_recursive_text_splitter.py @@ -98,7 +98,7 @@ if __name__ == "__main__": ] # text = """""" for inum, text in enumerate(ls): - print(inum) + print(f"**************分段:{inum}") chunks = text_splitter.split_text(text) for chunk in chunks: - print(chunk) + print(f"**************:chunk:{chunk}")