From 24d1e28a07f71633217b69e92e91e75e17ec41c6 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Fri, 27 Oct 2023 11:52:44 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=80=E4=BA=9B=E7=BB=86=E8=8A=82=E4=BC=98?= =?UTF-8?q?=E5=8C=96=20(#1891)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: zR --- configs/model_config.py.example | 13 ++++- configs/prompt_config.py.example | 11 +++- document_loaders/FilteredCSVloader.py | 80 ++++++++++++++++++++++++++ requirements.txt | 2 +- requirements_api.txt | 2 +- requirements_lite.txt | 2 +- server/agent/custom_template.py | 20 ++----- server/chat/knowledge_base_chat.py | 11 +++- server/chat/search_engine_chat.py | 3 + server/knowledge_base/kb_doc_api.py | 1 - server/knowledge_base/utils.py | 82 ++++++++++++++------------- webui_pages/dialogue/dialogue.py | 48 +++++++++------- 12 files changed, 190 insertions(+), 85 deletions(-) create mode 100644 document_loaders/FilteredCSVloader.py diff --git a/configs/model_config.py.example b/configs/model_config.py.example index efbedf9..74e9023 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -231,4 +231,15 @@ VLLM_MODEL_DICT = { "agentlm-13b":"THUDM/agentlm-13b", "agentlm-70b":"THUDM/agentlm-70b", -} \ No newline at end of file +} + +## 你认为支持Agent能力的模型,可以在这里添加,添加后不会出现可视化界面的警告 +SUPPORT_AGENT_MODEL = [ + "Azure-OpenAI", + "OpenAI", + "Anthropic", + "Qwen", + "qwen-api", + "baichuan-api", + "agentlm" +] \ No newline at end of file diff --git a/configs/prompt_config.py.example b/configs/prompt_config.py.example index e2d7ceb..6080ace 100644 --- a/configs/prompt_config.py.example +++ b/configs/prompt_config.py.example @@ -47,6 +47,11 @@ PROMPT_TEMPLATES["knowledge_base_chat"] = { <已知信息>{{ context }}、 <问题>{{ question }} """, + "Empty": # 搜不到内容的时候调用,此时没有已知信息,这个Empty可以更改,但不能删除,会影响程序使用 + """ + <指令>请根据用户的问题,进行简洁明了的回答 + <问题>{{ question }} + """, } PROMPT_TEMPLATES["search_engine_chat"] = { "default": @@ -55,13 +60,17 @@ PROMPT_TEMPLATES["search_engine_chat"] = { <已知信息>{{ context }}、 <问题>{{ question }} """, - "search": """ <指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 <已知信息>{{ context }}、 <问题>{{ question }} """, + "Empty": # 搜不到内容的时候调用,此时没有已知信息,这个Empty可以更改,但不能删除,会影响程序使用 + """ + <指令>请根据用户的问题,进行简洁明了的回答 + <问题>{{ question }} + """, } PROMPT_TEMPLATES["agent_chat"] = { "default": diff --git a/document_loaders/FilteredCSVloader.py b/document_loaders/FilteredCSVloader.py new file mode 100644 index 0000000..0f8148d --- /dev/null +++ b/document_loaders/FilteredCSVloader.py @@ -0,0 +1,80 @@ +## 指定制定列的csv文件加载器 + +from langchain.document_loaders import CSVLoader +import csv +from io import TextIOWrapper +from typing import Dict, List, Optional +from langchain.docstore.document import Document +from langchain.document_loaders.helpers import detect_file_encodings + + +class FilteredCSVLoader(CSVLoader): + def __init__( + self, + file_path: str, + columns_to_read: List[str], + source_column: Optional[str] = None, + metadata_columns: List[str] = [], + csv_args: Optional[Dict] = None, + encoding: Optional[str] = None, + autodetect_encoding: bool = False, + ): + super().__init__( + file_path=file_path, + source_column=source_column, + metadata_columns=metadata_columns, + csv_args=csv_args, + encoding=encoding, + autodetect_encoding=autodetect_encoding, + ) + self.columns_to_read = columns_to_read + + def load(self) -> List[Document]: + """Load data into document objects.""" + + docs = [] + try: + with open(self.file_path, newline="", encoding=self.encoding) as csvfile: + docs = self.__read_file(csvfile) + except UnicodeDecodeError as e: + if self.autodetect_encoding: + detected_encodings = detect_file_encodings(self.file_path) + for encoding in detected_encodings: + try: + with open( + self.file_path, newline="", encoding=encoding.encoding + ) as csvfile: + docs = self.__read_file(csvfile) + break + except UnicodeDecodeError: + continue + else: + raise RuntimeError(f"Error loading {self.file_path}") from e + except Exception as e: + raise RuntimeError(f"Error loading {self.file_path}") from e + + return docs + def __read_file(self, csvfile: TextIOWrapper) -> List[Document]: + docs = [] + csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore + for i, row in enumerate(csv_reader): + if self.columns_to_read[0] in row: + content = row[self.columns_to_read[0]] + # Extract the source if available + source = ( + row.get(self.source_column, None) + if self.source_column is not None + else self.file_path + ) + metadata = {"source": source, "row": i} + + for col in self.metadata_columns: + if col in row: + metadata[col] = row[col] + + doc = Document(page_content=content, metadata=metadata) + docs.append(doc) + else: + raise ValueError(f"Column '{self.columns_to_read[0]}' not found in CSV file.") + + return docs \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 364cc34..419c650 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -langchain>=0.0.319 +langchain>=0.0.324 langchain-experimental>=0.0.30 fschat[model_worker]==0.2.31 xformers>=0.0.22.post4 diff --git a/requirements_api.txt b/requirements_api.txt index 4ef692b..126e65f 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -1,4 +1,4 @@ -langchain>=0.0.319 +langchain>=0.0.324 langchain-experimental>=0.0.30 fschat[model_worker]==0.2.31 xformers>=0.0.22.post4 diff --git a/requirements_lite.txt b/requirements_lite.txt index d453451..4ff659e 100644 --- a/requirements_lite.txt +++ b/requirements_lite.txt @@ -1,4 +1,4 @@ -langchain>=0.0.319 +langchain>=0.0.324 fschat>=0.2.31 openai # sentence_transformers diff --git a/server/agent/custom_template.py b/server/agent/custom_template.py index a08ba8e..2744b87 100644 --- a/server/agent/custom_template.py +++ b/server/agent/custom_template.py @@ -3,29 +3,22 @@ from langchain.agents import Tool, AgentOutputParser from langchain.prompts import StringPromptTemplate from typing import List from langchain.schema import AgentAction, AgentFinish + +from configs import SUPPORT_AGENT_MODEL from server.agent import model_container class CustomPromptTemplate(StringPromptTemplate): - # The template to use template: str - # The list of tools available tools: List[Tool] def format(self, **kwargs) -> str: - # Get the intermediate steps (AgentAction, Observation tuples) - # Format them in a particular way intermediate_steps = kwargs.pop("intermediate_steps") thoughts = "" for action, observation in intermediate_steps: thoughts += action.log thoughts += f"\nObservation: {observation}\nThought: " - # Set the agent_scratchpad variable to that value kwargs["agent_scratchpad"] = thoughts - # Create a tools variable from the list of tools provided kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools]) - # Create a list of tool names for the tools provided kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools]) - # Return the formatted templatepr - # print( self.template.format(**kwargs), end="\n\n") return self.template.format(**kwargs) @@ -36,9 +29,7 @@ class CustomOutputParser(AgentOutputParser): self.begin = True def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction: - # Check if agent should finish - support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api","agentlm"] # 目前支持agent的模型 - if not any(agent in model_container.MODEL for agent in support_agent) and self.begin: + if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin: self.begin = False stop_words = ["Observation:"] min_index = len(llm_output) @@ -54,8 +45,6 @@ class CustomOutputParser(AgentOutputParser): return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()}, log=llm_output, ) - - # Parse out the action and action input parts = llm_output.split("Action:") if len(parts) < 2: return AgentFinish( @@ -66,7 +55,7 @@ class CustomOutputParser(AgentOutputParser): action = parts[1].split("Action Input:")[0].strip() action_input = parts[1].split("Action Input:")[1].strip() - # 原来的正则化检查方式 + # 原来的正则化检查方式,更严格,但是成功率更低 # regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)" # print("llm_output",llm_output) # match = re.search(regex, llm_output, re.DOTALL) @@ -80,7 +69,6 @@ class CustomOutputParser(AgentOutputParser): # action_input = match.group(2) # Return the action and action input - try: ans = AgentAction( tool=action, diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index b7fe85c..01744c1 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -55,8 +55,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", ) docs = search_docs(query, knowledge_base_name, top_k, score_threshold) context = "\n".join([doc.page_content for doc in docs]) - - prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) + if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板 + prompt_template = get_prompt_template("knowledge_base_chat", "Empty") + else: + prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) input_msg = History(role="user", content=prompt_template).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( [i.to_msg_template() for i in history] + [input_msg]) @@ -76,6 +78,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", url = f"/knowledge_base/download_doc?" + parameters text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" source_documents.append(text) + + if len(source_documents) == 0: # 没有找到相关文档 + source_documents.append(f"""未找到相关文档,该回答为大模型自身能力解答!""") + if stream: async for token in callback.aiter(): # Use server-sent-events to stream the response @@ -88,7 +94,6 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", yield json.dumps({"answer": answer, "docs": source_documents}, ensure_ascii=False) - await task return StreamingResponse(knowledge_base_chat_iterator(query=query, diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 3d2100e..fcda1fa 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -176,6 +176,9 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", for inum, doc in enumerate(docs) ] + if len(source_documents) == 0: # 没有找到相关资料(不太可能) + source_documents.append(f"""未找到相关文档,该回答为大模型自身能力解答!""") + if stream: async for token in callback.aiter(): # Use server-sent-events to stream the response diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 3d01f9e..11b8f45 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -31,7 +31,6 @@ def search_docs(query: str = Body(..., description="用户输入", examples=[" return [] docs = kb.search_docs(query, top_k, score_threshold) data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs] - return data diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index c73d021..3981e81 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,4 +1,7 @@ import os +import sys + +sys.path.append("/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat") from transformers import AutoTokenizer from configs import ( EMBEDDING_MODEL, @@ -25,7 +28,6 @@ import io from typing import List, Union, Callable, Dict, Optional, Tuple, Generator import chardet - def validate_kb_name(knowledge_base_id: str) -> bool: # 检查是否包含预期外的字符或路径攻击关键字 if "../" in knowledge_base_id: @@ -72,6 +74,7 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], "UnstructuredMarkdownLoader": ['.md'], "CustomJSONLoader": [".json"], "CSVLoader": [".csv"], + # "FilteredCSVLoader": [".csv"], # 需要自己指定,目前还没有支持 "RapidOCRPDFLoader": [".pdf"], "RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'], "UnstructuredFileLoader": ['.eml', '.msg', '.rst', @@ -88,12 +91,12 @@ class CustomJSONLoader(langchain.document_loaders.JSONLoader): ''' def __init__( - self, - file_path: Union[str, Path], - content_key: Optional[str] = None, - metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None, - text_content: bool = True, - json_lines: bool = False, + self, + file_path: Union[str, Path], + content_key: Optional[str] = None, + metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None, + text_content: bool = True, + json_lines: bool = False, ): """Initialize the JSONLoader. @@ -150,7 +153,7 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri 根据loader_name和文件路径或内容返回文档加载器。 ''' try: - if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]: + if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]: document_loaders_module = importlib.import_module('document_loaders') else: document_loaders_module = importlib.import_module('langchain.document_loaders') @@ -168,10 +171,11 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri # 自动识别文件编码类型,避免langchain loader 加载文件报编码错误 with open(file_path_or_content, 'rb') as struct_file: encode_detect = chardet.detect(struct_file.read()) - if encode_detect: - loader = DocumentLoader(file_path_or_content, encoding=encode_detect["encoding"]) - else: - loader = DocumentLoader(file_path_or_content, encoding="utf-8") + if encode_detect is None: + encode_detect = {"encoding": "utf-8"} + + loader = DocumentLoader(file_path_or_content, encoding=encode_detect["encoding"]) + ## TODO:支持更多的自定义CSV读取逻辑 elif loader_name == "JSONLoader": loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False) @@ -187,10 +191,10 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri def make_text_splitter( - splitter_name: str = TEXT_SPLITTER_NAME, - chunk_size: int = CHUNK_SIZE, - chunk_overlap: int = OVERLAP_SIZE, - llm_model: str = LLM_MODEL, + splitter_name: str = TEXT_SPLITTER_NAME, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + llm_model: str = LLM_MODEL, ): """ 根据参数获取特定的分词器 @@ -262,6 +266,7 @@ def make_text_splitter( text_splitter = TextSplitter(chunk_size=250, chunk_overlap=50) return text_splitter + class KnowledgeFile: def __init__( self, @@ -282,7 +287,7 @@ class KnowledgeFile: self.document_loader_name = get_LoaderClass(self.ext) self.text_splitter_name = TEXT_SPLITTER_NAME - def file2docs(self, refresh: bool=False): + def file2docs(self, refresh: bool = False): if self.docs is None or refresh: logger.info(f"{self.document_loader_name} used for {self.filepath}") loader = get_loader(self.document_loader_name, self.filepath) @@ -290,20 +295,21 @@ class KnowledgeFile: return self.docs def docs2texts( - self, - docs: List[Document] = None, - zh_title_enhance: bool = ZH_TITLE_ENHANCE, - refresh: bool = False, - chunk_size: int = CHUNK_SIZE, - chunk_overlap: int = OVERLAP_SIZE, - text_splitter: TextSplitter = None, + self, + docs: List[Document] = None, + zh_title_enhance: bool = ZH_TITLE_ENHANCE, + refresh: bool = False, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + text_splitter: TextSplitter = None, ): docs = docs or self.file2docs(refresh=refresh) if not docs: return [] if self.ext not in [".csv"]: if text_splitter is None: - text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap) + text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, + chunk_overlap=chunk_overlap) if self.text_splitter_name == "MarkdownHeaderTextSplitter": docs = text_splitter.split_text(docs[0].page_content) for doc in docs: @@ -320,12 +326,12 @@ class KnowledgeFile: return self.splited_docs def file2text( - self, - zh_title_enhance: bool = ZH_TITLE_ENHANCE, - refresh: bool = False, - chunk_size: int = CHUNK_SIZE, - chunk_overlap: int = OVERLAP_SIZE, - text_splitter: TextSplitter = None, + self, + zh_title_enhance: bool = ZH_TITLE_ENHANCE, + refresh: bool = False, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + text_splitter: TextSplitter = None, ): if self.splited_docs is None or refresh: docs = self.file2docs() @@ -359,6 +365,7 @@ def files2docs_in_thread( 如果传入参数是Tuple,形式为(filename, kb_name) 生成器返回值为 status, (kb_name, file_name, docs | error) ''' + def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]: try: return True, (file.kb_name, file.filename, file.file2text(**kwargs)) @@ -373,8 +380,8 @@ def files2docs_in_thread( kwargs = {} try: if isinstance(file, tuple) and len(file) >= 2: - filename=file[0] - kb_name=file[1] + filename = file[0] + kb_name = file[1] file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) elif isinstance(file, dict): filename = file.pop("filename") @@ -396,10 +403,9 @@ def files2docs_in_thread( if __name__ == "__main__": from pprint import pprint - kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples") + kb_file = KnowledgeFile( + filename="/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat/knowledge_base/csv1/content/gm.csv", + knowledge_base_name="samples") # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter" docs = kb_file.file2docs() - pprint(docs[-1]) - - docs = kb_file.file2text() - pprint(docs[-1]) + # pprint(docs[-1]) diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index f3fb4f6..5ff5a78 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -4,16 +4,17 @@ from streamlit_chatbox import * from datetime import datetime import os from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, - DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE) + DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL) from typing import List, Dict - chat_box = ChatBox( assistant_avatar=os.path.join( "img", "chatchat_icon_blue_square_v2.png" ) ) + + def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]: ''' 返回消息历史。 @@ -55,14 +56,14 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): if is_lite: dialogue_modes = ["LLM 对话", - "搜索引擎问答", - ] + "搜索引擎问答", + ] else: dialogue_modes = ["LLM 对话", - "知识库问答", - "搜索引擎问答", - "自定义Agent问答", - ] + "知识库问答", + "搜索引擎问答", + "自定义Agent问答", + ] dialogue_mode = st.selectbox("请选择对话模式:", dialogue_modes, index=0, @@ -102,10 +103,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): key="llm_model", ) if (st.session_state.get("prev_llm_model") != llm_model - and not is_lite - and not llm_model in config_models.get("online", {}) - and not llm_model in config_models.get("langchain", {}) - and llm_model not in running_models): + and not is_lite + and not llm_model in config_models.get("online", {}) + and not llm_model in config_models.get("langchain", {}) + and llm_model not in running_models): with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"): prev_model = st.session_state.get("prev_llm_model") r = api.change_llm_model(prev_model, llm_model) @@ -210,17 +211,20 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): elif dialogue_mode == "自定义Agent问答": - chat_box.ai_say([ - f"正在思考...", - Markdown("...", in_expander=True, title="思考过程", state="complete"), + if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL): + chat_box.ai_say([ + f"正在思考... \n\n 该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!\n\n\n", + Markdown("...", in_expander=True, title="思考过程", state="complete"), - ]) + ]) + else: + chat_box.ai_say([ + f"正在思考...", + Markdown("...", in_expander=True, title="思考过程", state="complete"), + + ]) text = "" ans = "" - support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api","agentlm"] # 目前支持agent的模型 - if not any(agent in llm_model for agent in support_agent): - ans += "正在思考... \n\n 该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!\n\n\n" - chat_box.update_msg(ans, element_index=0, streaming=False) for d in api.agent_chat(prompt, history=history, model=llm_model, @@ -278,7 +282,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): model=llm_model, prompt_name=prompt_template_name, temperature=temperature, - split_result=se_top_k>1): + split_result=se_top_k > 1): if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) elif chunk := d.get("answer"): @@ -305,4 +309,4 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md", mime="text/markdown", use_container_width=True, - ) \ No newline at end of file + )