diff --git a/.DS_Store b/.DS_Store index ace25c6..582265c 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/server/.DS_Store b/server/.DS_Store index caa0468..5a32a3e 100644 Binary files a/server/.DS_Store and b/server/.DS_Store differ diff --git a/server/chat/chat.py b/server/chat/chat.py index 4402185..a548328 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -10,7 +10,7 @@ from langchain.prompts.chat import ChatPromptTemplate from typing import List from server.chat.utils import History from server.utils import get_prompt_template - +from langchain.prompts import PromptTemplate async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), history: List[History] = Body([], @@ -41,11 +41,34 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 callbacks=[callback], ) - prompt_template = get_prompt_template("llm_chat", prompt_name) - input_msg = History(role="user", content=prompt_template).to_msg_template(False) - chat_prompt = ChatPromptTemplate.from_messages( - [i.to_msg_template() for i in history] + [input_msg]) - chain = LLMChain(prompt=chat_prompt, llm=model) + # augment_prompt_template = get_prompt_template("data_augment", "default") + # input_msg = History(role="user", content=augment_prompt_template).to_msg_template(False) + # chat_prompt = ChatPromptTemplate.from_messages( + # [i.to_msg_template() for i in history] + [input_msg]) + # chain = LLMChain(prompt=chat_prompt, llm=model) + # print(f"knowledge_base_chat_iterator,prompt_template:{chat_prompt}") + + # # Begin a task that runs in the background. + # task = asyncio.create_task(wrap_done( + # chain.acall({ "question": query}), + # callback.done), + # ) + + + prompt = ChatPromptTemplate.from_template(" 你是一个非常聪明的语义转换专家总能找到同一个语义不同的表达方式,请简洁生成一个与三单引号里的原句子语气语调完全一致,并且语义最相似的新句子,注意不是回答三单引号里的原句子,同时新句子直接使用简体中文给出,而不用重三单引号里的原句子。如果无法给出满足条件的新句子,直接给出三单引号里的原句子,而不是给出三单引号里的原句子的答案 '''{{input}}''' ") + + # prompt_template = get_prompt_template("llm_chat", prompt_name) + # input_msg = History(role="user", content=prompt_template).to_msg_template(False) + # chat_prompt = ChatPromptTemplate.from_messages( + # [i.to_msg_template() for i in history] + [input_msg]) + chain = LLMChain(prompt=prompt, llm=model) + + name = "John" + age = 30 + text = "你是一个非常聪明的语义转换专家总能找到同一个语义不同的表达方式,请简洁生成一个与三单引号里的原句子语气语调完全一致,并且语义最相似的新句子,注意不是回答三单引号里的原句子,同时新句子直接使用简体中文给出,而不用重三单引号里的原句子。如果无法给出满足条件的新句子,直接给出三单引号里的原句子,而不是给出三单引号里的原句子的答案\n ''' " + text = text + query + text = text + " ''' " + result = ", ".join(["Name: ", name, ", Age: ", str(age)]) # Begin a task that runs in the background. task = asyncio.create_task(wrap_done( diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 9d2cec7..f7e9b3a 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -14,12 +14,12 @@ import json import os from urllib.parse import urlencode from server.knowledge_base.kb_doc_api import search_docs - +from langchain.prompts import PromptTemplate async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), - score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), + score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=2), history: List[History] = Body([], description="历史对话", examples=[[ @@ -39,17 +39,30 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") history = [History.from_data(h) for h in history] - #weiwei print(f"server/chat/knowledge_base_chat function, history:{history}") - async def knowledge_base_chat_iterator(query: str, top_k: int, history: Optional[List[History]], model_name: str = LLM_MODEL, prompt_name: str = prompt_name, ) -> AsyncIterable[str]: - #weiwei print(f"knowledge_base_chat_iterator,query:{query},model_name:{model_name},prompt_name:{prompt_name}") + + model1 = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + callbacks=[], + ) + +# augment_prompt_template = get_prompt_template("data_augment", "default") +# input_msg1 = History(role="user", content=augment_prompt_template).to_msg_template(False) +# chat_prompt1 = ChatPromptTemplate.from_messages( +# [i.to_msg_template() for i in history] + [input_msg1]) +# chain1 = LLMChain(prompt=chat_prompt1, llm=model1) +# print(f"knowledge_base_chat_iterator,prompt_template:{chat_prompt1}") +# result = chain1._call({ "question": query}) +# print(f"chain1._call, result:{result}") callback = AsyncIteratorCallbackHandler() model = get_ChatOpenAI( @@ -58,23 +71,36 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", max_tokens=max_tokens, callbacks=[callback], ) - docs = search_docs(query, knowledge_base_name, top_k, score_threshold) + + #augment_prompt_template = get_prompt_template("data_augment", "default") + #input_msg = History(role="user", content=augment_prompt_template).to_msg_template(False) + #chat_prompt = ChatPromptTemplate.from_messages( + # [i.to_msg_template() for i in history] + [input_msg]) + #chain = LLMChain(prompt=chat_prompt, llm=model) + #print(f"knowledge_base_chat_iterator,prompt_template:{chat_prompt}") + ##chain = LLMChain(prompt=PromptTemplate.from_template(augment_prompt_template), llm=model) + ##print(f"knowledge_base_chat_iterator,prompt_template:{augment_prompt_template}") + #task = asyncio.create_task(wrap_done( + # chain.acall({ "question": query}), + # callback.done), + #) + #prompt_template = "请找出和{question}最相似的一句话" + #llm_chain = LLMChain(prompt=PromptTemplate.from_template(prompt_template), llm=model) + #result = llm_chain(query) + #print(f"请找出和question 最相似的一句话:{result}") + + docs = search_docs(query, knowledge_base_name, top_k, score_threshold, model1) context = "\n".join([doc.page_content for doc in docs]) - #weiwei - print(f"knowledge_base_chat_iterator,search docs context:{context}") + + #print(f"knowledge_base_chat_iterator,search docs context:{context}") prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) input_msg = History(role="user", content=prompt_template).to_msg_template(False) - - #weiwei print(f"knowledge_base_chat_iterator,input_msg:{input_msg}") - chat_prompt = ChatPromptTemplate.from_messages( [i.to_msg_template() for i in history] + [input_msg]) - #weiwei - print(f"knowledge_base_chat_iterator,chat_prompt:{chat_prompt}") - + #print(f"knowledge_base_chat_iterator,chat_prompt:{chat_prompt}") chain = LLMChain(prompt=chat_prompt, llm=model) # Begin a task that runs in the background. @@ -82,10 +108,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", chain.acall({"context": context, "question": query}), callback.done), ) - - #weiwei - print(f"task call end") + print(f"task call end") source_documents = [] for inum, doc in enumerate(docs): filename = os.path.split(doc.metadata["source"])[-1] @@ -93,6 +117,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", url = f"/knowledge_base/download_doc?" + parameters text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" source_documents.append(text) + + print(f"knowledge_base_chat_iterator, stream:{stream}") if stream: async for token in callback.aiter(): # Use server-sent-events to stream the response diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 1eef729..2298f0a 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -15,7 +15,13 @@ from server.knowledge_base.kb_service.base import KBServiceFactory from server.db.repository.knowledge_file_repository import get_file_detail from typing import List, Dict from langchain.docstore.document import Document - +from langchain.chat_models import ChatOpenAI +from langchain.prompts import PromptTemplate +from langchain.chains import LLMChain +from server.chat.utils import History +from server.utils import BaseResponse, get_prompt_template +from langchain.prompts.chat import ChatPromptTemplate +from langchain.callbacks import AsyncIteratorCallbackHandler class DocumentWithScore(Document): score: float = None @@ -25,32 +31,82 @@ def search_docs(query: str = Body(..., description="用户输入", examples=[" knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), + model:ChatOpenAI = Body(...,description="大语言模型"), ) -> List[DocumentWithScore]: kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: return [] - # query = "根据国网安徽信通公司安全准入实施要求," + query + + #history = {} + #augment_prompt_template = get_prompt_template("data_augment", "default") + #input_msg1 = History(role="user", content=augment_prompt_template).to_msg_template(False) + #chat_prompt1 = ChatPromptTemplate.from_messages( + # [i.to_msg_template() for i in history] + [input_msg1]) + #chain1 = LLMChain(prompt=chat_prompt1, llm=model) + #print(f"knowledge_base_chat_iterator,prompt_template:{chat_prompt1}") + #result = chain1._call({ "question": query}) + #query1 = result["text"] + # + #print(f"相似的问法:{query1}") + #docs = kb.search_docs(query, top_k, score_threshold) + #print(f"{query}的相似文档块有{docs}") + #data = [] + #if query1 != query: + # docs1 = kb.search_docs(query1, top_k, score_threshold) + # print(f"{query1}的相似文档块有{docs1}") + # rerank_docs = rerank(docs1,docs,top_k) + # print(f"精排后的相似文档块有{rerank_docs}") + # data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in rerank_docs] + #else: + # data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs] + + #print(f"chain1._call, result:{result},similiarit text:{query1}") + + pre_doc = kb.search_docs(query, 1, None) print(f"len(pre_doc):{len(pre_doc)}") if len(pre_doc) > 0: - print(f"search_docs, len(pre_doc):{len(pre_doc)}") + print(f"search_docs, pre_doc:{pre_doc}") filpath = pre_doc[0][0].metadata['source'] file_name = os.path.basename(filpath) file_name, file_extension = os.path.splitext(file_name) query = "根据" +file_name + ","+ query - print(f"search_docs, query:{query}") + print(f"search_docs, query:{query}") docs = kb.search_docs(query, top_k, score_threshold) + print(f"search_docs, docs:{docs}") if len(pre_doc) > 0: if docs is not None: docs.append(pre_doc[0]) else: docs = pre_doc[0] + data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs] - return data +def rerank(query_docs:List[Document] = Body(...,description="源query查询相似文档", examples=[]), + augment_docs:List[Document] = Body(...,description="增强query查询相似文档", examples=[]), + top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), + )-> List[Document]: + all_documents = query_docs + augment_docs + # unique_documents_dict = {doc[0].page_content: doc for doc in all_documents} + + unique_documents_dict = {} + for doc in all_documents: + if doc[0].page_content not in unique_documents_dict or doc[1] < unique_documents_dict[doc[0].page_content][1]: + unique_documents_dict[doc[0].page_content] = doc + +# 得到去重后的文档列表 + unique_documents = list(unique_documents_dict.values()) + + sorted_documents = sorted(unique_documents_dict.values(),key=lambda doc: doc[1],reverse=False) + min_documents = sorted_documents[:top_k] +# 打印结果 + for doc in min_documents: + print(f"{doc[0].page_content},doc[1]") + + return min_documents def list_files( knowledge_base_name: str @@ -158,6 +214,8 @@ def upload_docs(files: List[UploadFile] = File(..., description="上传文件, failed_files = {} file_names = list(docs.keys()) + print(f"upload_docs, file_names:{file_names}") + # 先将上传的文件保存到磁盘 for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override): filename = result["data"]["file_name"] @@ -167,7 +225,9 @@ def upload_docs(files: List[UploadFile] = File(..., description="上传文件, if filename not in file_names: file_names.append(filename) + # 对保存的文件进行向量化 + print(f"upload_docs, to_vector_store:{to_vector_store}") if to_vector_store: result = update_docs( knowledge_base_name=knowledge_base_name, diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 874f4bd..f25a706 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -334,7 +334,7 @@ class KnowledgeFile: with open(outputfile, 'w') as file: for doc in docs: print(f"**********切分段{i}:{doc}") - file.write(f"分段{i}") + file.write(f"\n**********切分段{i}") file.write(doc.page_content) i = i+1 diff --git a/text_splitter/chinese_recursive_text_splitter.py b/text_splitter/chinese_recursive_text_splitter.py index fca36b2..480d6ca 100644 --- a/text_splitter/chinese_recursive_text_splitter.py +++ b/text_splitter/chinese_recursive_text_splitter.py @@ -54,12 +54,12 @@ class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter): # Get appropriate separator to use separator = separators[-1] new_separators = [SPLIT_SEPARATOE] - #text = re.sub(r'(\n+[a-zA-Z1-9]+\s*(\.\s*[a-zA-Z1-9]+\s*)+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过章和节来分块 text = re.sub(r'(\n+[a-zA-Z1-9]+\s*\.\s*[a-zA-Z1-9]+\s*(?!\.|[a-zA-Z1-9]))', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过1.2这样的章和节来分块 - text = re.sub(r'(\n+表\s*[A-Za-z0-9]+(\.[A-Za-z0-9]+)+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过表 A.4.a + text = re.sub(r'(\n+表\s*[A-Za-z0-9]+(\.[A-Za-z0-9]+)+\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过表 A.4.a text = re.sub(r'(\n+第\s*\S+\s*条\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过第 条 text = re.sub(r'(\n+第\s*\S+\s*章\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过第 条 text = re.sub(r'(\n+(一、|二、|三、|四、|五、|六、|七、|八、|九、|十、|十一、|十二、|十三、|十四、|十五、|十六、|十七、|十八、|十九、|二十、))', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过第 条 + text = re.sub(r'(\s+[a-zA-Z1-9]+\s*\.\s*[a-zA-Z1-9]+\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 再通过 1.2 来分块 text = text.rstrip() # 段尾如果有多余的\n就去掉它 for i, _s in enumerate(separators): _separator = _s if self._is_separator_regex else re.escape(_s) @@ -88,7 +88,7 @@ class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter): if not new_separators: final_chunks.append(s) else: - text = re.sub(r'(\n+[a-zA-Z1-9]+\s*\.\s*[a-zA-Z1-9]+\s*(\.\s*[a-zA-Z1-9]+\s*)+)', r"\n\n\n\n\n\n\n\n\n\n\1", s) # 再通过1.2.3来分块 + text = re.sub(r'(\s+[a-zA-Z1-9]+\s*\.\s*[a-zA-Z1-9]+\s*\.\s*[a-zA-Z1-9]+\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", s) # 再通过 1.2.3 来分块 other_info = self._split_text(s, new_separators) final_chunks.extend(other_info) if _good_splits: diff --git a/webui_pages/.DS_Store b/webui_pages/.DS_Store new file mode 100644 index 0000000..2ff7808 Binary files /dev/null and b/webui_pages/.DS_Store differ diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index ff0f30f..2144b90 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -173,7 +173,7 @@ def dialogue_page(api: ApiRequest): kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K) ## Bge 模型会超过1 - score_threshold = st.slider("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01) + score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01) elif dialogue_mode == "搜索引擎问答": search_engine_list = api.list_search_engines()